R/posterior_transformation.R

Defines functions validate_bcf_term validate_bart_term bcf_model_has_term bart_model_has_term check_model_has_term check_model_is_valid summarize_interval compute_bart_posterior_interval compute_bcf_posterior_interval posterior_predictive_heuristic_multiplier sample_bart_posterior_predictive sample_bcf_posterior_predictive compute_contrast_bart_model compute_contrast_bcf_model

Documented in compute_bart_posterior_interval compute_bcf_posterior_interval compute_contrast_bart_model compute_contrast_bcf_model sample_bart_posterior_predictive sample_bcf_posterior_predictive

#' Compute a contrast between two outcome prediction specifications for a BCF model
#'
#' Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference.
#' For simple BCF models with binary treatment, this will yield the same prediction as requesting `terms = "cate"`
#' in the `predict.bcfmodel` function. For more general models, such as models with continuous / multivariate treatments or
#' an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a
#' any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term
#' contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend of the
#' contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" terminology of a classic
#' two-arm experiment. We mirror the function calls and terminology of the `predict.bcfmodel` function, labeling each prediction
#' data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the
#' control prediction.
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
#' @param Z_0 Treatments used for prediction in the "control" case. Must be a matrix or vector.
#' @param Z_1 Treatments used for prediction in the "treatment" case. Must be a matrix or vector.
#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector.
#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector.
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation
#' for group labels that were not in the training set. Must be a vector.
#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects
#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation
#' for group labels that were not in the training set. Must be a vector.
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.  Must be a matrix or vector.
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.
#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".
#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#'
#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' noise_sd <- 1
#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train, num_gfr = 10,
#'                  num_burnin = 0, num_mcmc = 10)
#' tau_hat_test <- compute_contrast_bcf_model(
#'     bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test),
#'     propensity_0 = pi_test, propensity_1 = pi_test
#' )
compute_contrast_bcf_model <- function(
  object,
  X_0,
  X_1,
  Z_0,
  Z_1,
  propensity_0 = NULL,
  propensity_1 = NULL,
  rfx_group_ids_0 = NULL,
  rfx_group_ids_1 = NULL,
  rfx_basis_0 = NULL,
  rfx_basis_1 = NULL,
  type = "posterior",
  scale = "linear"
) {
  # Handle mean function scale
  if (!is.character(scale)) {
    stop("scale must be a string or character vector")
  }
  if (!(scale %in% c("linear", "probability"))) {
    stop("scale must either be 'linear' or 'probability'")
  }
  is_probit <- object$model_params$probit_outcome_model
  if ((scale == "probability") && (!is_probit)) {
    stop(
      "scale cannot be 'probability' for models not fit with a probit outcome model"
    )
  }
  probability_scale <- scale == "probability"

  # Handle prediction type
  if (!is.character(type)) {
    stop("type must be a string or character vector")
  }
  if (!(type %in% c("mean", "posterior"))) {
    stop("type must either be 'mean' or 'posterior")
  }
  predict_mean <- type == "mean"

  # Make sure covariates are matrix or data frame
  if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) {
    stop("X_0 must be a matrix or dataframe")
  }
  if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) {
    stop("X_1 must be a matrix or dataframe")
  }

  # Convert all input data to matrices if not already converted
  if ((is.null(dim(Z_0))) && (!is.null(Z_0))) {
    Z_0 <- as.matrix(as.numeric(Z_0))
  }
  if ((is.null(dim(Z_1))) && (!is.null(Z_1))) {
    Z_1 <- as.matrix(as.numeric(Z_1))
  }
  if ((is.null(dim(propensity_0))) && (!is.null(propensity_0))) {
    propensity_0 <- as.matrix(propensity_0)
  }
  if ((is.null(dim(propensity_1))) && (!is.null(propensity_1))) {
    propensity_1 <- as.matrix(propensity_1)
  }
  if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) {
    rfx_basis_0 <- as.matrix(rfx_basis_0)
  }
  if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) {
    rfx_basis_1 <- as.matrix(rfx_basis_1)
  }

  # Data checks
  if (
    (object$model_params$propensity_covariate != "none") &&
      ((is.null(propensity_0)) ||
        (is.null(propensity_1)))
  ) {
    if (!object$model_params$internal_propensity_model) {
      stop("propensity_0 and propensity_1 must be provided for this model")
    }
  }
  if (nrow(X_0) != nrow(Z_0)) {
    stop("X_0 and Z_0 must have the same number of rows")
  }
  if (nrow(X_1) != nrow(Z_1)) {
    stop("X_1 and Z_1 must have the same number of rows")
  }
  if (object$model_params$num_covariates != ncol(X_0)) {
    stop(
      "X_0 and must have the same number of columns as the covariates used to train the model"
    )
  }
  if (object$model_params$num_covariates != ncol(X_1)) {
    stop(
      "X_1 and must have the same number of columns as the covariates used to train the model"
    )
  }
  if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_0))) {
    stop(
      "Random effect group labels (rfx_group_ids_0) must be provided for this model"
    )
  }
  if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_1))) {
    stop(
      "Random effect group labels (rfx_group_ids_1) must be provided for this model"
    )
  }
  if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_0))) {
    stop("Random effects basis (rfx_basis_0) must be provided for this model")
  }
  if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_1))) {
    stop("Random effects basis (rfx_basis_1) must be provided for this model")
  }
  if (
    (object$model_params$num_rfx_basis > 0) &&
      (ncol(rfx_basis_0) != object$model_params$num_rfx_basis)
  ) {
    stop(
      "Random effects basis has a different dimension than the basis used to train this model"
    )
  }
  if (
    (object$model_params$num_rfx_basis > 0) &&
      (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)
  ) {
    stop(
      "Random effects basis has a different dimension than the basis used to train this model"
    )
  }

  # Predict for the control arm
  control_preds <- predict(
    object = object,
    X = X_0,
    Z = Z_0,
    propensity = propensity_0,
    rfx_group_ids = rfx_group_ids_0,
    rfx_basis = rfx_basis_0,
    type = "posterior",
    term = "y_hat",
    scale = "linear"
  )

  # Predict for the treatment arm
  treatment_preds <- predict(
    object = object,
    X = X_1,
    Z = Z_1,
    propensity = propensity_1,
    rfx_group_ids = rfx_group_ids_1,
    rfx_basis = rfx_basis_1,
    type = "posterior",
    term = "y_hat",
    scale = "linear"
  )

  # Transform to probability scale if requested
  if (probability_scale) {
    treatment_preds <- pnorm(treatment_preds)
    control_preds <- pnorm(control_preds)
  }

  # Compute and return contrast
  if (predict_mean) {
    return(rowMeans(treatment_preds - control_preds))
  } else {
    return(treatment_preds - control_preds)
  }
}

