R/fusion_SMC_BLR.R

Defines functions get_resampled_top_level_SMC hierarchical_fusion_SMC_TA_BLR par_fusion_SMC_TA_BLR Q_importance_sample

Documented in get_resampled_top_level_SMC hierarchical_fusion_SMC_TA_BLR par_fusion_SMC_TA_BLR

# rho_importance_sample <- function(N,
#                                   dim,
#                                   time,
#                                   m,
#                                   samples_to_fuse,
#                                   sub_posterior_weights) {
#   # variables for first importance sampling step
#   x_samples <- rep(list(matrix(nrow = m, ncol = dim)), N)
#   log_rho_weights <- rep(0, N)
#   
#   # first importance sampling step
#   for (i in 1:N) {
#     # sample partile
#     for (j in 1:m) {
#       x_samples[[i]][j,] <- samples_to_fuse[[j]][sample(nrow(samples_to_fuse[[j]]), 1),]
#     }
#     # calculate the associated weight for the particle
#     # calculate the weighted average of the sub-posterior samples
#     weighted_avg <- BayesLogitFusion:::weighted_column_mean(matrix = x_samples[[i]],
#                                                             weights = sub_posterior_weights)
#     # calculate the first acceptance probability (weight for first particle set)
#     log_rho_weights[i] <- BayesLogitFusion:::log_rho_time_adapting(x = x_samples[[i]],
#                                                                    weighted_mean = weighted_avg,
#                                                                    sub_posterior_weights = sub_posterior_weights,
#                                                                    time = time)
#   }
#   
#   return(list('x_samples' = x_samples,
#               'rho_weights' = exp(log_rho_weights)))
# }

Q_importance_sample <- function(dim,
                                y_split,
                                X_split,
                                prior_means,
                                prior_variances,
                                time,
                                m,
                                C,
                                power, 
                                precondition_matrices,
                                K,
                                samples_to_fuse,
                                sub_posterior_weights,
                                x_samples,
                                level = 1,
                                node = 1,
                                core = 1) {
  # variables for second importance sampling step
  N <- length(x_samples)
  y_samples <- matrix(data = NA, nrow = N, ncol = dim)
  Q_weights <- rep(1, N)
  
  # finding T_{c}'s and storing them in new_times
  new_times <- time / sub_posterior_weights
  # setting T^{*} (T-star) as the largest T_{c} = (time / w_{c}) for c = 1, ..., m
  T_star <- max(new_times)
  
  # second importance sampling step
  for (i in 1:N) {
    # calculate the weighted average of the sub-posterior samples
    weighted_avg <- BayesLogitFusion:::weighted_column_mean(matrix = x_samples[[i]],
                                                            weights = sub_posterior_weights)
    # simulate proposed value y from a Gaussian distribution
    y_samples[i,] <- BayesLogitFusion:::mvrnormArma(N = 1, 
                                                    mu = weighted_avg, 
                                                    Sigma = diag(time/sum(sub_posterior_weights), dim))
    # simulate m diffusion and obtain their probabilities (weight for second particle set)
    # Q_weights[i] is a product of m diffusion probabilities
    for (c in 1:m) {
      Q_weights[i] <- Q_weights[i]*diffusion_probability_BLR(dim = dim, 
                                                             x0 = x_samples[[i]][c,],
                                                             y = y_samples[i,], 
                                                             s = T_star - new_times[c],
                                                             t = T_star, 
                                                             K = K[c], 
                                                             initial_parameters = weighted_avg,
                                                             y_labels = y_split[[c]], 
                                                             X = X_split[[c]],
                                                             prior_means = prior_means, 
                                                             prior_variances = prior_variances, 
                                                             C = C,
                                                             power = power,
                                                             precondition_mat = precondition_matrices[[c]])
    }
    
    cat('Level:', level, '|| Node:', node, '|| Core:', core, '||', i, '/', N, '\n', 
        file = 'Q_importance_sample_progress.txt', append = T)
  }
  
  cat('Core', core, 'finished sampling \n',
      file = 'Q_importance_sample_progress.txt', append = T)
  
  return(list('y_samples' = y_samples,
              'Q_weights' = Q_weights))
}

######################################## parallelised time adapting SMC fusion ########################################

