R/stochasticPartitioningSampler.R

#' Stochastic Partition Sampling
#'
#' Samples from the model configuration space of a multivariate normal outcome
#'
#'
#' @return List of outputs
#' 
#' @imports Rcpp, RcppEigen, parallel
#'
#' @export
stochPartMCMC <- function(X=NULL, Y=NULL, invariantData = NULL, priors, comp = NULL,  n.iter = 1E5, n.burnin = NULL, sweep = 1000, n.chain = 10, save.iter= 1000, out.file = NULL, parallel=FALSE, tempering=c(FALSE, "none", "sequential","parallel", "mpi"), n.temps = 10, ...) {
  
  #save initial call
  init.call <- match.call(stochPartMCMC, expand.dots = TRUE)
  
  #save initial args
  args <- lapply(init.call, eval)
  
  #if X is part of the stochPart class, then pull out old data (future)
  #for now can feed invariant data and then the final chain info as initial chain info
  
  #error function dump
  currerrfn <- options("error")[[1]]
  options(error = (function(x){dump.frames();error.save <- as.list(last.dump[[1]]);save(error.save, file=paste0(getwd(),"/error.RData")); .rs.breakOnError(FALSE)}))
  
  #restore old error function
  on.exit(expr=options(error=currerrfn))
  
  #set up data that is the same across chains and iterations
  if(is.null(invariantData)){
    #data check
    if(nrow(X) != nrow(Y)) stop("Data (X and Y) must have same number of rows!")
    #save data dimensions
    Q <- ncol(Y)
    P <- ncol(X)
    N <- nrow(X)
    XcolSum <- colSum(X)
    Xcross <- crossprod(X)
    XcolCross <- tcrossprod(XcolSum)
    XYcross <- crossprod(X,Y)
    A <- Acalc(XcolSum, XcolCross,  Xcross, priors$h0, N)
    V_part <- priors$alpha0/priors$h0 + colSum(Y)
    V_part_sq <- (priors$alpha0/priors$h0 + colSum(Y))^2
    
    invariantData <- list(N = N, Q = Q, P = P,
                          Y = Y, XcolSum = XcolSum, XcolCross = XcolCross,
                          XYcross = XYcross, A = A, V_part=V_part, 
                          V_part_sq = V_part_sq)
    rm(X)
    rm(Y)
  } else {
    if(!all(names(invariantData) %in% c("N", "P", "Q","Y","XcolSum","XcolCross","XYcross","A","V_part","V_part_sq"))){
      stop(paste("Names of invariantData must include",paste(c("N", "P", "Q","Y","XcolSum","XcolCross","XYcross","A","V_part","V_part_sq"), collapse=", "), collapse=" "))
    }
    N <- invariantData$N
    Q <- invariantData$Q
    P <- invariantData$P
    
  }
  
  
  
  
  #save tempering conditions
  if(length(tempering)>1){
    tempering <- tempering[1]
    warning("Only first argument of tempering will be used.")
  }
  tempering <- match.arg(arg=tempering, choices=c(FALSE, NULL,"none", "sequential","parallel"))
  if((tempering != "sequential" & tempering != "parallel" & tempering != "mpi") | is.null(tempering)){
    tempering <- "none"
  }
  if(tempering=="mpi") tempering <- "parallel"
  
  
  #set up iterative parameters
  if(sweep == 0 | is.null(sweep) | sweep > n.iter) sweep <- n.iter
  if(tempering == "none") sweep <- n.iter
  if(tempering != "none" & sweep == n.iter) sweep <- floor(n.iter/10)
  n.sweep <- round(n.iter/sweep) #number of sweeps if tempering
  n.iter <- n.sweep * sweep #makes sure you have whole numbers of sweeps
  
  if(is.null(n.burnin)) n.burnin <- floor(n.iter/2)
  progress.iter <- ceiling(n.iter/10)
  n.half <- round((n.iter-n.burnin)/2)
  n.half <- c(n.half, n.iter-n.burnin-n.half)
  half.switch <- n.half[1] + n.burnin
  n.save <- ceiling((n.iter-n.burnin)/save.iter)
  # write.length <- pmin(1E6, n.iter-n.burnin)
  # write.iter <- save.iter * write.length
  
  #set up parameters to track
  # Y_track <- matrix(0, ncol=Q, nrow=write.length)
  # X_track <- matrix(0, ncol=P, nrow=write.length)
  Y_corr  <- array(0, dim=c(Q, Q, 2, n.chain), dimnames=list(yrows=NULL, ycols=NULL, half=1:2, chains=1:n.chain))
  X_corr  <- array(0, dim=c(Q, P, 2, n.chain), dimnames=list(yrows=NULL, xcols=NULL, half=1:2, chains=1:n.chain))
  log_prob <- numComp <- matrix(NA, nrow=n.save, ncol=n.chain)
  tempMonitor <- array(0, dim=c(n.temps, n.temps, n.chain), dimnames=list(temps=1:n.temps, temp.chain = 1:n.temps, chains=1:n.chain))
  
  # if(is.null(out.file)){
  #   out.file.Y <- paste0("stochastic_search_",date(),"_Y.csv")
  #   out.file.X <- paste0("stochastic_search_",date(),"_X.csv")
  # } else {
  #   if(length(out.file)>2) stop("0, 1, or 2 (Y variable then X variable) file names must be given")
  #   if(length(out.file)==1){
  #     out.file <- strsplit(outfile, ".")
  #     out.file.Y <- paste0(out.file[1], "_Y.",out.file[2])
  #     out.file.X <- paste0(out.file[1], "_X.",out.file[2])
  #   } else{
  #     out.file.Y <- paste0(out.file[1])
  #     out.file.X <- paste0(out.file[2])
  #   }
  #   
  # }
  
  #set up starting components for each chain
  if(is.null(comp)){
    componentTotal <- lapply(1:n.chain, function(i) lapply(1:n.temps, function(j) components(Q,P)))
    comp <- lapply(componentTotal, function(ch) lapply(ch, function(x) x$comp))
  } else {
    comp <- lapply(1:n.chain, function(i) lapply(1:n.temps, function(t) comp))
  }
  initial.comp <- comp
  
  #quantities to diagnose MCMC performance
  AR <- MS1.ratio <- MS2.ratio <- nk <- NULL
  accept <- numeric(2)
  upperY <- upper.tri(Y_corr[,,1,1])
  ntriY <- sum(upperY)
  rhat <- numeric(ntriY + P*Q + 2)
  Y_corr_rhat <- matrix(0, ncol = Q, nrow = Q)
  X_corr_rhat <- matrix(NA, ncol = P, nrow = Q)
  
  #count of partitions with fully saturated models
  pXcount <- lapply(comp, function(x) lapply(x, pX, p=P))
  
  #count of components with more than 0 X variables included
  nC <- lapply(comp, function(x) lapply(x, nComp))
  
  #count of components with only 1 Y
  p1Y <- lapply(comp, function(x) lapply(x, pY, p=1))
  
  #set up vector of temperatures
  temps <- sapply(1:n.temps, function(x) 1/(1.02^(x-1)))
  tempIdx <- lapply(1:n.chain, function(x) 1:n.temps)
  
  
  
  #put into list
  componentData <- lapply(1:n.chain, function(x) vector("list", n.temps))
  for(ch in 1:n.chain) 
  {
    for(tm in 1:n.temps) 
    {
      componentData[[ch]][[tm]] <- list(comp = comp[[ch]][[tm]], 
                    compCount = list(pXcount = pXcount[[ch]][[tm]], #fully saturated components
                                          nC = nC[[ch]][[tm]], #length X>0, length Y>0
                                         p1Y = p1Y[[ch]][[tm]],# length Y == 1
                                           K = length(comp[[ch]][[tm]]) # total components
                                     )
                    )
    }
  }
  
  iterCond <- list(progress.iter = progress.iter,
                   save.iter = save.iter,
                   n.save = n.save,
                   n.burnin = n.burnin,
                   half.switch = half.switch,
                   n.sweep = n.sweep,
                   n.temps = n.temps,
                   sweep = sweep)
  
  
  
  # write.table(NULL, file=out.file.Y, col.names=FALSE, row.names=FALSE)
  # write.table(NULL, file=out.file.X, col.names=FALSE, row.names=FALSE)
  
  # out list from function
  out <- vector("list", n.chain)
  
  # Start MCMC! #  
  cat("\nBegining MCMC ", date(), " \n")
  time <- proc.time()
  iter <- 0
  
  if(tempering=="none" & !parallel){
    for(ch in 1:n.chain) {
      out[[ch]] <- chainFunction(componentData[[ch]], priors, invariantData, chain = ch, iterCond, temps, tempIdx[[ch]], parallel=FALSE)
    }
    
  }
  if(tempering=="none" & parallel){
      out <- parallel::mcmapply(chainFunction, componentData = componentData, chain=(1:n.chain),  
                      MoreArgs=list(priors=priors, 
                                    invariantData= invariantData, 
                                    iterCond = iterCond, temps=1, tempIdx=1, parallel=FALSE),
                      SIMPLIFY=FALSE)
    
  }
  if(tempering=="sequential" & !parallel){
    
    for(ch in 1:n.chain){
      out[[ch]] <- chainFunction(componentData[[ch]], priors, invariantData, chain = ch, iterCond, temps, tempIdx[[ch]], parallel=FALSE)
    }
    
    
  }
  if(tempering=="sequential" & parallel) {
    
    out <- parallel::mcmapply(chainFunction, componentData=componentData, chain = 1:n.chain, tempIdx = tempIdx,
                    MoreArgs = list(priors=priors, 
                                    invariantData = invariantData, 
                                    iterCond = iterCond, 
                                    temps = temps, parallel=FALSE),
                    SIMPLIFY=FALSE)
  }
  if(tempering=="parallel" & parallel) {
    for(ch in 1:n.chain){
      out[[ch]] <- chainFunction(componentData[[ch]], priors, invariantData, chain = ch, iterCond, temps, tempIdx=tempIdx[[ch]], parallel=TRUE)
    }
  }
  if(tempering=="mpi") {
    for(ch in 1:n.chain){
      # not yet implemented
      # chainFunction(componentData[[ch]], priors, invariantData, chain = ch, iterCond)
    }
  }
  #save data from sampling
    for(ch in 1:n.chain){
      Y_corr[,,,ch] <- out[[ch]]$Y_corr
      X_corr[,,,ch] <- out[[ch]]$X_corr
      log_prob[,ch] <- out[[ch]]$log_prob
      numComp[,ch] <- out[[ch]]$numComp
      tempMonitor[,,ch] <-  out[[ch]]$tempMonitor
      accept <- accept + out[[ch]]$accept
    }
  
  
  #save mean correlations
  Y_corr_mean <- apply(Y_corr, 1:2, sum)/((n.iter-n.burnin) * n.chain)
  X_corr_mean <- apply(X_corr, 1:2, sum)/((n.iter-n.burnin) * n.chain)
  
  #change from counts to correlations
  for(i in 1:2){
    Y_corr[,,i,] <- Y_corr[,,i,]/n.half[i]
    X_corr[,,i,] <- X_corr[,,i,]/n.half[i]
  }
  
  #timings
  total.time <- (proc.time()-time)[3]
  hours <- floor(total.time/3600)
  minutes <- round((total.time-hours*3600)/60)
  cat("\nFinished in ", hours, " hours and ", minutes," minutes.\n",date(),"\n")
  
  
  #save the MCMC diagnostic quantities
  AR <- accept[1]/accept[2]
  
  if(n.chain > 1){
    Y_corr_upper <- array(NA, dim = c(ntriY, 2, n.chain))
    X_corr_upper <- array(NA, dim = c(P*Q, 2, n.chain))
    
    
    for(ch in 1:n.chain) {
      Y_corr_upper[,1,ch] <- c(Y_corr[,,1,ch][upperY])
      Y_corr_upper[,2,ch] <- c(Y_corr[,,2,ch][upperY])
      X_corr_upper[,1,ch] <- c(X_corr[,,1,ch])
      X_corr_upper[,2,ch] <- c(X_corr[,,2,ch])
    }
    
    rhat[1:ntriY] <- corr_rhat_rfun(Y_corr_upper, n.half)
    rhat[(ntriY+1):(ntriY + P*Q)] <- corr_rhat_rfun(X_corr_upper, n.half)
    rhat[ntriY + P*Q + 1] <- split_rhat_rfun(log_prob)
    rhat[ntriY + P*Q + 2] <- split_rhat_rfun(numComp)
    
    names(rhat)[1:ntriY] <- paste0("Y_corr ",
                                   c(row(Y_corr[,,1,1])[upperY]), ", ",
                                   c(col(Y_corr[,,1,1])[upperY] ))
    names(rhat)[(ntriY+1):(ntriY + P*Q)] <- paste0("X_corr ",
                                                   c(row(X_corr[,,1,1])), ", ",
                                                   c(col(X_corr[,,1,1])))
    names(rhat)[ntriY + P*Q + 1] <- "log_prob"
    names(rhat)[ntriY + P*Q + 2] <- "number components"
    rhat <- rhat[!is.na(rhat)]
    
    
    Y_corr_rhat[upperY] <- rhat[1:ntriY]
    X_corr_rhat[1:(P*Q)] <- rhat[(ntriY+1):(ntriY + P*Q)]
    
  }

  

  return(
    list(call = init.call,
         args = args,
         initial.component = initial.comp,
         final.component = lapply(out, function(o) o$comp),
         diagnostics=list(move.type=list(accept.ratio=AR, accept.pct=accept/(n.iter * n.chain)),
                          move1=MS1.ratio,
                          move2=MS2.ratio,
                          rhat = rhat,
                          Y_corr_rhat = Y_corr_rhat,
                          X_corr_rhat = X_corr_rhat),
         # save.files = list(Y=out.file.Y, X=out.file.X),
         correlations=list(mean=list(Y=Y_corr_mean, X=X_corr_mean), raw=list(Y=Y_corr,X=X_corr)),
         lp = log_prob,
         numComp = numComp,
         time = total.time,
         iter = list(num.samp = (n.iter-n.burnin) * n.chain,
                      n.half = n.half,
                      save.iter = save.iter,
                      n.sweep = n.sweep,
                      n.iter = n.iter,
                      n.burnin = n.burnin,
                      n.temps = n.temps,
                      progress.iter = progress.iter,
                      n.save = n.save,
                      half.switch = half.switch,
                      sweep = sweep)
    )
  )
}
eifer4/stochasticSampling documentation built on May 14, 2019, 11:16 a.m.