#' Compute a contrast between two outcome prediction specifications for a BART model
#'
#' Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference.
#' This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects
#' bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or
#' `Y0` term and the minuend of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
#' terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the `predict.bartmodel`
#' function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and
#' `0` to denote inclusion in the control prediction.
#'
#' Only valid when there is either a mean forest or a random effects term in the BART model.
#'
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`.
#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`.
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation
#' for group labels that were not in the training set. Must be a vector.
#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects
#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation
#' for group labels that were not in the training set. Must be a vector.
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.  Must be a matrix or vector.
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.
#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".
#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#'
#' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior".
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' W <- matrix(runif(n*1), ncol = 1)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' W_test <- W[test_inds,]
#' W_train <- W[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train,
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' contrast_test <- compute_contrast_bart_model(
#'     bart_model,
#'     X_0 = X_test,
#'     X_1 = X_test,
#'     leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
#'     leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
#'     type = "posterior",
#'     scale = "linear"
#' )
compute_contrast_bart_model <- function(
  object,
  X_0,
  X_1,
  leaf_basis_0 = NULL,
  leaf_basis_1 = NULL,
  rfx_group_ids_0 = NULL,
  rfx_group_ids_1 = NULL,
  rfx_basis_0 = NULL,
  rfx_basis_1 = NULL,
  type = "posterior",
  scale = "linear"
) {
  # Handle mean function scale
  if (!is.character(scale)) {
    stop("scale must be a string or character vector")
  }
  if (!(scale %in% c("linear", "probability"))) {
    stop("scale must either be 'linear' or 'probability'")
  }
  is_probit <- object$model_params$probit_outcome_model
  if ((scale == "probability") && (!is_probit)) {
    stop(
      "scale cannot be 'probability' for models not fit with a probit outcome model"
    )
  }
  probability_scale <- scale == "probability"

  # Handle prediction type
  if (!is.character(type)) {
    stop("type must be a string or character vector")
  }
  if (!(type %in% c("mean", "posterior"))) {
    stop("type must either be 'mean' or 'posterior'")
  }
  predict_mean <- type == "mean"

  # Handle prediction terms
  has_mean_forest <- object$model_params$include_mean_forest
  has_rfx <- object$model_params$has_rfx
  if ((!has_mean_forest) && (!has_rfx)) {
    stop(
      "Model must have either or both of mean forest or random effects terms to compute the requested contrast."
    )
  }

  # Check that covariates are matrix or data frame
  if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) {
    stop("X_0 must be a matrix or dataframe")
  }
  if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) {
    stop("X_1 must be a matrix or dataframe")
  }

  # Convert all input data to matrices if not already converted
  if ((is.null(dim(leaf_basis_0))) && (!is.null(leaf_basis_0))) {
    leaf_basis_0 <- as.matrix(leaf_basis_0)
  }
  if ((is.null(dim(leaf_basis_1))) && (!is.null(leaf_basis_1))) {
    leaf_basis_1 <- as.matrix(leaf_basis_1)
  }
  if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) {
    rfx_basis_0 <- as.matrix(rfx_basis_0)
  }
  if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) {
    rfx_basis_1 <- as.matrix(rfx_basis_1)
  }

  # Data checks
  if (
    (object$model_params$requires_basis) &&
      (is.null(leaf_basis_0) || is.null(leaf_basis_1))
  ) {
    stop("leaf_basis_0 and leaf_basis_1 must be provided for this model")
  }
  if ((!is.null(leaf_basis_0)) && (nrow(X_0) != nrow(leaf_basis_0))) {
    stop("X_0 and leaf_basis_0 must have the same number of rows")
  }
  if ((!is.null(leaf_basis_1)) && (nrow(X_1) != nrow(leaf_basis_1))) {
    stop("X_1 and leaf_basis_1 must have the same number of rows")
  }
  if (object$model_params$num_covariates != ncol(X_0)) {
    stop(
      "X_0 must contain the same number of columns as the BART model's training dataset"
    )
  }
  if (object$model_params$num_covariates != ncol(X_1)) {
    stop(
      "X_1 must contain the same number of columns as the BART model's training dataset"
    )
  }
  if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) {
    stop(
      "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
    )
  }
  if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
    stop(
      "rfx_basis_0 and rfx_basis_1 must be provided for this model"
    )
  }
  if (
    (object$model_params$num_rfx_basis > 0) &&
      ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
        (ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
  ) {
    stop(
      "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
    )
  }

  # Predict for the control arm
  control_preds <- predict(
    object = object,
    X = X_0,
    leaf_basis = leaf_basis_0,
    rfx_group_ids = rfx_group_ids_0,
    rfx_basis = rfx_basis_0,
    type = "posterior",
    term = "y_hat",
    scale = "linear"
  )

  # Predict for the treatment arm
  treatment_preds <- predict(
    object = object,
    X = X_1,
    leaf_basis = leaf_basis_1,
    rfx_group_ids = rfx_group_ids_1,
    rfx_basis = rfx_basis_1,
    type = "posterior",
    term = "y_hat",
    scale = "linear"
  )

  # Transform to probability scale if requested
  if (probability_scale) {
    treatment_preds <- pnorm(treatment_preds)
    control_preds <- pnorm(control_preds)
  }

  # Compute and return contrast
  if (predict_mean) {
    return(rowMeans(treatment_preds - control_preds))
  } else {
    return(treatment_preds - control_preds)
  }
}