#' Parallel Time-adapting Sequential Monte Carlo Fusion for Bayesian Logistic Regression model
#'
#' @param N number of samples
#' @param dim dimension of the predictors (= p+1)
#' @param y_split list of length m, where y_split[[c]] is the y responses for sub-posterior c
#' @param X_split list of length m, where X_split[[c]] is the design matrix for sub-posterior c 
#' @param prior_means prior for means of predictors
#' @param prior_variances prior for variances of predictors
#' @param time time T for fusion algorithm
#' @param m number of sub-posteriors to combine
#' @param C overall number of sub-posteriors
#' @param power exponent of target distribution
#' @param precondition boolean value determining whether or not a preconditioning matrix is to be used
#' @param samples_to_fuse list of length m, where samples_to_fuse[c] contains the samples for the c-th sub-posterior
#' @param sub_posterior_weights vector of length m, where sub_posterior_weights[c] is the weight for sub-posterior c
#' @param ESS_threshold number between 0 and 1 defining the proportion of the number of samples that ESS needs to be 
#'                      lower than for resampling (i.e. resampling is carried out only when ESS < N*ESS_threshold)
#' @param seed seed number - default is NULL, meaning there is no seed
#' @param level indicates which level this is for the hierarchy (default 1)
#' @param node indicates which node this is for the hierarchy (default 1)
#' @param n_cores number of cores to use
#'
#' @return A list with components:
#' \describe{
#'   \item{samples}{samples from fusion (resampled if ESS < N*ESS_threshold)}
#'   \item{weighted_sampled}{weighted samples}
#'   \item{Q_weights}{normalised Q weights (after Q step)}
#'   \item{x_samples}{samples after the first fusion step (after rho step)}
#'   \item{rho_weights}{normalised rho weights (after rho step)}
#'   \item{ESS}{effective sample size after rho and Q step}
#'   \item{combined_y}{combined y responses after fusion}
#'   \item{combined_X}{combined design matrix after fusion}
#'   \item{resampled}{boolean value to indicate whether or not samples were resampled}
#'   \item{precondition_matrices}{pre-conditioning matricies that were used}
#' }
#' 
#' @export
par_fusion_SMC_TA_BLR <- function(N, 
                                  dim,
                                  y_split, 
                                  X_split,
                                  prior_means,
                                  prior_variances,
                                  time, 
                                  m,
                                  C,
                                  power,
                                  precondition = FALSE,
                                  samples_to_fuse,
                                  sub_posterior_weights,
                                  ESS_threshold = 0.5,
                                  seed = NULL,
                                  level = 1,
                                  node = 1,
                                  n_cores = parallel::detectCores()) {
  if (length(samples_to_fuse)!=m) {
    stop("par_fusion_SMC_TA_BLR: samples_to_fuse must be a list of length m")
  } else if (length(y_split)!=m) {
    stop("par_fusion_SMC_TA_BLR: y_split must be a list of length m")
  } else if (length(X_split)!=m) {
    stop("par_fusion_SMC_TA_BLR: X_split must be a list of length m")
  } else if (length(sub_posterior_weights)!=m) {
    stop("par_fusion_SMC_TA_BLR: sub_posterior_weights must be a vector of length m")
  }
  
  # check that all the samples in samples_to_fuse are matrices with dim columns
  for (c in 1:length(samples_to_fuse)) {
    if (ncol(samples_to_fuse[[c]])!=dim) {
      stop("par_fusion_SMC_TA_BLR: check that samples_to_fuse contains matrices with dim columns")
    }
  }
  
  # set a seed if one is supplied
  if (!is.null(seed)) {
    set.seed(seed)
  }
  
  # if using a preconditioning matrix, we need to calculate them
  if (precondition) {
    precondition_matrices <- lapply(1:m, function(c) {
      # use the diagonal sqrt inverse covariance matrix for the cth sub-posterior
      sqrt(diag(diag(cov(samples_to_fuse[[c]]))))
    })
  } else {
    # dont want to use a preconditioning matrix - i.e. use the identity matrix
    precondition_matrices <- lapply(1:m, function(c) diag(1, dim))
  }
  
  # need to find the lower bound K for the Exact Algorithm (second importance sampling step for Q)
  K <- rep(0, m)
  for (c in 1:m) {
    K[c] <- BayesLogitFusion:::phi_LB_BLR(X = X_split[[c]],
                                          prior_variances = prior_variances,
                                          C = C,
                                          power = power,
                                          precondition_mat = precondition_matrices[[c]])
  }
  
  cat('Level:', level, '||', 'Node:', node, '|| Starting fusion for level', 
      level, '( node', node, ') trying to get', N, 'samples\n',
      file = 'fusion_SMC_TA_BLR_progress.txt', append = T)
  
  # start time recording
  pcm <- proc.time()
  ########## first importance sampling step
  rho_weighted_samples <- BayesLogitFusion:::rho_importance_sample(N = N,
                                                                   dim = dim,
                                                                   time = time, 
                                                                   m = m,
                                                                   samples_to_fuse = samples_to_fuse,
                                                                   sub_posterior_weights = sub_posterior_weights) 
  
  # normalise weights
  rho_weights <- rho_weighted_samples$rho_weights / sum(rho_weighted_samples$rho_weights)
  # check that the weights are not NaN (can happen if we have numerical error)
  if (any(is.na(rho_weights))) {
    stop("par_fusion_SMC_TA_BLR: rho_weights are NaN - potential numerical error. Try increasing value for T")
  }
  # create variable resampled to log if the samples were resampled
  resampled <- c()
  # only resample if ESS < N*ESS_threshold
  ESS_rho <- 1 / sum(rho_weights^2)
  if (ESS_rho < N*ESS_threshold) {
    resampled['rho'] <- TRUE
    x_samples <- sample(x = rho_weighted_samples$x_samples, size = N, replace = T, prob = rho_weights)
  } else {
    resampled['rho'] <- FALSE
    x_samples <- rho_weighted_samples$x_samples
  }
  
  # print completion of first importance sampling step
  cat('Level:', level, '||', 'Node:', node, '|| Finished first importance sampling step \n',
      file = 'fusion_SMC_TA_BLR_progress.txt', append = T)
  
  ########## second importance sampling step
  # creating parallel cluster
  cl <- parallel::makeCluster(n_cores)
  # creating variable and functions list to pass into cluster using clusterExport
  varlist <- list("phi_LB_BLR", 
                  "bound_phi_BLR", 
                  "Q_importance_sample", 
                  "diffusion_probability_BLR",
                  "precondition_matrices")
  parallel::clusterExport(cl, envir = environment(), varlist = varlist)
  # exporting functions from layeredBB package to simulate layered Brownian bridges
  parallel::clusterExport(cl, varlist = ls("package:layeredBB"))
  
  if (!is.null(seed)) {
    # setting seed for the cores in the cluster
    parallel::clusterSetRNGStream(cl, iseed = seed)
  }
  
  # split the resampled x values into approximately equal lists
  max_samples_per_core <- ceiling(N/n_cores)
  split_x_samples <- split(x_samples, ceiling(seq_along(x_samples)/max_samples_per_core))
  
  # for each particle x, we propose a new value y and assign a weight for it
  # sample for y and importance weight in parallel
  Q_weighted_samples <- parallel::parLapply(cl, X = 1:length(split_x_samples), fun = function(i) {
    Q_importance_sample(dim = dim,
                        y_split = y_split,
                        X_split = X_split,
                        prior_means = prior_means,
                        prior_variances = prior_variances,
                        time = time,
                        m = m,
                        C = C,
                        power = power, 
                        precondition_matrices = precondition_matrices,
                        K = K,
                        samples_to_fuse = samples_to_fuse,
                        sub_posterior_weights = sub_posterior_weights,
                        x_samples = split_x_samples[[i]],
                        level = level, 
                        node = node,
                        core = i)
  })
  final <- proc.time() - pcm
  
  # stopping cluster
  parallel::stopCluster(cl)
  
  ########## return samples and acceptance probabilities
  
  y_samples <- do.call(rbind, lapply(1:length(split_x_samples), function(i) Q_weighted_samples[[i]]$y_samples))
  Q_weights <- unlist(lapply(1:length(split_x_samples), function(i) Q_weighted_samples[[i]]$Q_weights))
  
  # print completion of second importance sampling step
  cat('Level:', level, '||', 'Node:', node, '|| Finished second importance sampling step \n',
      file = 'fusion_SMC_TA_BLR_progress.txt', append = T)
  
  ########## return importance samples
  
  # return resample particles and weighted particles
  # normalise weights
  Q_weights <- Q_weights / sum(Q_weights)
  # check that the weights are not NaN (can happen if we have numerical error)
  if (any(is.na(Q_weights))) {
    stop("par_fusion_SMC_TA_BLR: Q_weights are NaN - potential numerical error. Try decreasing value for T")
  }
  
  # only resample if ESS < N*ESS_threshold
  ESS_Q <- 1 / sum(Q_weights^2)
  if (ESS_Q < N*ESS_threshold) {
    resampled['Q'] <- TRUE
    fusion_samples <- y_samples[sample(1:nrow(y_samples), size = N, replace = T, prob = Q_weights),]
  } else {
    resampled['Q'] <- FALSE
    fusion_samples <- y_samples
  }
  
  # return the full data set combined together
  combined_y <- do.call(append, y_split)
  combined_X <- do.call(rbind, X_split)
  
  # print completion
  cat('Level:', level, '||', 'Node:', node, '|| Completed SMC fusion \n', 
      file = 'fusion_SMC_TA_BLR_progress.txt', append = T)
  
  return(list('samples' = fusion_samples, 
              'weighted_samples' = y_samples,
              'Q_weights' = Q_weights,
              'x_samples' = x_samples,
              'rho_weights' = rho_weights,
              'ESS' = c('rho' = ESS_rho, 'Q' = ESS_Q),
              'combined_y' = combined_y,
              'combined_X' = combined_X,
              'time' = final['elapsed'], 
              'resampled' = resampled,
              'precondition_matrices' = precondition_matrices))
}

