R/chainFunction.R

chainFunction <- function(componentData, priors, invariantData, chain, iterCond, temps, tempIdx, parallel=FALSE) {
  
  samplerData <- lapply(componentData, function(x) list(componentData=x))
  tidx <- tempIdx
  tExch <- TRUE
  
  if(length(tidx)==1) tExch <- FALSE
  
  n.sweep <- iterCond$n.sweep
  n.temp <- iterCond$n.temp
  n.burnin <- iterCond$n.burnin
  n.save <- iterCond$n.save
  progress.iter <- iterCond$progress.iter
  sweep <- iterCond$sweep
  
  iter <- 0
  sidx <- 1
  Q <- invariantData$Q
  P <- invariantData$P
  Y_corr  <- array(0, dim=c(Q, Q, 2), dimnames=list(yrows=NULL, ycols=NULL, half=1:2))
  X_corr  <- array(0, dim=c(Q, P, 2), dimnames=list(yrows=NULL, xcols=NULL, half=1:2))
  log_prob <- numComp <- numeric(n.save)
  accept <- numeric(2)
  tempMonitor <- matrix(0, nrow=length(temps), ncol=length(temps))
   
    for(s in 1:n.sweep) {
      if(tExch) for(tm in 1:length(tidx)) tempMonitor[tm,tidx[tm]]
      
      if(parallel){
        compDat <- lapply(samplerData, function(x) x$componentData)
        samplerData <- parallel::mcmapply(partitionSampler, componentData=compDat, tempIdx = tidx,
                          MoreArgs = list(priors=priors, invariantData = invariantData, 
                                          n.iter = sweep, total.iter=iter, 
                                          iterCond = iterCond,
                                          temp = temps),
                          SIMPLIFY = FALSE)
      } else {
        for(tm in 1:n.temp) {
          samplerData[[tm]] <- partitionSampler(samplerData[[tm]]$componentData, 
                                                priors, invariantData, 
                                                sweep, iter, iterCond, 
                                                temp = temps[tidx[tm]],tidx[tm])
          accept <- samplerData[[tm]]$accept + accept
        }
      }
      
      
      iter <- iter + sweep - 1
      
      if(iter > n.burnin) {
        saveidx <- which(tidx == 1)
        new.sidx <- sidx + samplerData[[saveidx]]$n.save-1
        
        Y_corr[,,] <- Y_corr[,,] + samplerData[[saveidx]]$Y_corr
        X_corr[,,] <- X_corr[,,] + samplerData[[saveidx]]$X_corr
        log_prob[sidx:new.sidx] <- samplerData[[saveidx]]$log_prob
        numComp[sidx:new.sidx] <- samplerData[[saveidx]]$numComp
        sidx <- new.sidx + 1
      }

      progressFunction(iter, n.burnin, chain)
      
      if(tExch) tempExchange(invariantData=invariantData, comp=samplerData, priors=priors, tempIdx=tidx, temps=temps)
      
  
    }
  
  return(
    list(Y_corr = Y_corr,
         X_corr = X_corr,
         log_prob = log_prob,
         numComp = numComp,
         accept = accept,
         tempMonitor = tempMonitor,
         comp = sapply(samplerData, function(s) s$componentData))
  )

}
eifer4/stochasticSampling documentation built on May 14, 2019, 11:16 a.m.