#' Sample from the posterior predictive distribution for outcomes modeled by BCF
#'
#' @param model_object A fitted BCF model object of class `bcfmodel`.
#' @param X A matrix or data frame of covariates.
#' @param Z A vector or matrix of treatment assignments.
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.
#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.
#' @param num_draws_per_sample (Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).
#'
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
#'
#' @export
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' pi_X <- pnorm(X[,1] / 2)
#' Z <- rbinom(n, 1, pi_X)
#' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n)
#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X)
#' ppd_samples <- sample_bcf_posterior_predictive(
#'   model_object = bcf_model, X = X,
#'   Z = Z, propensity = pi_X
#' )
sample_bcf_posterior_predictive <- function(
  model_object,
  X = NULL,
  Z = NULL,
  propensity = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL,
  num_draws_per_sample = NULL
) {
  # Check the provided model object
  check_model_is_valid(model_object)

  # Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
  is_probit <- model_object$model_params$probit_outcome_model

  # Check that all the necessary inputs were provided for interval computation
  needs_covariates <- TRUE
  if (needs_covariates) {
    if (is.null(X)) {
      stop(
        "'X' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(X) && !is.data.frame(X)) {
      stop("'X' must be a matrix or data frame")
    }
  }
  needs_treatment <- needs_covariates
  if (needs_treatment) {
    if (is.null(Z)) {
      stop(
        "'Z' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(Z) && !is.numeric(Z)) {
      stop("'Z' must be a numeric vector or matrix")
    }
    if (is.matrix(Z)) {
      if (nrow(Z) != nrow(X)) {
        stop("'Z' must have the same number of rows as 'X'")
      }
    } else {
      if (length(Z) != nrow(X)) {
        stop(
          "'Z' must have the same number of elements as 'X'"
        )
      }
    }
  }
  uses_propensity <- model_object$model_params$propensity_covariate != "none"
  internal_propensity_model <- model_object$model_params$internal_propensity_model
  needs_propensity <- (needs_covariates &&
    uses_propensity &&
    (!internal_propensity_model))
  if (needs_propensity) {
    if (is.null(propensity)) {
      stop(
        "'propensity' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(propensity) && !is.numeric(propensity)) {
      stop("'propensity' must be a numeric vector or matrix")
    }
    if (is.matrix(propensity)) {
      if (nrow(propensity) != nrow(X)) {
        stop("'propensity' must have the same number of rows as 'X'")
      }
    } else {
      if (length(propensity) != nrow(X)) {
        stop(
          "'propensity' must have the same number of elements as 'X'"
        )
      }
    }
  }
  needs_rfx_data <- model_object$model_params$has_rfx
  if (needs_rfx_data) {
    if (is.null(rfx_group_ids)) {
      stop(
        "'rfx_group_ids' must be provided in order to compute the requested intervals"
      )
    }
    if (length(rfx_group_ids) != nrow(X)) {
      stop(
        "'rfx_group_ids' must have the same length as the number of rows in 'X'"
      )
    }
    if (is.null(rfx_basis)) {
      stop(
        "'rfx_basis' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(rfx_basis)) {
      stop("'rfx_basis' must be a matrix")
    }
    if (nrow(rfx_basis) != nrow(X)) {
      stop("'rfx_basis' must have the same number of rows as 'X'")
    }
  }

  # Compute posterior samples
  bcf_preds <- predict(
    model_object,
    X = X,
    Z = Z,
    propensity = propensity,
    rfx_group_ids = rfx_group_ids,
    rfx_basis = rfx_basis,
    type = "posterior",
    terms = c("all"),
    scale = "linear"
  )

  # Compute outcome mean and variance for every posterior draw
  has_rfx <- model_object$model_params$has_rfx
  has_variance_forest <- model_object$model_params$include_variance_forest
  samples_global_variance <- model_object$model_params$sample_sigma2_global
  num_posterior_draws <- model_object$model_params$num_samples
  num_observations <- nrow(X)
  ppd_mean <- bcf_preds$y_hat
  if (has_variance_forest) {
    ppd_variance <- bcf_preds$variance_forest_predictions
  } else {
    if (samples_global_variance) {
      ppd_variance <- matrix(
        rep(
          model_object$sigma2_global_samples,
          each = num_observations
        ),
        nrow = num_observations
      )
    } else {
      ppd_variance <- model_object$model_params$initial_sigma2
    }
  }

  # Sample from the posterior predictive distribution
  if (is.null(num_draws_per_sample)) {
    ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier(
      num_posterior_draws,
      num_observations
    )
  } else {
    ppd_draw_multiplier <- num_draws_per_sample
  }
  num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations
  ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance))

  # Reshape data
  if (ppd_draw_multiplier > 1) {
    ppd_array <- array(
      ppd_vector,
      dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier)
    )
  } else {
    ppd_array <- array(
      ppd_vector,
      dim = c(num_observations, num_posterior_draws)
    )
  }

  # Binarize outcomes for probit models
  if (is_probit) {
    ppd_array <- (ppd_array > 0.0) * 1
  }

  return(ppd_array)
}

#' Sample from the posterior predictive distribution for outcomes modeled by BART
#'
#' @param model_object A fitted BART model object of class `bartmodel`.
#' @param X A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
#' @param leaf_basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects.
#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects.
#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
#'
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
#'
#' @export
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' y <- 2 * X[,1] + rnorm(n)
#' bart_model <- bart(y_train = y, X_train = X)
#' ppd_samples <- sample_bart_posterior_predictive(
#'   model_object = bart_model, X = X
#' )
sample_bart_posterior_predictive <- function(
  model_object,
  X = NULL,
  leaf_basis = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL,
  num_draws_per_sample = NULL
) {
  # Check the provided model object
  check_model_is_valid(model_object)

  # Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
  is_probit <- model_object$model_params$probit_outcome_model

  # Check that all the necessary inputs were provided for interval computation
  needs_covariates <- model_object$model_params$include_mean_forest
  if (needs_covariates) {
    if (is.null(X)) {
      stop(
        "'X' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(X) && !is.data.frame(X)) {
      stop("'X' must be a matrix or data frame")
    }
  }
  needs_basis <- needs_covariates && model_object$model_params$has_basis
  if (needs_basis) {
    if (is.null(leaf_basis)) {
      stop(
        "'leaf_basis' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(leaf_basis)) {
      stop("'leaf_basis' must be a matrix")
    }
    if (is.matrix(leaf_basis)) {
      if (nrow(leaf_basis) != nrow(X)) {
        stop("'leaf_basis' must have the same number of rows as 'X'")
      }
    } else {
      if (length(leaf_basis) != nrow(X)) {
        stop("'leaf_basis' must have the same number of elements as 'X'")
      }
    }
  }
  needs_rfx_data <- model_object$model_params$has_rfx
  if (needs_rfx_data) {
    if (is.null(rfx_group_ids)) {
      stop(
        "'rfx_group_ids' must be provided in order to compute the requested intervals"
      )
    }
    if (length(rfx_group_ids) != nrow(X)) {
      stop(
        "'rfx_group_ids' must have the same length as the number of rows in 'X'"
      )
    }
    if (is.null(rfx_basis)) {
      stop(
        "'rfx_basis' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(rfx_basis)) {
      stop("'rfx_basis' must be a matrix")
    }
    if (nrow(rfx_basis) != nrow(X)) {
      stop("'rfx_basis' must have the same number of rows as 'X'")
    }
  }

  # Compute posterior samples
  bart_preds <- predict(
    model_object,
    X = X,
    leaf_basis = leaf_basis,
    rfx_group_ids = rfx_group_ids,
    rfx_basis = rfx_basis,
    type = "posterior",
    terms = c("all"),
    scale = "linear"
  )

  # Compute outcome mean and variance for every posterior draw
  has_mean_term <- (model_object$model_params$include_mean_forest ||
    model_object$model_params$has_rfx)
  has_variance_forest <- model_object$model_params$include_variance_forest
  samples_global_variance <- model_object$model_params$sample_sigma2_global
  num_posterior_draws <- model_object$model_params$num_samples
  num_observations <- nrow(X)
  if (has_mean_term) {
    ppd_mean <- bart_preds$y_hat
  } else {
    ppd_mean <- 0
  }
  if (has_variance_forest) {
    ppd_variance <- bart_preds$variance_forest_predictions
  } else {
    if (samples_global_variance) {
      ppd_variance <- matrix(
        rep(
          model_object$sigma2_global_samples,
          each = num_observations
        ),
        nrow = num_observations
      )
    } else {
      ppd_variance <- model_object$model_params$sigma2_init
    }
  }

  # Sample from the posterior predictive distribution
  if (is.null(num_draws_per_sample)) {
    ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier(
      num_posterior_draws,
      num_observations
    )
  } else {
    ppd_draw_multiplier <- num_draws_per_sample
  }
  num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations
  ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance))

  # Reshape data
  if (ppd_draw_multiplier > 1) {
    ppd_array <- array(
      ppd_vector,
      dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier)
    )
  } else {
    ppd_array <- array(
      ppd_vector,
      dim = c(num_observations, num_posterior_draws)
    )
  }

  # Binarize outcomes for probit models
  if (is_probit) {
    ppd_array <- (ppd_array > 0.0) * 1
  }

  return(ppd_array)
}

posterior_predictive_heuristic_multiplier <- function(
  num_samples,
  num_observations
) {
  if (num_samples >= 1000) {
    return(1)
  } else {
    return(ceiling(1000 / num_samples))
  }
}

#' Compute posterior credible intervals for BCF model terms
#'
#' Compute posterior credible intervals for specified terms from a fitted BCF model. Supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions.
#'
#' @param model_object A fitted BCF model object of class `bcfmodel`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
#' @param Z (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
#'
#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned.
#'
#' @export
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' pi_X <- pnorm(0.5 * X[,1])
#' Z <- rbinom(n, 1, pi_X)
#' mu_X <- X[,1]
#' tau_X <- 0.25 * X[,2]
#' y <- mu_X + tau_X * Z + rnorm(n)
#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y,
#'                  propensity_train = pi_X)
#' intervals <- compute_bcf_posterior_interval(
#'  model_object = bcf_model,
#'  terms = c("prognostic_function", "cate"),
#'  X = X,
#'  Z = Z,
#'  propensity = pi_X,
#'  level = 0.90
#' )
compute_bcf_posterior_interval <- function(
  model_object,
  terms,
  level = 0.95,
  scale = "linear",
  X = NULL,
  Z = NULL,
  propensity = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL
) {
  # Check the provided model object and requested term
  check_model_is_valid(model_object)
  for (term in terms) {
    check_model_has_term(model_object, term)
  }

  # Handle mean function scale
  if (!is.character(scale)) {
    stop("scale must be a string or character vector")
  }
  if (!(scale %in% c("linear", "probability"))) {
    stop("scale must either be 'linear' or 'probability'")
  }
  is_probit <- model_object$model_params$probit_outcome_model
  if ((scale == "probability") && (!is_probit)) {
    stop(
      "scale cannot be 'probability' for models not fit with a probit outcome model"
    )
  }

  # Check that all the necessary inputs were provided for interval computation
  for (term in terms) {
    if (
      !(term %in%
        c(
          "prognostic_function",
          "mu",
          "cate",
          "tau",
          "variance_forest",
          "rfx",
          "y_hat",
          "all"
        ))
    ) {
      stop(
        paste0(
          "Term '",
          term,
          "' was requested. Valid terms are 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', and 'all'."
        )
      )
    }
  }
  needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
    ("all" %in% terms)))
  needs_covariates <- (("prognostic_function" %in% terms) ||
    ("cate" %in% terms) ||
    ("variance_forest" %in% terms) ||
    (needs_covariates_intermediate))
  if (needs_covariates) {
    if (is.null(X)) {
      stop(
        "'X' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(X) && !is.data.frame(X)) {
      stop("'X' must be a matrix or data frame")
    }
  }
  needs_treatment <- needs_covariates
  if (needs_treatment) {
    if (is.null(Z)) {
      stop(
        "'Z' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(Z) && !is.numeric(Z)) {
      stop("'Z' must be a numeric vector or matrix")
    }
    if (is.matrix(Z)) {
      if (nrow(Z) != nrow(X)) {
        stop("'Z' must have the same number of rows as 'X'")
      }
    } else {
      if (length(Z) != nrow(X)) {
        stop(
          "'Z' must have the same number of elements as 'X'"
        )
      }
    }
  }
  uses_propensity <- model_object$model_params$propensity_covariate != "none"
  internal_propensity_model <- model_object$model_params$internal_propensity_model
  needs_propensity <- (needs_covariates &&
    uses_propensity &&
    (!internal_propensity_model))
  if (needs_propensity) {
    if (is.null(propensity)) {
      stop(
        "'propensity' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(propensity) && !is.numeric(propensity)) {
      stop("'propensity' must be a numeric vector or matrix")
    }
    if (is.matrix(propensity)) {
      if (nrow(propensity) != nrow(X)) {
        stop("'propensity' must have the same number of rows as 'X'")
      }
    } else {
      if (length(propensity) != nrow(X)) {
        stop(
          "'propensity' must have the same number of elements as 'X'"
        )
      }
    }
  }
  needs_rfx_data_intermediate <- ((("y_hat" %in% terms) ||
    ("all" %in% terms)) &&
    model_object$model_params$has_rfx)
  needs_rfx_data <- (("rfx" %in% terms) ||
    (needs_rfx_data_intermediate))
  if (needs_rfx_data) {
    if (is.null(rfx_group_ids)) {
      stop(
        "'rfx_group_ids' must be provided in order to compute the requested intervals"
      )
    }
    if (length(rfx_group_ids) != nrow(X)) {
      stop(
        "'rfx_group_ids' must have the same length as the number of rows in 'X'"
      )
    }

    if (model_object$model_params$rfx_model_spec == "custom") {
      if (is.null(rfx_basis)) {
        stop(
          "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
        )
      }
    }

    if (!is.null(rfx_basis)) {
      if (!is.matrix(rfx_basis)) {
        stop("'rfx_basis' must be a matrix")
      }
      if (nrow(rfx_basis) != nrow(X)) {
        stop("'rfx_basis' must have the same number of rows as 'X'")
      }
    }
  }

  # Compute posterior matrices for the requested model terms
  predictions <- predict(
    model_object,
    X = X,
    Z = Z,
    propensity = propensity,
    rfx_group_ids = rfx_group_ids,
    rfx_basis = rfx_basis,
    type = "posterior",
    terms = terms,
    scale = scale
  )
  has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE)

  # Compute the interval
  if (has_multiple_terms) {
    result <- list()
    for (term_name in names(predictions)) {
      if (!is.null(predictions[[term_name]])) {
        result[[term_name]] <- summarize_interval(
          predictions[[term_name]],
          sample_dim = 2,
          level = level
        )
      } else {
        result[[term_name]] <- NULL
      }
    }
    return(result)
  } else {
    return(summarize_interval(
      predictions,
      sample_dim = 2,
      level = level
    ))
  }
}

#' Compute posterior credible intervals for specified terms from a fitted BART model.
#'
#' Compute posterior credible intervals for specified terms from a fitted BART model. Supports intervals for mean functions, variance functions, random effects, and overall outcome predictions.
#'
#' @param model_object A fitted BART or BCF model object of class `bartmodel`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param X A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
#' @param leaf_basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
#'
#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned.
#'
#' @export
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' y <- 2 * X[,1] + rnorm(n)
#' bart_model <- bart(y_train = y, X_train = X)
#' intervals <- compute_bart_posterior_interval(
#'  model_object = bart_model,
#'  terms = c("mean_forest", "y_hat"),
#'  X = X,
#'  level = 0.90
#' )
#' @export
compute_bart_posterior_interval <- function(
  model_object,
  terms,
  level = 0.95,
  scale = "linear",
  X = NULL,
  leaf_basis = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL
) {
  # Check the provided model object and requested term
  check_model_is_valid(model_object)
  for (term in terms) {
    check_model_has_term(model_object, term)
  }

  # Handle mean function scale
  if (!is.character(scale)) {
    stop("scale must be a string or character vector")
  }
  if (!(scale %in% c("linear", "probability"))) {
    stop("scale must either be 'linear' or 'probability'")
  }
  is_probit <- model_object$model_params$probit_outcome_model
  if ((scale == "probability") && (!is_probit)) {
    stop(
      "scale cannot be 'probability' for models not fit with a probit outcome model"
    )
  }

  # Check that all the necessary inputs were provided for interval computation
  needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
    ("all" %in% terms)) &&
    model_object$model_params$include_mean_forest)
  needs_covariates <- (("mean_forest" %in% terms) ||
    ("variance_forest" %in% terms) ||
    (needs_covariates_intermediate))
  if (needs_covariates) {
    if (is.null(X)) {
      stop(
        "'X' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(X) && !is.data.frame(X)) {
      stop("'X' must be a matrix or data frame")
    }
  }
  needs_basis <- needs_covariates && model_object$model_params$has_basis
  if (needs_basis) {
    if (is.null(leaf_basis)) {
      stop(
        "'leaf_basis' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(leaf_basis)) {
      stop("'leaf_basis' must be a matrix")
    }
    if (is.matrix(leaf_basis)) {
      if (nrow(leaf_basis) != nrow(X)) {
        stop("'leaf_basis' must have the same number of rows as 'X'")
      }
    } else {
      if (length(leaf_basis) != nrow(X)) {
        stop("'leaf_basis' must have the same number of elements as 'X'")
      }
    }
  }
  needs_rfx_data_intermediate <- ((("y_hat" %in% terms) ||
    ("all" %in% terms)) &&
    model_object$model_params$has_rfx)
  needs_rfx_data <- (("rfx" %in% terms) ||
    (needs_rfx_data_intermediate))
  if (needs_rfx_data) {
    if (is.null(rfx_group_ids)) {
      stop(
        "'rfx_group_ids' must be provided in order to compute the requested intervals"
      )
    }
    if (length(rfx_group_ids) != nrow(X)) {
      stop(
        "'rfx_group_ids' must have the same length as the number of rows in 'X'"
      )
    }
    if (is.null(rfx_basis)) {
      stop(
        "'rfx_basis' must be provided in order to compute the requested intervals"
      )
    }
    if (!is.matrix(rfx_basis)) {
      stop("'rfx_basis' must be a matrix")
    }
    if (nrow(rfx_basis) != nrow(X)) {
      stop("'rfx_basis' must have the same number of rows as 'X'")
    }
  }

  # Compute posterior matrices for the requested model terms
  predictions <- predict(
    model_object,
    X = X,
    leaf_basis = leaf_basis,
    rfx_group_ids = rfx_group_ids,
    rfx_basis = rfx_basis,
    type = "posterior",
    terms = terms,
    scale = scale
  )
  has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE)

  # Compute the interval
  if (has_multiple_terms) {
    result <- list()
    for (term_name in names(predictions)) {
      if (!is.null(predictions[[term_name]])) {
        result[[term_name]] <- summarize_interval(
          predictions[[term_name]],
          sample_dim = 2,
          level = level
        )
      } else {
        result[[term_name]] <- NULL
      }
    }
    return(result)
  } else {
    return(summarize_interval(
      predictions,
      sample_dim = 2,
      level = level
    ))
  }
}