######################################## time adapting SMC hierarchical fusion ########################################

#' Time-adapting Hierarchical Monte Carlo Fusion using SMC for Bayesian Logistic Regression model
#'
#' @param N_schedule vector of length (L-1), where N_schedule[l] is the number of samples per node at level l
#' @param dim dimension of the predictors (= p+1)
#' @param y_split list of length C, where y_split[[c]] is the y responses for sub-posterior c
#' @param X_split list of length C, where X_split[[c]] is the design matrix for sub-posterior c 
#' @param prior_means prior for means of predictors
#' @param prior_variances prior for variances of predictors
#' @param global_T time T for time-adapting fusion algorithm
#' @param m_schedule vector of length (L-1), where m_schedule[l] is the number of samples to fuse for level l
#' @param C number of sub-posteriors at the base level
#' @param power exponent of target distribution
#' @param precondition boolean value determining whether or not a preconditioning matrix is to be used 
#' @param L total number of levels in the hierarchy
#' @param base_samples list of length C, where samples_to_fuse[c] containg the samples for the c-th node in the level
#' @param ESS_threshold number between 0 and 1 defining the proportion of the number of samples that ESS needs to be 
#'                      lower than for resampling (i.e. resampling is carried out only when ESS < N*ESS_threshold)
#' @param seed seed number - default is NULL, meaning there is no seed
#' @param n_cores number of cores to use
#'
#' @return A list with components:
#' \describe{
#'   \item{samples}{list of length (L-1), where samples[[l]][[i]] are the samples for level l, node i}
#'   \item{weighted_samples}{list of length (L-1), where weighted_samples[[l]][[i]] are the weighted samples for level l, node i}
#'   \item{Q_weights}{list of length (L-1), where Q_weights[[l]][[i]] is the normalised Q weights for level l, node i}
#'   \item{ESS}{list of length (L-1), where ESS[[l]][[i]] is the ESS for samples in level l, node i}
#'   \item{time}{list of length (L-1), where time[[l]] is the run time for level l, node i}
#'   \item{resampled}{list of length (L-1), where resampled[[l]][[i]] is a boolean value to indicate if samples in level, node i
#'         were resampled after each fusion step}
#'   \item{y_inputs}{input y data for each level and node}
#'   \item{X_inputs}{input X data for each level and node}
#'   \item{C_inputs}{vector of length (L-`), where C_inputs[l] is the number of sub-posteriors at level l+1 (the input for C to get to level l)}
#'   \item{sub_posterior_weight_inputs}{list of length (L), where sub_posterior_weight_inputs[[l]] is the input for the sub-posterior weights for level l}
#'   \item{diffusion_times}{vector of length (L-1), where diffusion_times[l] are the times for fusion in level l}
#'   \item{power}{exponent of target distributions in the hierarchy}
#'   \item{precondition_matrices}{pre-conditioning matricies that were used}
#' } 
#' 
#' @export
hierarchical_fusion_SMC_TA_BLR <- function(N_schedule, 
                                           dim,
                                           y_split,
                                           X_split, 
                                           prior_means, 
                                           prior_variances,
                                           global_T,
                                           m_schedule, 
                                           C,
                                           power,
                                           precondition = FALSE,
                                           L,
                                           base_samples, 
                                           ESS_threshold = 0.5,
                                           seed = NULL,
                                           n_cores = parallel::detectCores()) {
  # check variables are of the right length
  if (length(N_schedule) != (L-1)) {
    stop('hierarchical_fusion_SMC_TA_BLR: length of N_schedule must be equal to (L-1)')
  } else if (!is.list(y_split) || length(y_split)!= C) {
    stop('hierarchical_fusion_SMC_TA_BLR: check that y_split is a list of length C')
  } else if (!is.list(X_split) || length(X_split)!= C) {
    stop('hierarchical_fusion_SMC_TA_BLR: check that X_split is a list of length C')
  }
  
  # output warning to say that top level does not have C=1
  if (C != prod(m_schedule)) {
    warning('hierarchical_fusion_SMC_TA_BLR: C != prod(m_schedule) - the top level does not have C=1')
  }
  
  # check that at each level, we are fusing a suitable number
  if (length(m_schedule) == (L-1)) {
    for (l in (L-1):1) {
      if ((C/prod(m_schedule[(L-1):l]))%%1 != 0) {
        stop('hierarchical_fusion_SMC_TA_BLR: check that C/prod(m_schedule[(L-1):l]) is an integer for l=L-1,...,1')
      }
    }
  } else {
    stop('hierarchical_fusion_SMC_TA_BLR: m_schedule must be a vector of length (L-1)')
  }
  
  # check ESS_threshold is strictly between 0 and 1
  if (ESS_threshold < 0 || ESS_threshold > 1) {
    stop('hierarchical_goku: ESS_threshold must be between 0 and 1')
  }
  
  # we append 1 to the vector m_schedule to make the indices work later on when we call fusion
  m_schedule <- c(m_schedule, 1)
  
  # initialising results that we want to keep
  hier_samples <- list()
  hier_samples[[L]] <- base_samples # base level
  weighted_samples <- list()
  Q_weights <- list()
  ESS <- list()
  y_inputs <- list()
  y_inputs[[L]] <- y_split # base level input samples for y
  X_inputs <- list()
  X_inputs[[L]] <- X_split # base level input samples for X
  C_inputs <- rep(0, L-1)
  sub_posterior_weight_inputs <- list()
  diffusion_times <- list()
  time <- list()
  resampled <- list()
  precondition_matrices <- list()
  
  # add to output file that starting hierarchical fusion
  cat('Starting hierarchical fusion \n', file = 'hierarchical_fusion_SMC_TA_BLR.txt')
  
  # parallelising tasks for each level going up the hiearchy
  for (k in ((L-1):1)) {
    # since previous level has C/prod(m_schedule[L:(k-1)]) nodes and we fuse m_schedule[k] of these
    n_nodes <- C/prod(m_schedule[L:k]) 
    
    # performing Fusion for this level
    # printing out some stuff to log file to track the progress
    cat('########################\n', file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
    cat('Starting to fuse', m_schedule[k], 'sub-posteriors for level', k, 'with time', 
        global_T/prod(m_schedule[L:(k+1)]), ', which is using', n_cores, 'cores\n',
        file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
    cat('At this level, the data is split up into', (C / prod(m_schedule[L:(k+1)])), 'subsets\n',
        file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
    cat('There are', n_nodes, 'nodes at the next level each giving', N_schedule[k],
        'samples \n', file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
    cat('########################\n', file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
    
    # starting fusion
    fused <- lapply(X = 1:n_nodes, FUN = function(i) {
      par_fusion_SMC_TA_BLR(N = N_schedule[k], 
                            dim = dim, 
                            y_split = y_inputs[[k+1]][((m_schedule[k]*i)-(m_schedule[k]-1)):(m_schedule[k]*i)],
                            X_split = X_inputs[[k+1]][((m_schedule[k]*i)-(m_schedule[k]-1)):(m_schedule[k]*i)],
                            prior_means = prior_means, 
                            prior_variances = prior_variances,
                            time = global_T,
                            m = m_schedule[k],
                            C = (C / prod(m_schedule[L:(k+1)])),
                            power = power,
                            precondition = precondition,
                            samples_to_fuse = hier_samples[[k+1]][((m_schedule[k]*i)-(m_schedule[k]-1)):(m_schedule[k]*i)],
                            sub_posterior_weights = rep(prod(m_schedule[L:(k+1)]), m_schedule[k]), 
                            ESS_threshold = ESS_threshold,
                            seed = seed, 
                            level = k, 
                            node = i,
                            n_cores = n_cores)
    })
    
    # need to combine the correct samples
    hier_samples[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$samples)
    weighted_samples[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$weighted_samples)
    Q_weights[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$Q_weights)
    y_inputs[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$combined_y)
    X_inputs[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$combined_X)
    ESS[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$ESS)
    time[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$time)
    resampled[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$resampled)
    precondition_matrices[[k]] <- lapply(1:n_nodes, function(i) fused[[i]]$precondition_matrices)
    
    # keep track of other information
    C_inputs[k] <- (C / prod(m_schedule[L:(k+1)]))
    sub_posterior_weight_inputs[[k]] <- rep(prod(m_schedule[L:(k+1)]), m_schedule[k])
    diffusion_times[[k]] <- global_T / rep(prod(m_schedule[L:(k+1)]), m_schedule[k])
  }
  
  # print completion
  cat('Completed hierarchical fusion\n', file = 'hierarchical_fusion_SMC_TA_BLR.txt', append = T)
  
  if (C == prod(m_schedule)) {
    hier_samples[[1]] <- hier_samples[[1]][[1]]
    y_inputs[[1]] <- y_inputs[[1]][[1]]
    X_inputs[[1]] <- X_inputs[[1]][[1]]
    weighted_samples[[1]] <- weighted_samples[[1]][[1]]
    Q_weights[[1]] <- Q_weights[[1]][[1]]
    ESS[[1]] <- ESS[[1]][[1]]
    time[[1]] <- time[[1]][[1]]
    resampled[[1]] <- resampled[[1]][[1]]
    precondition_matrices[[1]] <- precondition_matrices[[1]][[1]]
  }
  
  return(list('samples' = hier_samples, 
              'weighted_samples' = weighted_samples,
              'Q_weights' = Q_weights,
              'ESS' = ESS,
              'time' = time,
              'resampled' = resampled,
              'y_inputs' = y_inputs,
              'X_inputs' = X_inputs,
              'C_inputs' = C_inputs,
              'sub_posterior_weight_inputs' = sub_posterior_weight_inputs,
              'diffusion_times' = diffusion_times,
              'power' = power,
              'precondition_matrices' = precondition_matrices))
}

######################################## miscellena ########################################

#' Get resampled top level for SMC hierarchy
#' 
#' Function resamples the top level of the hierarchy if it has not been resampled
#' 
#' @param hierarchical_result hierarchy returned by 'hierarchical_fusion_SMC_TA_BLR' function
#' @param seed seed number - default is NULL, meaning there is no seed
#' 
#' @export
get_resampled_top_level_SMC <- function(hierarchical_result, seed = NULL) {
  if (!is.null(seed)) {
    # setting seed if given
    set.seed(seed)
  }
  
  # if the top level got resampled, then we keep that, otherwise resample according to weights
  if (hierarchical_result$resampled[[1]]['Q']) {
    samples <- hierarchical_result$samples[[1]]
  } else {
    samples <- hierarchical_result$samples[[1]][sample(x = 1:nrow(hierarchical_result$samples[[1]]), 
                                                       size = nrow(hierarchical_result$samples[[1]]), 
                                                       replace = T,
                                                       prob = hierarchical_result$Q_weights[[1]]),]
  }
  
  return(samples)
}
rchan26/BayesLogitFusion documentation built on June 13, 2020, 5:03 a.m.