R/mcmc.R

Defines functions validate.stan_data remove_if_all_missing as_stan_array as_indices get_pattern_groups get_pattern_groups_unique prepare_stan_data QR_decomp check_mcmc check_hmc_diagn check_ESS get_ESS extract_draws split_dim fit_mcmc

Documented in as_indices as_stan_array check_ESS check_hmc_diagn check_mcmc extract_draws fit_mcmc get_ESS get_pattern_groups get_pattern_groups_unique prepare_stan_data QR_decomp remove_if_all_missing split_dim validate.stan_data

#' Fit the base imputation model using a Bayesian approach
#'
#' @description
#' `fit_mcmc()` fits the base imputation model using a Bayesian approach.
#' This is done through a MCMC method that is implemented in `stan`
#' and is run by using the function [rstan::sampling()].
#' The function returns the draws from the posterior distribution of the model parameters
#' and the `stanfit` object. Additionally it performs multiple diagnostics checks of the chain
#' and returns warnings in case of any detected issues.
#'
#' @param designmat The design matrix of the fixed effects.
#' @param outcome The response variable. Must be numeric.
#' @param group Character vector containing the group variable.
#' @param subjid Character vector containing the subjects IDs.
#' @param visit Character vector containing the visit variable.
#' @param method A `method` object as generated by [method_bayes()].
#' @param quiet Specify whether the stan sampling log should be printed to the console.
#'
#' @details
#' The Bayesian model assumes a multivariate normal likelihood function and weakly-informative
#' priors for the model parameters: in particular, uniform priors are assumed for the regression
#' coefficients and inverse-Wishart priors for the covariance matrices.
#' The chain is initialized using the REML parameter estimates from MMRM as starting values.
#'
#' The function performs the following steps:
#' 1. Fit MMRM using a REML approach.
#' 2. Prepare the input data for the MCMC fit as described in the `data{}`
#' block of the Stan file. See [prepare_stan_data()] for details.
#' 3. Run the MCMC according the input arguments and using as starting values the REML parameter estimates
#' estimated at point 1.
#' 4. Performs diagnostics checks of the MCMC. See [check_mcmc()] for details.
#' 5. Extract the draws from the model fit.
#'
#' The chains perform `method$n_samples` draws by keeping one every `method$burn_between` iterations. Additionally
#' the first `method$burn_in` iterations are discarded. The total number of iterations will
#' then be `method$burn_in + method$burn_between*method$n_samples`.
#' The purpose of `method$burn_in` is to ensure that the samples are drawn from the stationary
#' distribution of the Markov Chain.
#' The `method$burn_between` aims to keep the draws uncorrelated each from other.
#'
#' @return
#' A named list composed by the following:
#' - `samples`: a named list containing the draws for each parameter. It corresponds to the output of [extract_draws()].
#' - `fit`: a `stanfit` object.
#'
#'
#' @import Rcpp
#' @import methods
#' @importFrom rstan sampling
#' @useDynLib rbmi, .registration = TRUE
fit_mcmc <- function(
    designmat,
    outcome,
    group,
    subjid,
    visit,
    method,
    quiet = FALSE
) {

    n_imputations <- method$n_samples
    burn_in <- method$burn_in
    seed <- method$seed
    burn_between <- method$burn_between
    same_cov <- method$same_cov

    # fit MMRM (needed for initial values)
    mmrm_initial <- fit_mmrm(
        designmat = designmat,
        outcome = outcome,
        subjid = subjid,
        visit = visit,
        group = group,
        cov_struct = "us",
        REML = TRUE,
        same_cov = same_cov
    )

    if (mmrm_initial$failed) {
        stop("Fitting MMRM to original dataset failed")
    }

    stan_data <- prepare_stan_data(
        ddat = designmat,
        subjid = subjid,
        visit = visit,
        outcome = outcome,
        group = ife(same_cov == TRUE, rep(1, length(group)), group)
    )

    stan_data$Sigma_init <- ife(
        same_cov == TRUE,
        list(mmrm_initial$sigma[[1]]),
        mmrm_initial$sigma
    )

    sampling_args <- list(
        object = stanmodels$MMRM,
        data = stan_data,
        pars = c("beta", "Sigma"),
        chains = 1,
        warmup = burn_in,
        thin = burn_between,
        iter = burn_in + burn_between * n_imputations,
        init = list(list(
            theta = as.vector(stan_data$R %*% mmrm_initial$beta),
            sigma = mmrm_initial$sigma
        )),
        refresh = ife(
            quiet,
            0,
            (burn_in + burn_between * n_imputations) / 10
        )
    )

    assert_that(
        !is.na(seed),
        !is.null(seed),
        is.numeric(seed),
        msg = "mcmc seed is invalid"
    )
    sampling_args$seed <- seed

    stan_fit <- record({
        do.call(sampling, sampling_args)
    })

    if (!is.null(stan_fit$errors)) {
        stop(stan_fit$errors)
    }

    ignorable_warnings <- c(
        "Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess",
        "Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess"
    )

    # handle warning: display only warnings if
    # 1) the warning is not in ignorable_warnings
    warnings <- stan_fit$warnings
    warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings]
    for (i in warnings_not_allowed) warning(warnings_not_allowed)

    fit <- stan_fit$results
    check_mcmc(fit, n_imputations)

    draws <- extract_draws(fit)

    ret_obj <- list(
        "samples" = draws,
        "fit" = fit
    )

    return(ret_obj)
}




