R/combineChain.R

combineChain <- function(output) {
  
  stopifnot(is.list(output))
  
  calls <- sapply(output, function(o) o$call)
  # for(cc in calls) stopifnot(all.equal(calls[[1]], cc))
  
  args  <- lapply(output, function(o) o$args)

  n.chain <- sum(sapply(args, function(aa) aa$n.chain))
  for(aa in 1:length(args)) args[[aa]]$n.chain <- NULL
  # for(aa in args) stopifnot(all.equal(unlist(args[[1]]), unlist(aa)))
  
  initial.comp <- lapply(output, function(o) o$initial.component)
  
  num.samp <- sum(sapply(output, function(o) o$iter$num.samp))
  
  
  dimsYcorr <- dim(output[[1]]$correlations$raw$Y)
  dimsXcorr <- dim(output[[1]]$correlations$raw$X)
  
  n.chain <- sum(sapply(output, function(o) dim(o$correlations$raw$Y)[4]))
  
  Q <- dimsXcorr[1]
  P <- dimsXcorr[2]
  
  dimsYcorr[4] <- dimsXcorr[4] <- n.chain
  
  dimnamesYcorr <- dimnames(output[[1]]$correlations$raw$Y)
  dimnamesXcorr <- dimnames(output[[1]]$correlations$raw$X)
  dimnamesYcorr$chains <- dimnamesXcorr$chains <- 1:n.chain
  
  Y_corr <- array(NA, dim = dimsYcorr, dimnames = dimnamesYcorr)
  X_corr <- array(NA, dim = dimsXcorr, dimnames = dimnamesXcorr)
  
  ch <- 1
  for(oo in 1:length(output)) {
    chain.range <- ch:(ch + dim(output[[oo]]$correlations$raw$Y)[4] -1)
    Y_corr[,,,chain.range] <- output[[oo]]$correlations$raw$Y
    X_corr[,,,chain.range] <- output[[oo]]$correlations$raw$X
    ch <- ch + dim(output[[oo]]$correlations$raw$Y)[4]
  }
  
  
  #save mean correlations
  Y_corr_mean <- apply(Y_corr, 1:2, sum)/(2*n.chain)
  X_corr_mean <- apply(X_corr, 1:2, sum)/(2*n.chain)
  
  correlations <- list(mean=list(Y=Y_corr_mean, X=X_corr_mean), raw=list(Y=Y_corr,X=X_corr))
  
  log_prob <- matrix(NA, nrow=nrow(output[[1]]$lp), ncol=n.chain)
  numComp  <- matrix(NA, nrow=nrow(output[[1]]$numComp), ncol=n.chain)
  
  ch <- 1
  for(oo in 1:length(output)) {
    chain.range <- ch:(ch + dim(output[[oo]]$correlations$raw$Y)[4] -1)
    log_prob[,chain.range] <- drop(output[[oo]]$lp)
    numComp[,chain.range]  <- drop(output[[oo]]$numComp)
    ch <- ch + dim(output[[oo]]$correlations$raw$Y)[4]
  }
  
  AR <- weighted.mean(sapply(output, function(o) o$diagnostics$move.type$accept.ratio),
                      sapply(output, function(o) o$iter$num.samp))
  accept.pct <- sapply(output, function(o) o$diagnostics$move.type$accept.pct) %*% sapply(output, function(o) o$iter$num.samp)/num.samp
  
  
  
  n.half <- output[[1]]$iter$n.half
  
  ntriY <- sum(upper.tri(Y_corr[,,1,1]))

  rhat <- numeric(ntriY + P*Q  + 2)
  
  Y_corr_upper <- array(NA, dim = c(ntriY, 2, n.chain))
  X_corr_upper <- array(NA, dim = c(P*Q, 2, n.chain))
  
  for(ch in 1:n.chain) {
    Y_corr_upper[,1,ch] <- c(Y_corr[,,1,ch][upper.tri(Y_corr[,,1,ch])])
    Y_corr_upper[,2,ch] <- c(Y_corr[,,2,ch][upper.tri(Y_corr[,,2,ch])])
    X_corr_upper[,1,ch] <- c(X_corr[,,1,ch])
    X_corr_upper[,2,ch] <- c(X_corr[,,2,ch])
  }
  
  
  rhat[1:ntriY] <- corr_rhat_rfun(Y_corr_upper, n.half)
  rhat[(ntriY+1):(ntriY + P*Q)] <- corr_rhat_rfun(X_corr_upper, n.half)
  rhat[ntriY + P*Q + 1] <- split_rhat_rfun(log_prob)
  rhat[ntriY + P*Q + 2] <- split_rhat_rfun(numComp)
  
  names(rhat)[1:ntriY] <- paste0("Y_corr ",
                                 c(row(Y_corr[,,1,1])[upper.tri(Y_corr[,,1,ch])]), ", ",
                                 c(col(Y_corr[,,1,1])[upper.tri(Y_corr[,,1,ch])] ))
  names(rhat)[(ntriY+1):(ntriY + P*Q)] <- paste0("X_corr ",
                                              c(row(X_corr[,,1,1])), ", ",
                                              c(col(X_corr[,,1,1])))
  names(rhat)[ntriY + P*Q + 1] <- "log_prob"
  names(rhat)[ntriY + P*Q + 2] <- "number components"
  rhat <- rhat[!is.na(rhat)]
  
  diagnostics <- list(move.type=list(accept.ratio=AR, accept.pct=accept.pct),
                      move1=NULL,
                      move2=NULL,
                      rhat = rhat)
  
  
  total.time <- sapply(output, function(o) o$time)
  
  
  return(
    list(call=calls[[1]], 
       args=args[[1]],
       initial.component = initial.comp,
       final.component = lapply(output, function(o) o$final.component),
       diagnostics=diagnostics,
       correlations= correlations,
       lp = log_prob,
       numComp = numComp,
       time = total.time,
       iter = list(num.samp = num.samp,
                   n.half = output[[1]]$iter$n.half,
                   save.iter = output[[1]]$iter$save.iter,
                   n.sweep = output[[1]]$iter$n.sweep,
                   n.iter = output[[1]]$iter$n.iter,
                   n.burnin = output[[1]]$iter$n.burnin,
                   n.temps = output[[1]]$iter$n.temps)
    )
  )
}
eifer4/stochasticSampling documentation built on May 14, 2019, 11:16 a.m.