View source: R/compute_ensemble_weights.R
| compute.ensemble.weights | R Documentation |
Compute model averaging weights for a set of Bayesian models using Bayesian model averaging (BMA), pseudo-BMA,
pseudo-BMA+ (pseudo-BMA with the Bayesian bootstrap), or stacking. This function takes a list of model fit objects,
each containing posterior samples from a generalized linear model (GLM) or survival model, and returns normalized
weights that can be used for model comparison or combining posterior samples using functions like sample.ensemble().
compute.ensemble.weights(
fit.list,
type = c("bma", "pseudobma", "pseudobma+", "stacking"),
prior.prob = NULL,
bridge.args = NULL,
loo.args = NULL,
loo.wts.args = NULL,
iter_warmup = 1000,
iter_sampling = 1000,
chains = 4,
...
)
fit.list |
a list of model fit objects returned by functions in the |
type |
a character string specifying the ensemble method used to compute model weights. Options are "bma" (Bayesian model averaging (BMA)), "pseudobma" (pseudo-BMA without the Bayesian bootstrap), "pseudobma+" (pseudo-BMA with the Bayesian bootstrap), and "stacking". |
prior.prob |
a numeric vector of prior model probabilities, used only when |
bridge.args |
a |
loo.args |
a |
loo.wts.args |
a |
iter_warmup |
number of warmup iterations to run per chain. Used only when computing the log marginal likelihood
(i.e., when |
iter_sampling |
number of post-warmup iterations to run per chain. Used only when computing the log marginal likelihood
(i.e., when |
chains |
number of Markov chains to run. Used only when computing the log marginal likelihood (i.e., when
|
... |
arguments passed to |
The input fit.list should be a list of outputs from model fitting functions in the hdbayes package, such as glm.pp()
(for generalized linear models), aft.pp() (for accelerated failure time models), pwe.pp() (for piecewise exponential (PWE)
models), or curepwe.pp() (for mixture cure rate models with a PWE component for the non-cured population). To compute
pseudo-BMA, pseudo-BMA+, or stacking weights, each fit must include pointwise log-likelihood values. To ensure this, the
fitting function must be called with get.loglik = TRUE.
The arguments related to Markov chain Monte Carlo (MCMC) sampling are utilized to compute the logarithm of the normalizing constant for BMA, if applicable.
The function returns a list with the following objects
a numeric vector of normalized model weights corresponding to the models in fit.list. The names of the
weights are made unique based on the model identifiers.
a character string indicating the method used to compute the model weights (e.g., "bma", "pseudobma", "pseudobma+", or "stacking")
a list of log marginal likelihood estimation results, returned only when type = "bma"
a list of outputs from loo::loo(), returned only when type is "pseudobma", "pseudobma+", or "stacking"
Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). Using stacking to average Bayesian predictive distributions. Bayesian Analysis, 13(3), 917–1007.
Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5), 1413–1432.
sample.ensemble()
if (instantiate::stan_cmdstan_exists()) {
if(requireNamespace("survival")){
library(survival)
data(E1684)
data(E1690)
## replace 0 failure times with 0.50 days
E1684$failtime[E1684$failtime == 0] = 0.50/365.25
E1690$failtime[E1690$failtime == 0] = 0.50/365.25
E1684$cage = as.numeric(scale(E1684$age))
E1690$cage = as.numeric(scale(E1690$age))
data_list = list(currdata = E1690, histdata = E1684)
nbreaks = 3
probs = 1:nbreaks / nbreaks
breaks = as.numeric(
quantile(E1690[E1690$failcens==1, ]$failtime, probs = probs)
)
breaks = c(0, breaks)
breaks[length(breaks)] = max(10000, 1000 * breaks[length(breaks)])
fit.pwe.pp = pwe.pp(
formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
data.list = data_list,
breaks = breaks,
a0 = 0.5,
get.loglik = TRUE,
chains = 1, iter_warmup = 1000, iter_sampling = 2000
)
fit.pwe.post = pwe.post(
formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
data.list = data_list,
breaks = breaks,
get.loglik = TRUE,
chains = 1, iter_warmup = 1000, iter_sampling = 2000
)
fit.aft.post = aft.post(
formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
data.list = data_list,
dist = "weibull",
beta.sd = 10,
get.loglik = TRUE,
chains = 1, iter_warmup = 1000, iter_sampling = 2000
)
compute.ensemble.weights(
fit.list = list(fit.pwe.post, fit.pwe.pp, fit.aft.post),
type = "pseudobma+",
loo.args = list(save_psis = FALSE),
loo.wts.args = list(optim_method="BFGS")
)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.