#' Transform array into list of arrays
#'
#' @description
#' Transform an array into list of arrays where the listing
#' is performed on a given dimension.
#'
#' @param a Array with number of dimensions at least 2.
#' @param n Positive integer. Dimension of `a` to be listed.
#'
#' @details
#' For example, if `a` is a 3 dimensional array and `n = 1`,
#' `split_dim(a,n)` returns a list of 2 dimensional arrays (i.e.
#' a list of matrices) where each element of the list is `a[i, , ]`, where
#' `i` takes values from 1 to the length of the first dimension of the array.
#'
#' Example:
#'
#' inputs:
#' `a <- array( c(1,2,3,4,5,6,7,8,9,10,11,12), dim = c(3,2,2))`,
#' which means that:
#' ```
#' a[1,,]     a[2,,]     a[3,,]
#'
#' [,1] [,2]  [,1] [,2]  [,1] [,2]
#' ---------  ---------  ---------
#'  1    7     2    8     3    9
#'  4    10    5    11    6    12
#' ```
#'
#' `n <- 1`
#'
#' output of `res <- split_dim(a,n)` is a list of 3 elements:
#' ```
#' res[[1]]   res[[2]]   res[[3]]
#'
#' [,1] [,2]  [,1] [,2]  [,1] [,2]
#' ---------  ---------  ---------
#'  1    7     2    8     3    9
#'  4    10    5    11    6    12
#' ```
#'
#' @return
#' A list of length `n` of arrays with number of dimensions equal to the
#' number of dimensions of `a` minus 1.
#'
#' @importFrom stats setNames
split_dim <- function(a, n) {
    x <- split(
        a,
        arrayInd(seq_along(a), dim(a))[, n]
    )

    y <- lapply(
        x,
        array,
        dim = dim(a)[-n],
        dimnames = dimnames(a)[-n]
    )

    setNames(y, dimnames(a)[[n]])
}


#' Extract draws from a `stanfit` object
#'
#' @description
#' Extract draws from a `stanfit` object and convert them into lists.
#'
#' The function [rstan::extract()] returns the draws for a given parameter as an array. This function
#' calls [rstan::extract()] to extract the draws from a `stanfit` object
#' and then convert the arrays into lists.
#'
#' @param stan_fit A `stanfit` object.
#'
#' @return
#' A named list of length 2 containing:
#' - `beta`: a list of length equal to the number of draws containing
#'   the draws from the posterior distribution of the regression coefficients.
#' - `sigma`: a list of length equal to the number of draws containing
#'   the draws from the posterior distribution of the covariance matrices. Each element
#'   of the list is a list with length equal to 1 if `same_cov = TRUE` or equal to the
#'   number of groups if `same_cov = FALSE`.
#'
#' @importFrom rstan extract
extract_draws <- function(stan_fit) {

    pars <- extract(stan_fit, pars = c("beta", "Sigma"))
    names(pars) <- c("beta", "sigma")

    ##################### from array to list
    pars$sigma <- split_dim(pars$sigma, 1) # list of length equal to the number of draws

    pars$sigma <- lapply(
        pars$sigma,
        function(x) split_dim(x, 1)
    )

    pars$beta <- split_dim(pars$beta, 1)
    pars$beta <- lapply(pars$beta, as.vector)

    return(pars)
}


