R/hmc_tempered_GM.R

##########################
# function to use HMC (STAN) to sample from the tempered t-distribution using 'mixture_gaussian.stan' file (should be in src/stan_files)
##########################

#' HMC sampler for tempered mixture Gaussian
#'
#' Sample from tempered target using Stan
#'
#' @param weights vector: weights of mixture Gaussian
#' @param means vector: means of mixture Gassuan
#' @param sds vector: st.devs of mixture Gaussian
#' @param beta temperature level
#' @param iterations number of iterations per chain - note that burn in is 0.5 of this number
#' @param chains number of chains
#' @param output boolean value: defaults to T, determines whether or not to print output to console
#'
#' @return samples from the tempered target
#'
#' @examples
#' hmc_sample_tempered_mixG(weights = c(0.5, 0.5), means = c(-2, 2), sds = c(0.5, 1),
#'                          beta = 0.2, iterations = 2*100000, chains = 1, output = F)
#'
#' @export
hmc_sample_tempered_mixG <- function(weights, means, sds, beta, iterations, chains, output = F) {
  # print output to console
  print('Sampling from tempered mixture Gaussian density')

  # function to use Stan (HMC) to sample from the tempered mixture Gaussian distribution
  training_data <- list(beta = beta,
                        K = length(weights),
                        weights = weights,
                        means = means,
                        sds = sds)
  if (output) {
    # print output to console
    model <- rstan::sampling(object = stanmodels$mixture_gaussian,
                      data = training_data,
                      iter = iterations,
                      chains = chains)
  } else {
    # hide output from console
    model <- rstan::sampling(object = stanmodels$mixture_gaussian,
                      data = training_data,
                      iter = iterations,
                      chains = chains,
                      verbose = FALSE,
                      refresh = 0)
  }

  # print completion
  print('Finished sampling from tempered mixture Gaussian density')
  return(rstan::extract(model)$x)
}

#' HMC sampler for base level
#'
#' Sample for base level (tempered mixture Gaussian)
#'
#' @param weights vector: weights of mixture Gaussian
#' @param means vector: means of mixture Gassuan
#' @param sds vector: st.devs of mixture Gaussian
#' @param beta temperature level
#' @param nsamples number of samples per node
#' @param nchains number of nodes
#'
#' @return samples from tempered target
#'
#' @examples
#' hmc_base_sampler_mixG(weights = c(0.5, 0.5), means = c(-2, 2), sds = c(0.5, 1),
#'                       beta = 1, iterations = 100000, chains = 1, output = F)
#'
#' @export
hmc_base_sampler_mixG <- function(weights, means, sds, beta, nsamples, nchains) {
  # samples for the base level - set nchains = number of nodes at base level

  # sample at the base level
  base <- hmc_sample_tempered_mixG(weights, means, sds, beta, iterations = 2*nsamples, chains = nchains, output = F)

  # split into nodes as a list and return
  base_samples <- split(base, ceiling((1:length(base))/nsamples))
  return(base_samples)
}
rchan26/mixGaussTempering documentation built on June 14, 2019, 3:26 p.m.