summarize_interval <- function(array, sample_dim = 2, level = 0.95) {
  # Check that the array is numeric and at least 2 dimensional
  stopifnot(is.numeric(array) && length(dim(array)) >= 2)

  # Compute lower and upper quantiles based on the requested interval
  quantile_lb <- (1 - level) / 2
  quantile_ub <- 1 - quantile_lb

  # Determine the dimensions over which interval is computed
  apply_dim <- setdiff(1:length(dim(array)), sample_dim)

  # Calculate the interval
  result_lb <- apply(array, apply_dim, function(x) {
    quantile(x, probs = quantile_lb, names = FALSE)
  })
  result_ub <- apply(array, apply_dim, function(x) {
    quantile(x, probs = quantile_ub, names = FALSE)
  })

  return(list(lower = result_lb, upper = result_ub))
}

check_model_is_valid <- function(model_object) {
  if (
    (!inherits(model_object, "bartmodel")) &&
      (!inherits(model_object, "bcfmodel"))
  ) {
    stop("'model_object' must be a bartmodel or bcfmodel")
  }
}

check_model_has_term <- function(model_object, term) {
  # Parse inputs
  if (!is.character(term) || length(term) != 1) {
    stop("'term' must be a single character string")
  }
  if (
    (!inherits(model_object, "bartmodel")) &&
      (!inherits(model_object, "bcfmodel"))
  ) {
    stop("'model_object' must be a bartmodel or bcfmodel")
  }
  model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf")

  # Check if the term was fitted as part of the provided model
  if (model_type == "bart") {
    validate_bart_term(term)
    return(bart_model_has_term(model_object, term))
  } else {
    validate_bcf_term(term)
    return(bcf_model_has_term(model_object, term))
  }
}