#' Extract the Effective Sample Size (ESS) from a `stanfit` object
#'
#' @param stan_fit A `stanfit` object.
#'
#' @return
#' A named vector containing the ESS for each parameter of the model.
#'
#' @importFrom rstan summary
get_ESS <- function(stan_fit) {
    return(rstan::summary(stan_fit, pars = c("beta", "Sigma"))$summary[, "n_eff"])
}


#' Diagnostics of the MCMC based on ESS
#'
#' @description
#' Check the quality of the MCMC draws from the posterior distribution
#' by checking whether the relative ESS is sufficiently large.
#'
#' @inheritParams check_mcmc
#'
#' @details
#' `check_ESS()` works as follows:
#' 1. Extract the ESS from `stan_fit` for each parameter of the model.
#' 2. Compute the relative ESS (i.e. the ESS divided by the number of draws).
#' 3. Check whether for any of the parameter the ESS is lower than `threshold`.
#'    If for at least one parameter the relative ESS is below the threshold,
#'    a warning is thrown.
#'
#' @inherit check_mcmc return
#'
check_ESS <- function(stan_fit, n_draws, threshold_lowESS = 0.4) {

    ESS <- get_ESS(stan_fit)

    n_low_ESS <- sum((ESS / n_draws) < threshold_lowESS)

    if (any((ESS / n_draws) < threshold_lowESS)) {
        warning(
            paste0(
                "The Effective Sample Size is below ",
                threshold_lowESS * 100,
                "% for ",
                n_low_ESS,
                " parameters. Please consider increasing burn-in and/or burn-between, or the number of samples"
            ),
            call. = FALSE
        )
    }

    return(invisible(NULL))
}


#' Diagnostics of the MCMC based on HMC-related measures.
#'
#' @description
#' Check that:
#' 1. There are no divergent iterations.
#' 2. The Bayesian Fraction of Missing Information (BFMI) is sufficiently low.
#' 3. The number of iterations that saturated the max treedepth is zero.
#'
#' Please see [rstan::check_hmc_diagnostics()] for details.
#'
#' @param stan_fit A `stanfit` object.
#'
#' @inherit check_mcmc return
#'
check_hmc_diagn <- function(stan_fit) {

    if (
        any(rstan::get_divergent_iterations(stan_fit)) | # draws "out of the distribution"
        isTRUE(rstan::get_bfmi(stan_fit) < 0.2) | # exploring well the target distribution
        any(rstan::get_max_treedepth_iterations(stan_fit)) # efficiency of the algorithm
    ) {
        warning(
            "Lack of efficiency in the HMC sampler: please consider increasing the burn-in period.",
            call. = FALSE
        )
    }

    return(invisible(NULL))
}


#' Diagnostics of the MCMC
#'
#' @param stan_fit A `stanfit` object.
#' @param n_draws Number of MCMC draws.
#' @param threshold_lowESS A number in `[0,1]` indicating the minimum acceptable
#' value of the relative ESS. See details.
#'
#' @details
#' Performs checks of the quality of the MCMC. See [check_ESS()] and [check_hmc_diagn()]
#' for details.
#'
#' @returns
#' A warning message in case of detected problems.
#'
check_mcmc <- function(stan_fit, n_draws, threshold_lowESS = 0.4) {

    check_ESS(
        stan_fit = stan_fit,
        n_draws = n_draws,
        threshold_lowESS = threshold_lowESS
    )

    check_hmc_diagn(stan_fit)

    return(invisible(NULL))
}



#' QR decomposition
#'
#' @description
#' QR decomposition as defined in the
#' [Stan user's guide (section 1.2)](https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html).
#'
#' @param mat A matrix to perform the QR decomposition on.
QR_decomp <- function(mat) {
    qr_obj <- qr(mat)
    N <- nrow(mat)
    Q <- qr.Q(qr = qr_obj) * sqrt(N - 1)
    R <- qr.R(qr = qr_obj) / sqrt(N - 1)

    ret_obj <- list(
        Q = Q,
        R = R
    )

    return(ret_obj)
}



