Nothing
#' Sample from the ensemble posterior distribution
#'
#' Given model averaging weights (e.g., from Bayesian model averaging (BMA), pseudo-BMA, or stacking) and a matrix of
#' posterior samples from the candidate models, this function draws samples from the model-averaged posterior distribution.
#' Here, each "model" refers to a unique combination of an outcome model and its associated priors. Posterior draws
#' are randomly selected from the candidate models in proportion to their specified weights, producing samples from
#' the ensemble of posterior distributions.
#'
#' This function is typically used in combination with [compute.ensemble.weights()], which computes model averaging
#' weights using methods such as Bayesian model averaging (BMA), pseudo-BMA, pseudo-BMA with the Bayesian bootstrap,
#' or stacking). The input matrix of posterior samples should have one column per candidate model, with each column
#' containing posterior draws from that model.
#'
#' @export
#'
#' @param wts a numeric vector of normalized model averaging weights (e.g., from [compute.ensemble.weights()]).
#' The length of `wts` must match the number of columns in `samples.mtx`.
#' @param samples.mtx a matrix of posterior samples. Each column corresponds to samples from a different model, and each
#' row is one posterior draw (e.g., from Markov chain Monte Carlo (MCMC) sampling). All columns must
#' have the same number of samples.
#'
#' @return
#' The function returns a numeric vector of ensemble posterior draws, sampled proportionally to the provided model weights.
#' The returned vector has the same length as the number of rows in `samples.mtx`.
#'
#' @seealso [compute.ensemble.weights()]
#'
#' @examples
#' if (instantiate::stan_cmdstan_exists()) {
#' if(requireNamespace("survival")){
#' library(survival)
#' data(E1684)
#' data(E1690)
#' ## replace 0 failure times with 0.50 days
#' E1684$failtime[E1684$failtime == 0] = 0.50/365.25
#' E1690$failtime[E1690$failtime == 0] = 0.50/365.25
#' E1684$cage = as.numeric(scale(E1684$age))
#' E1690$cage = as.numeric(scale(E1690$age))
#' data_list = list(currdata = E1690, histdata = E1684)
#' nbreaks = 3
#' probs = 1:nbreaks / nbreaks
#' breaks = as.numeric(
#' quantile(E1690[E1690$failcens==1, ]$failtime, probs = probs)
#' )
#' breaks = c(0, breaks)
#' breaks[length(breaks)] = max(10000, 1000 * breaks[length(breaks)])
#' fit.pwe.pp = pwe.pp(
#' formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#' data.list = data_list,
#' breaks = breaks,
#' a0 = 0.5,
#' get.loglik = TRUE,
#' chains = 1, iter_warmup = 1000, iter_sampling = 2000
#' )
#' fit.pwe.post = pwe.post(
#' formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#' data.list = data_list,
#' breaks = breaks,
#' get.loglik = TRUE,
#' chains = 1, iter_warmup = 1000, iter_sampling = 2000
#' )
#' fit.pwe.commensurate = pwe.commensurate(
#' formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
#' data.list = data_list,
#' breaks = breaks,
#' p.spike = 0.1,
#' get.loglik = TRUE,
#' chains = 1, iter_warmup = 1000, iter_sampling = 2000
#' )
#' fit.list = list(fit.pwe.post, fit.pwe.pp, fit.pwe.commensurate)
#' samples.mtx = do.call(
#' cbind, lapply(fit.list, function(d){
#' as.numeric( d[["treatment"]] )
#' })
#' )
#' wts = compute.ensemble.weights(
#' fit.list = fit.list,
#' type = "pseudobma+"
#' )$weights
#' sample.ensemble(
#' wts = wts, samples.mtx = samples.mtx
#' )
#' }
#' }
sample.ensemble = function(
wts,
samples.mtx
) {
if( ncol(samples.mtx) != length(wts) ){
stop("The number of columns in samples.mtx must match the length of wts.")
}
n = nrow(samples.mtx)
samples.mtx.permuted = samples.mtx[sample(x = seq_len(n), size = n, replace = F), ]
wts = as.numeric(wts)
## draw n i.i.d. samples (c0) from categorical distribution with probability being `wts`
c0 = sample(x = seq_len(length(wts)), size = n, replace = T, prob = wts)
models = unique(c0)
res.samples = lapply(models, function(j){
nsample = sum(c0 == j)
return( samples.mtx.permuted[1:nsample, j] )
})
res.samples = unlist(res.samples)
return(res.samples)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.