bart_model_has_term <- function(model_object, term) {
  if (term == "mean_forest") {
    return(model_object$model_params$include_mean_forest)
  } else if (term == "variance_forest") {
    return(model_object$model_params$include_variance_forest)
  } else if (term == "rfx") {
    return(model_object$model_params$has_rfx)
  } else if (term == "y_hat") {
    return(
      model_object$model_params$include_mean_forest ||
        model_object$model_params$has_rfx
    )
  } else if (term == "all") {
    return(TRUE)
  } else {
    return(FALSE)
  }
}

bcf_model_has_term <- function(model_object, term) {
  if (term == "prognostic_function") {
    return(TRUE)
  } else if (term == "mu") {
    return(TRUE)
  } else if (term == "cate") {
    return(TRUE)
  } else if (term == "tau") {
    return(TRUE)
  } else if (term == "variance_forest") {
    return(model_object$model_params$include_variance_forest)
  } else if (term == "rfx") {
    return(model_object$model_params$has_rfx)
  } else if (term == "y_hat") {
    return(TRUE)
  } else if (term == "all") {
    return(TRUE)
  } else {
    return(FALSE)
  }
}

validate_bart_term <- function(term) {
  model_terms <- c("mean_forest", "variance_forest", "rfx", "y_hat", "all")
  if (!(term %in% model_terms)) {
    stop(
      "'term' must be one of 'mean_forest', 'variance_forest', 'rfx', 'y_hat', or 'all' for bartmodel objects"
    )
  }
}

validate_bcf_term <- function(term) {
  model_terms <- c(
    "prognostic_function",
    "mu",
    "cate",
    "tau",
    "variance_forest",
    "rfx",
    "y_hat",
    "all"
  )
  if (!(term %in% model_terms)) {
    stop(
      "'term' must be one of 'prognostic_function', 'mu', 'cate', 'tau', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects"
    )
  }
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on Nov. 22, 2025, 9:06 a.m.