#' Prepare input data to run the Stan model
#'
#' @description
#' Prepare input data to run the Stan model.
#' Creates / calculates all the required inputs as required by the `data{}` block of the MMRM Stan program.
#'
#' @param ddat A design matrix
#' @param subjid Character vector containing the subjects IDs.
#' @param visit Vector containing the visits.
#' @param outcome Numeric vector containing the outcome variable.
#' @param group Vector containing the group variable.
#'
#' @details
#' - The `group` argument determines which covariance matrix group the subject belongs to. If you
#' want all subjects to use a shared covariance matrix then set group to "1" for everyone.
#'
#'
#' @returns
#' A `stan_data` object. A named list as per `data{}` block of the related Stan file. In particular it returns:
#'
#' - N - The number of rows in the design matrix
#' - P - The number of columns in the design matrix
#' - G - The number of distinct covariance matrix groups (i.e. `length(unique(group))`)
#' - n_visit - The number of unique outcome visits
#' - n_pat - The total number of pattern groups (as defined by missingness patterns & covariance group)
#' - pat_G - Index for which Sigma each pattern group should use
#' - pat_n_pt - number of patients within each pattern group
#' - pat_n_visit - number of non-missing visits in each pattern group
#' - pat_sigma_index - rows/cols from Sigma to subset on for the pattern group (padded by 0's)
#' - y - The outcome variable
#' - Q - design matrix (after QR decomposition)
#' - R - R matrix from the QR decomposition of the design matrix
prepare_stan_data <- function(ddat, subjid, visit, outcome, group) {

    assert_that(
        is.factor(group) | is.numeric(group),
        is.factor(visit) | is.numeric(visit),
        is.character(subjid) | is.factor(subjid),
        is.numeric(outcome),
        is.data.frame(ddat) | is.matrix(ddat),
        length(group) == length(visit),
        length(subjid) == length(visit),
        length(outcome) == length(group),
        length(outcome) == nrow(ddat),
        length(unique(subjid)) * length(unique(visit)) == nrow(ddat)
    )

    design_variables <- paste0("V", seq_len(ncol(ddat)))
    ddat <- as.data.frame(ddat)
    names(ddat) <- design_variables
    ddat$subjid <- as.character(subjid)
    ddat$visit <- visit
    ddat$outcome <- outcome
    ddat$group <- group
    ddat$is_avail <- (!is.na(ddat$outcome)) * 1

    ddat <- remove_if_all_missing(ddat)

    dat_pgroups <- get_pattern_groups(ddat)

    ddat2 <- merge(ddat, dat_pgroups, by = "subjid", all = TRUE)
    ddat2 <- sort_by(ddat2, c("pgroup", "subjid", "visit"))
    assert_that(nrow(ddat2) == nrow(ddat))

    ddat3 <- ddat2[!is.na(ddat2$outcome), ]

    dmat <- as.matrix(ddat3[, design_variables])

    qr <- QR_decomp(dmat)

    dat_pgroups_u <- get_pattern_groups_unique(dat_pgroups)

    stan_dat <- list(
        N = nrow(dmat),
        P = ncol(dmat),
        G = length(unique(group)),
        n_visit = length(levels(visit)),
        n_pat = nrow(dat_pgroups_u),
        pat_G = as_stan_array(dat_pgroups_u$group_n),
        pat_n_pt = as_stan_array(dat_pgroups_u$n),
        pat_n_visit = as_stan_array(dat_pgroups_u$n_avail),
        pat_sigma_index = as_indices(dat_pgroups_u$pattern),
        y = ddat3$outcome,
        Q = qr$Q,
        R = qr$R
    )

    class(stan_dat) <- c("list", "stan_data")
    validate(stan_dat)
    return(stan_dat)
}


#' Get Pattern Summary
#'
#' Takes a dataset of pattern information and creates a summary dataset of it
#' with just 1 row per pattern
#'
#' @param patterns A `data.frame` with the columns `pgroup`, `pattern` and `group`
#' @details
#' - The column `pgroup` must be a numeric vector indicating which pattern group the patient belongs to
#' - The column `pattern` must be a character string of `0`'s or `1`'s. It must be identical for all
#' rows within the same `pgroup`
#' - The column `group` must be a character / numeric vector indicating which covariance group the observation
#' belongs to. It must be identical within the same `pgroup`
get_pattern_groups_unique <- function(patterns) {
    u_pats <- unique(patterns[, c("pgroup", "pattern", "group")])
    u_pats <- sort_by(u_pats, "pgroup")
    u_pats$group_n <- as.numeric(u_pats$group)
    u_pats$n <- as.numeric(tapply(patterns$pgroup, patterns$pgroup, length))
    u_pats$n_avail <- vapply(
        strsplit(u_pats$pattern, ""),
        function(x) sum(as.numeric(x)),
        numeric(1)
    )
    u_pats$group <- NULL
    return(u_pats)
}


