#' Fit the base imputation model using a Bayesian approach
#'
#' @description
#' `fit_mcmc()` fits the base imputation model using a Bayesian approach.
#' This is done through a MCMC method that is implemented in `stan`
#' and is run by using the function `rstan::sampling()`.
#' The function returns the draws from the posterior distribution of the model parameters
#' and the `stanfit` object. Additionally it performs multiple diagnostics checks of the chain
#' and returns warnings in case of any detected issues.
#'
#' @param designmat The design matrix of the fixed effects.
#' @param outcome The response variable. Must be numeric.
#' @param group Character vector containing the group variable.
#' @param subjid Character vector containing the subjects IDs.
#' @param visit Character vector containing the visit variable.
#' @param method A `method` object as generated by [method_bayes()].
#' @param quiet Specify whether the stan sampling log should be printed to the console.
#'
#' @details
#' The Bayesian model assumes a multivariate normal likelihood function and weakly-informative
#' priors for the model parameters: in particular, uniform priors are assumed for the regression
#' coefficients and inverse-Wishart priors for the covariance matrices.
#' The chain is initialized using the REML parameter estimates from MMRM as starting values.
#'
#' The function performs the following steps:
#' 1. Fit MMRM using a REML approach.
#' 2. Prepare the input data for the MCMC fit as described in the `data{}`
#' block of the Stan file. See [prepare_stan_data()] for details.
#' 3. Run the MCMC according the input arguments and using as starting values the REML parameter estimates
#' estimated at point 1.
#' 4. Performs diagnostics checks of the MCMC. See [check_mcmc()] for details.
#' 5. Extract the draws from the model fit.
#'
#' The chains perform `method$n_samples` draws by keeping one every `method$burn_between` iterations. Additionally
#' the first `method$burn_in` iterations are discarded. The total number of iterations will
#' then be `method$burn_in + method$burn_between*method$n_samples`.
#' The purpose of `method$burn_in` is to ensure that the samples are drawn from the stationary
#' distribution of the Markov Chain.
#' The `method$burn_between` aims to keep the draws uncorrelated each from other.
#'
#' @return
#' A named list composed by the following:
#' - `samples`: a named list containing the draws for each parameter. It corresponds to the output of [extract_draws()].
#' - `fit`: a `stanfit` object.
#'
#'
#' @import methods
fit_mcmc <- function(
designmat,
outcome,
group,
subjid,
visit,
method,
quiet = FALSE
) {
# Fit MMRM (needed for Sigma prior parameter and possibly initial values).
mmrm_initial <- fit_mmrm(
designmat = designmat,
outcome = outcome,
subjid = subjid,
visit = visit,
group = group,
cov_struct = "us",
REML = TRUE,
same_cov = method$same_cov
)
if (mmrm_initial$failed) {
stop("Fitting MMRM to original dataset failed")
}
stan_data <- prepare_stan_data(
ddat = designmat,
subjid = subjid,
visit = visit,
outcome = outcome,
group = ife(
isTRUE(method$same_cov),
rep(1, length(group)),
group
)
)
stan_data$Sigma_init <- ife(
isTRUE(method$same_cov),
list(mmrm_initial$sigma[[1]]),
mmrm_initial$sigma
)
control <- complete_control_bayes(
control = method$control,
n_samples = method$n_samples,
quiet = quiet,
stan_data = stan_data,
mmrm_initial = mmrm_initial
)
sampling_args <- c(
list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma")
),
control
)
stan_fit <- record({
do.call(rstan::sampling, sampling_args)
})
if (!is.null(stan_fit$errors)) {
stop(stan_fit$errors)
}
ignorable_warnings <- c(
"Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess",
"Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess"
)
# handle warning: display only warnings if
# 1) the warning is not in ignorable_warnings
warnings <- stan_fit$warnings
warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings]
for (i in warnings_not_allowed) warning(warnings_not_allowed)
fit <- stan_fit$results
check_mcmc(fit, method$n_samples)
draws <- extract_draws(fit, method$n_samples)
ret_obj <- list(
"samples" = draws,
"fit" = fit
)
return(ret_obj)
}
#' Transform array into list of arrays
#'
#' @description
#' Transform an array into list of arrays where the listing
#' is performed on a given dimension.
#'
#' @param a Array with number of dimensions at least 2.
#' @param n Positive integer. Dimension of `a` to be listed.
#'
#' @details
#' For example, if `a` is a 3 dimensional array and `n = 1`,
#' `split_dim(a,n)` returns a list of 2 dimensional arrays (i.e.
#' a list of matrices) where each element of the list is `a[i, , ]`, where
#' `i` takes values from 1 to the length of the first dimension of the array.
#'
#' Example:
#'
#' inputs:
#' `a <- array( c(1,2,3,4,5,6,7,8,9,10,11,12), dim = c(3,2,2))`,
#' which means that:
#' ```
#' a[1,,] a[2,,] a[3,,]
#'
#' [,1] [,2] [,1] [,2] [,1] [,2]
#' --------- --------- ---------
#' 1 7 2 8 3 9
#' 4 10 5 11 6 12
#' ```
#'
#' `n <- 1`
#'
#' output of `res <- split_dim(a,n)` is a list of 3 elements:
#' ```
#' res[[1]] res[[2]] res[[3]]
#'
#' [,1] [,2] [,1] [,2] [,1] [,2]
#' --------- --------- ---------
#' 1 7 2 8 3 9
#' 4 10 5 11 6 12
#' ```
#'
#' @return
#' A list of length `n` of arrays with number of dimensions equal to the
#' number of dimensions of `a` minus 1.
#'
#' @importFrom stats setNames
split_dim <- function(a, n) {
x <- split(
a,
arrayInd(seq_along(a), dim(a))[, n]
)
y <- lapply(
x,
array,
dim = dim(a)[-n],
dimnames = dimnames(a)[-n]
)
setNames(y, dimnames(a)[[n]])
}
#' Extract draws from a `stanfit` object
#'
#' @description
#' Extract draws from a `stanfit` object and convert them into lists.
#'
#' The function `rstan::extract()` returns the draws for a given parameter as an array. This function
#' calls `rstan::extract()` to extract the draws from a `stanfit` object
#' and then convert the arrays into lists.
#'
#' @param stan_fit A `stanfit` object.
#'
#' @param n_samples Number of MCMC draws.
#'
#' @return
#' A named list of length 2 containing:
#' - `beta`: a list of length equal to `n_samples` containing
#' the draws from the posterior distribution of the regression coefficients.
#' - `sigma`: a list of length equal to `n_samples` containing
#' the draws from the posterior distribution of the covariance matrices. Each element
#' of the list is a list with length equal to 1 if `same_cov = TRUE` or equal to the
#' number of groups if `same_cov = FALSE`.
#'
extract_draws <- function(stan_fit, n_samples) {
assertthat::assert_that(assertthat::is.number(n_samples))
pars <- rstan::extract(stan_fit, pars = c("beta", "Sigma"))
names(pars) <- c("beta", "sigma")
##################### from array to list
pars$sigma <- split_dim(pars$sigma, 1) # list of length equal to the number of draws
pars$sigma <- lapply(
pars$sigma,
function(x) split_dim(x, 1)
)
assertthat::assert_that(length(pars$sigma) >= n_samples)
pars$sigma <- pars$sigma[seq_len(n_samples)]
pars$beta <- split_dim(pars$beta, 1)
pars$beta <- lapply(pars$beta, as.vector)
assertthat::assert_that(length(pars$beta) >= n_samples)
pars$beta <- pars$beta[seq_len(n_samples)]
return(pars)
}
#' Extract the Effective Sample Size (ESS) from a `stanfit` object
#'
#' @param stan_fit A `stanfit` object.
#'
#' @return
#' A named vector containing the ESS for each parameter of the model.
#'
get_ESS <- function(stan_fit) {
return(rstan::summary(stan_fit, pars = c("beta", "Sigma"))$summary[, "n_eff"])
}
#' Diagnostics of the MCMC based on ESS
#'
#' @description
#' Check the quality of the MCMC draws from the posterior distribution
#' by checking whether the relative ESS is sufficiently large.
#'
#' @inheritParams check_mcmc
#'
#' @details
#' `check_ESS()` works as follows:
#' 1. Extract the ESS from `stan_fit` for each parameter of the model.
#' 2. Compute the relative ESS (i.e. the ESS divided by the number of draws).
#' 3. Check whether for any of the parameter the ESS is lower than `threshold`.
#' If for at least one parameter the relative ESS is below the threshold,
#' a warning is thrown.
#'
#' @inherit check_mcmc return
#'
check_ESS <- function(stan_fit, n_draws, threshold_lowESS = 0.4) {
ESS <- get_ESS(stan_fit)
n_low_ESS <- sum((ESS / n_draws) < threshold_lowESS)
if (any((ESS / n_draws) < threshold_lowESS)) {
warning(
paste0(
"The Effective Sample Size is below ",
threshold_lowESS * 100,
"% for ",
n_low_ESS,
" parameters. Please consider increasing burn-in and/or burn-between, or the number of samples"
),
call. = FALSE
)
}
return(invisible(NULL))
}
#' Diagnostics of the MCMC based on HMC-related measures.
#'
#' @description
#' Check that:
#' 1. There are no divergent iterations.
#' 2. The Bayesian Fraction of Missing Information (BFMI) is sufficiently low.
#' 3. The number of iterations that saturated the max treedepth is zero.
#'
#' Please see `rstan::check_hmc_diagnostics()` for details.
#'
#' @param stan_fit A `stanfit` object.
#'
#' @inherit check_mcmc return
#'
check_hmc_diagn <- function(stan_fit) {
if (
any(rstan::get_divergent_iterations(stan_fit)) || # draws "out of the distribution"
isTRUE(rstan::get_bfmi(stan_fit) < 0.2) || # exploring well the target distribution
any(rstan::get_max_treedepth_iterations(stan_fit)) # efficiency of the algorithm
) {
warning(
"Lack of efficiency in the HMC sampler: please consider increasing the burn-in period.",
call. = FALSE
)
}
return(invisible(NULL))
}
#' Diagnostics of the MCMC
#'
#' @param stan_fit A `stanfit` object.
#' @param n_draws Number of MCMC draws.
#' @param threshold_lowESS A number in `[0,1]` indicating the minimum acceptable
#' value of the relative ESS. See details.
#'
#' @details
#' Performs checks of the quality of the MCMC. See [check_ESS()] and [check_hmc_diagn()]
#' for details.
#'
#' @returns
#' A warning message in case of detected problems.
#'
check_mcmc <- function(stan_fit, n_draws, threshold_lowESS = 0.4) {
check_ESS(
stan_fit = stan_fit,
n_draws = n_draws,
threshold_lowESS = threshold_lowESS
)
check_hmc_diagn(stan_fit)
return(invisible(NULL))
}
#' QR decomposition
#'
#' @description
#' QR decomposition as defined in the
#' [Stan user's guide (section 1.2)](https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html).
#'
#' @param mat A matrix to perform the QR decomposition on.
QR_decomp <- function(mat) {
qr_obj <- qr(mat)
N <- nrow(mat)
Q <- qr.Q(qr = qr_obj) * sqrt(N - 1)
R <- qr.R(qr = qr_obj) / sqrt(N - 1)
ret_obj <- list(
Q = Q,
R = R
)
return(ret_obj)
}
#' Prepare input data to run the Stan model
#'
#' @description
#' Prepare input data to run the Stan model.
#' Creates / calculates all the required inputs as required by the `data{}` block of the MMRM Stan program.
#'
#' @param ddat A design matrix
#' @param subjid Character vector containing the subjects IDs.
#' @param visit Vector containing the visits.
#' @param outcome Numeric vector containing the outcome variable.
#' @param group Vector containing the group variable.
#'
#' @details
#' - The `group` argument determines which covariance matrix group the subject belongs to. If you
#' want all subjects to use a shared covariance matrix then set group to "1" for everyone.
#'
#'
#' @returns
#' A `stan_data` object. A named list as per `data{}` block of the related Stan file. In particular it returns:
#'
#' - N - The number of rows in the design matrix
#' - P - The number of columns in the design matrix
#' - G - The number of distinct covariance matrix groups (i.e. `length(unique(group))`)
#' - n_visit - The number of unique outcome visits
#' - n_pat - The total number of pattern groups (as defined by missingness patterns & covariance group)
#' - pat_G - Index for which Sigma each pattern group should use
#' - pat_n_pt - number of patients within each pattern group
#' - pat_n_visit - number of non-missing visits in each pattern group
#' - pat_sigma_index - rows/cols from Sigma to subset on for the pattern group (padded by 0's)
#' - y - The outcome variable
#' - Q - design matrix (after QR decomposition)
#' - R - R matrix from the QR decomposition of the design matrix
prepare_stan_data <- function(ddat, subjid, visit, outcome, group) {
assert_that(
is.factor(group) | is.numeric(group),
is.factor(visit) | is.numeric(visit),
is.character(subjid) | is.factor(subjid),
is.numeric(outcome),
is.data.frame(ddat) | is.matrix(ddat),
length(group) == length(visit),
length(subjid) == length(visit),
length(outcome) == length(group),
length(outcome) == nrow(ddat),
length(unique(subjid)) * length(unique(visit)) == nrow(ddat)
)
design_variables <- paste0("V", seq_len(ncol(ddat)))
ddat <- as.data.frame(ddat)
names(ddat) <- design_variables
ddat$subjid <- as.character(subjid)
ddat$visit <- visit
ddat$outcome <- outcome
ddat$group <- group
ddat$is_avail <- (!is.na(ddat$outcome)) * 1
ddat <- remove_if_all_missing(ddat)
dat_pgroups <- get_pattern_groups(ddat)
ddat2 <- merge(ddat, dat_pgroups, by = "subjid", all = TRUE)
ddat2 <- sort_by(ddat2, c("pgroup", "subjid", "visit"))
assert_that(nrow(ddat2) == nrow(ddat))
ddat3 <- ddat2[!is.na(ddat2$outcome), ]
dmat <- as.matrix(ddat3[, design_variables])
qr <- QR_decomp(dmat)
dat_pgroups_u <- get_pattern_groups_unique(dat_pgroups)
stan_dat <- list(
N = nrow(dmat),
P = ncol(dmat),
G = length(unique(group)),
n_visit = length(levels(visit)),
n_pat = nrow(dat_pgroups_u),
pat_G = as_stan_array(dat_pgroups_u$group_n),
pat_n_pt = as_stan_array(dat_pgroups_u$n),
pat_n_visit = as_stan_array(dat_pgroups_u$n_avail),
pat_sigma_index = as_indices(dat_pgroups_u$pattern),
y = ddat3$outcome,
Q = qr$Q,
R = qr$R
)
class(stan_dat) <- c("list", "stan_data")
validate(stan_dat)
return(stan_dat)
}
#' Get Pattern Summary
#'
#' Takes a dataset of pattern information and creates a summary dataset of it
#' with just 1 row per pattern
#'
#' @param patterns A `data.frame` with the columns `pgroup`, `pattern` and `group`
#' @details
#' - The column `pgroup` must be a numeric vector indicating which pattern group the patient belongs to
#' - The column `pattern` must be a character string of `0`'s or `1`'s. It must be identical for all
#' rows within the same `pgroup`
#' - The column `group` must be a character / numeric vector indicating which covariance group the observation
#' belongs to. It must be identical within the same `pgroup`
get_pattern_groups_unique <- function(patterns) {
u_pats <- unique(patterns[, c("pgroup", "pattern", "group")])
u_pats <- sort_by(u_pats, "pgroup")
u_pats$group_n <- as.numeric(u_pats$group)
u_pats$n <- as.numeric(tapply(patterns$pgroup, patterns$pgroup, length))
u_pats$n_avail <- vapply(
strsplit(u_pats$pattern, ""),
function(x) sum(as.numeric(x)),
numeric(1)
)
u_pats$group <- NULL
return(u_pats)
}
#' Determine patients missingness group
#'
#' Takes a design matrix with multiple rows per subject and returns a dataset
#' with 1 row per subject with a new column `pgroup` indicating which group
#' the patient belongs to (based upon their missingness pattern and treatment group)
#'
#' @param ddat a `data.frame` with columns `subjid`, `visit`, `group`, `is_avail`
#' @details
#' - The column `is_avail` must be a character or numeric `0` or `1`
get_pattern_groups <- function(ddat) {
ddat <- sort_by(ddat, c("subjid", "visit"))[, c("subjid", "group", "is_avail")]
pt_pattern <- tapply(ddat$is_avail, ddat$subjid, paste0, collapse = "")
dat_pattern <- data.frame(
subjid = names(pt_pattern),
pattern = pt_pattern,
stringsAsFactors = FALSE,
row.names = NULL
)
dat_group <- unique(ddat[, c("subjid", "group")])
assert_that(
nrow(dat_group) == length(unique(ddat$subjid)),
nrow(dat_group) == nrow(dat_pattern),
all(ddat$subjid %in% dat_group$subjid)
)
dat_pgroups <- merge(dat_group, dat_pattern, all = TRUE, by = "subjid")
dat_pgroups$pgroup <- as_strata(dat_pgroups$pattern, dat_pgroups$group)
return(dat_pgroups)
}
#' Convert indicator to index
#'
#' Converts a string of 0's and 1's into index positions of the 1's
#' padding the results by 0's so they are all the same length
#'
#' i.e.
#' ```
#' patmap(c("1101", "0001")) -> list(c(1,2,4,999), c(4,999, 999, 999))
#' ```
#'
#' @param x a character vector whose values are all either "0" or "1". All elements of
#' the vector must be the same length
as_indices <- function(x) {
assert_that(
length(unique(nchar(x))) == 1,
msg = "all values of x must be the same length"
)
assert_that(
unique(nchar(x)) < 999,
msg = "Number of pattern groups must be < 999"
)
len <- max(nchar(x))
lapply(
strsplit(x, ""),
function(x) {
assert_that(
all(x %in% c("0", "1")),
msg = "All values of x must be 0 or 1"
)
temp <- rep(999, len)
y <- which(x == "1")
temp[seq_along(y)] <- y
return(temp)
}
)
}
#' As array
#'
#' Converts a numeric value of length 1 into a 1 dimension array.
#' This is to avoid type errors that are thrown by stan when length 1 numeric vectors
#' are provided by R for stan::vector inputs
#'
#' @param x a numeric vector
as_stan_array <- function(x) {
ife(
length(x) == 1,
array(x, dim = 1),
x
)
}
#' Remove subjects from dataset if they have no observed values
#'
#' This function takes a `data.frame` with variables `visit`, `outcome` & `subjid`.
#' It then removes all rows for a given `subjid` if they don't have any non-missing
#' values for `outcome`.
#'
#' @param dat a `data.frame`
remove_if_all_missing <- function(dat) {
n_visit <- length(unique(dat$visit))
n_miss <- tapply(dat$outcome, dat$subjid, function(x) sum(is.na(x)))
remove_me <- Filter(function(x) x == n_visit, n_miss)
remove_me_pt <- names(remove_me)
dat[!dat$subjid %in% remove_me_pt, ]
}
#' Validate a `stan_data` object
#'
#' @param x A `stan_data` object.
#' @param ... Not used.
#'
#' @export
validate.stan_data <- function(x, ...) {
assert_that(
x$N == nrow(x$Q),
x$P == ncol(x$Q),
sum(x$pat_n_visit * x$pat_n_pt) == nrow(x$Q),
ncol(x$Q) == ncol(x$R),
ncol(x$R) == nrow(x$R),
length(x$y) == nrow(x$Q),
length(x$pat_G) == length(x$pat_n_pt),
length(x$pat_G) == length(x$pat_n_visit),
length(x$pat_G) == length(x$pat_sigma_index),
length(unique(lapply(x$pat_sigma_index, length))) == 1,
length(x$pat_sigma_index[[1]]) == x$n_visit,
all(vapply(x$pat_sigma_index, function(z) all(z %in% c(seq_len(x$n_visit), 999)), logical(1))),
msg = "Invalid Stan Data Object"
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.