#' Determine patients missingness group
#'
#' Takes a design matrix with multiple rows per subject and returns a dataset
#' with 1 row per subject with a new column `pgroup` indicating which group
#' the patient belongs to (based upon their missingness pattern and treatment group)
#'
#' @param ddat a `data.frame` with columns `subjid`, `visit`, `group`, `is_avail`
#' @details
#' - The column `is_avail` must be a character or numeric `0` or `1`
get_pattern_groups <- function(ddat) {
    ddat <- sort_by(ddat, c("subjid", "visit"))[, c("subjid", "group", "is_avail")]

    pt_pattern <- tapply(ddat$is_avail, ddat$subjid, paste0, collapse = "")
    dat_pattern <- data.frame(
        subjid = names(pt_pattern),
        pattern = pt_pattern,
        stringsAsFactors = FALSE,
        row.names = NULL
    )

    dat_group <- unique(ddat[, c("subjid", "group")])

    assert_that(
        nrow(dat_group) == length(unique(ddat$subjid)),
        nrow(dat_group) == nrow(dat_pattern),
        all(ddat$subjid %in% dat_group$subjid)
    )

    dat_pgroups <- merge(dat_group, dat_pattern, all = TRUE, by = "subjid")
    dat_pgroups$pgroup <- as_strata(dat_pgroups$pattern, dat_pgroups$group)
    return(dat_pgroups)
}


#' Convert indicator to index
#'
#' Converts a string of 0's and 1's into index positions of the 1's
#' padding the results by 0's so they are all the same length
#'
#' i.e.
#' ```
#' patmap(c("1101", "0001"))  ->   list(c(1,2,4,999), c(4,999, 999, 999))
#' ```
#'
#' @param x a character vector whose values are all either "0" or "1". All elements of
#' the vector must be the same length
as_indices <- function(x) {

    assert_that(
        length(unique(nchar(x))) == 1,
        msg = "all values of x must be the same length"
    )
    assert_that(
        unique(nchar(x)) < 999,
        msg = "Number of pattern groups must be < 999"
    )

    len <- max(nchar(x))
    lapply(
        strsplit(x, ""),
        function(x) {
            assert_that(
                all(x %in% c("0", "1")),
                msg = "All values of x must be 0 or 1"
            )
            temp <- rep(999, len)
            y <- which(x == "1")
            temp[seq_along(y)] <- y
            return(temp)
        }
    )
}

#' As array
#'
#' Converts a numeric value of length 1 into a 1 dimension array.
#' This is to avoid type errors that are thrown by stan when length 1 numeric vectors
#' are provided by R for stan::vector inputs
#'
#' @param x a numeric vector
as_stan_array <- function(x) {
    ife(
        length(x) == 1,
        array(x, dim = 1),
        x
    )
}


#' Remove subjects from dataset if they have no observed values
#'
#' This function takes a `data.frame` with variables `visit`, `outcome` & `subjid`.
#' It then removes all rows for a given `subjid` if they don't have any non-missing
#' values for `outcome`.
#'
#' @param dat a `data.frame`
remove_if_all_missing <- function(dat) {
    n_visit <- length(unique(dat$visit))
    n_miss <- tapply(dat$outcome, dat$subjid, function(x) sum(is.na(x)))
    remove_me <- Filter(function(x) x == n_visit, n_miss)
    remove_me_pt <- names(remove_me)
    dat[!dat$subjid %in% remove_me_pt, ]
}


#' Validate a `stan_data` object
#'
#' @param x A `stan_data` object.
#' @param ... Not used.
#'
#' @export
validate.stan_data <- function(x, ...) {
    assert_that(
        x$N == nrow(x$Q),
        x$P == ncol(x$Q),
        sum(x$pat_n_visit * x$pat_n_pt) == nrow(x$Q),
        ncol(x$Q) == ncol(x$R),
        ncol(x$R) == nrow(x$R),
        length(x$y) == nrow(x$Q),
        length(x$pat_G) == length(x$pat_n_pt),
        length(x$pat_G) == length(x$pat_n_visit),
        length(x$pat_G) == length(x$pat_sigma_index),
        length(unique(lapply(x$pat_sigma_index, length))) == 1,
        length(x$pat_sigma_index[[1]]) == x$n_visit,
        all(vapply(x$pat_sigma_index, function(z) all(z %in% c(seq_len(x$n_visit), 999)), logical(1))),
        msg = "Invalid Stan Data Object"
    )
}

Try the rbmi package in your browser

Any scripts or data that you put into this service are public.

rbmi documentation built on Nov. 24, 2023, 5:11 p.m.