R/bcf.R

Defines functions createBCFModelFromCombinedJsonString createBCFModelFromCombinedJson createBCFModelFromJsonString createBCFModelFromJsonFile createBCFModelFromJson saveBCFModelToJsonString saveBCFModelToJsonFile saveBCFModelToJson getRandomEffectSamples.bcfmodel predict.bcfmodel bcf

Documented in bcf createBCFModelFromCombinedJson createBCFModelFromCombinedJsonString createBCFModelFromJson createBCFModelFromJsonFile createBCFModelFromJsonString getRandomEffectSamples.bcfmodel predict.bcfmodel saveBCFModelToJson saveBCFModelToJsonFile saveBCFModelToJsonString

#' Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
#'
#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix.
#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be
#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded,
#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata
#' that the column is ordered categorical).
#' @param Z_train Vector of (continuous or binary) treatment assignments.
#' @param y_train Outcome to be modeled by the ensemble.
#' @param propensity_train (Optional) Vector of propensity scores. If not provided, this will be estimated from the data.
#' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model.
#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model.
#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model
#' will be estimated.
#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data.
#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with
#' that of `X_train`.
#' @param Z_test (Optional) Test set of (continuous or binary) treatment assignments.
#' @param propensity_test (Optional) Vector of propensity scores. If not provided, this will be estimated from the data.
#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model.
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
#'   - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`.
#'   - `sample_sigma2_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Default: `TRUE`.
#'   - `sigma2_global_init` Starting value of global error variance parameter. Calibrated internally as `1.0*var((y_train-mean(y_train))/sd(y_train))` if not set.
#'   - `sigma2_global_shape` Shape parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
#'   - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
#'   - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars` accordingly for the `prognostic` or `treatment_effect` forests.
#'   - `propensity_covariate` Whether to include the propensity score as a covariate in either or both of the forests. Enter `"none"` for neither, `"prognostic"` for the prognostic forest, `"treatment_effect"` for the treatment forest, and `"both"` for both forests. If this is not `"none"` and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: `"mu"`.
#'   - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `TRUE`.
#'   - `control_coding_init` Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: `-0.5`.
#'   - `treated_coding_init` Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: `0.5`.
#'   - `rfx_prior_var` Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))`
#'   - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#'   - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
#'   - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
#'   - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
#'   - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. Note that if `num_chains > 1`, the returned model object will contain samples from all chains, stored consecutively. That is, if there are 4 chains with 100 samples each, the first 100 samples will be from chain 1, the next 100 samples will be from chain 2, etc... For more detail on working with multi-chain BCF models, see [the multi chain vignette](https://stochtree.ai/R_docs/pkgdown/articles/MultiChain.html).
#'   - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#'   - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#'   - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
#'
#' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `num_trees` Number of trees in the ensemble for the prognostic forest. Default: `250`. Must be a positive integer.
#'   - `alpha` Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`.
#'   - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
#'   - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Default: `5`.
#'   - `max_depth` Maximum depth of any tree in the ensemble in the prognostic forest. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'   - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#'   - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`.
#'   - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
#'   - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`.
#'   - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#'   - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#'   - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#'   - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param treatment_effect_forest_params (Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `num_trees` Number of trees in the ensemble for the treatment effect forest. Default: `50`. Must be a positive integer.
#'   - `alpha` Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.25`.
#'   - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `3`.
#'   - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Default: `5`.
#'   - `max_depth` Maximum depth of any tree in the ensemble in the treatment effect forest. Default: `5`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'   - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
#'   - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(Z_train)>1`. Default: `FALSE`.
#'   - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
#'   - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`.
#'   - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#'   - `delta_max` Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Default: `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`.
#'   - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#'   - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#'   - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `num_trees` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees > 0`.
#'   - `alpha` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`.
#'   - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
#'   - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`.
#'   - `max_depth` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'   - `leaf_prior_calibration_param` Hyperparameter used to calibrate the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model. If `var_forest_prior_shape` and `var_forest_prior_scale` are not set below, this calibration parameter is used to set these values to `num_trees / leaf_prior_calibration_param^2 + 0.5` and `num_trees / leaf_prior_calibration_param^2`, respectively. Default: `1.5`.
#'   - `variance_forest_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees` if not set.
#'   - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set.
#'   - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set.
#'   - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#'   - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#'   - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
#'
#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `model_spec` Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
#'   - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
#'   - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
#'   - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
#'   - `group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
#'   - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#'   - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#'
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' noise_sd <- 1
#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
#'                  propensity_test = pi_test, num_gfr = 10,
#'                  num_burnin = 0, num_mcmc = 10)
bcf <- function(
  X_train,
  Z_train,
  y_train,
  propensity_train = NULL,
  rfx_group_ids_train = NULL,
  rfx_basis_train = NULL,
  X_test = NULL,
  Z_test = NULL,
  propensity_test = NULL,
  rfx_group_ids_test = NULL,
  rfx_basis_test = NULL,
  num_gfr = 5,
  num_burnin = 0,
  num_mcmc = 100,
  previous_model_json = NULL,
  previous_model_warmstart_sample_num = NULL,
  general_params = list(),
  prognostic_forest_params = list(),
  treatment_effect_forest_params = list(),
  variance_forest_params = list(),
  random_effects_params = list()
) {
  # Update general BCF parameters
  general_params_default <- list(
    cutpoint_grid_size = 100,
    standardize = TRUE,
    sample_sigma2_global = TRUE,
    sigma2_global_init = NULL,
    sigma2_global_shape = 0,
    sigma2_global_scale = 0,
    variable_weights = NULL,
    propensity_covariate = "prognostic",
    adaptive_coding = TRUE,
    control_coding_init = -0.5,
    treated_coding_init = 0.5,
    rfx_prior_var = NULL,
    random_seed = -1,
    keep_burnin = FALSE,
    keep_gfr = FALSE,
    keep_every = 1,
    num_chains = 1,
    verbose = FALSE,
    probit_outcome_model = FALSE,
    num_threads = -1
  )
  general_params_updated <- preprocessParams(
    general_params_default,
    general_params
  )

  # Update mu forest BCF parameters
  prognostic_forest_params_default <- list(
    num_trees = 250,
    alpha = 0.95,
    beta = 2.0,
    min_samples_leaf = 5,
    max_depth = 10,
    sample_sigma2_leaf = TRUE,
    sigma2_leaf_init = NULL,
    sigma2_leaf_shape = 3,
    sigma2_leaf_scale = NULL,
    keep_vars = NULL,
    drop_vars = NULL,
    num_features_subsample = NULL
  )
  prognostic_forest_params_updated <- preprocessParams(
    prognostic_forest_params_default,
    prognostic_forest_params
  )

  # Update tau forest BCF parameters
  treatment_effect_forest_params_default <- list(
    num_trees = 50,
    alpha = 0.25,
    beta = 3.0,
    min_samples_leaf = 5,
    max_depth = 5,
    sample_sigma2_leaf = FALSE,
    sigma2_leaf_init = NULL,
    sigma2_leaf_shape = 3,
    sigma2_leaf_scale = NULL,
    keep_vars = NULL,
    drop_vars = NULL,
    delta_max = 0.9,
    num_features_subsample = NULL
  )
  treatment_effect_forest_params_updated <- preprocessParams(
    treatment_effect_forest_params_default,
    treatment_effect_forest_params
  )

  # Update variance forest BCF parameters
  variance_forest_params_default <- list(
    num_trees = 0,
    alpha = 0.95,
    beta = 2.0,
    min_samples_leaf = 5,
    max_depth = 10,
    leaf_prior_calibration_param = 1.5,
    variance_forest_init = NULL,
    var_forest_prior_shape = NULL,
    var_forest_prior_scale = NULL,
    keep_vars = NULL,
    drop_vars = NULL,
    num_features_subsample = NULL
  )
  variance_forest_params_updated <- preprocessParams(
    variance_forest_params_default,
    variance_forest_params
  )

  # Update random effects parameters
  rfx_params_default <- list(
    model_spec = "custom",
    working_parameter_prior_mean = NULL,
    group_parameter_prior_mean = NULL,
    working_parameter_prior_cov = NULL,
    group_parameter_prior_cov = NULL,
    variance_prior_shape = 1,
    variance_prior_scale = 1
  )
  rfx_params_updated <- preprocessParams(
    rfx_params_default,
    random_effects_params
  )

  ### Unpack all parameter values
  # 1. General parameters
  cutpoint_grid_size <- general_params_updated$cutpoint_grid_size
  standardize <- general_params_updated$standardize
  sample_sigma2_global <- general_params_updated$sample_sigma2_global
  sigma2_init <- general_params_updated$sigma2_global_init
  a_global <- general_params_updated$sigma2_global_shape
  b_global <- general_params_updated$sigma2_global_scale
  variable_weights <- general_params_updated$variable_weights
  propensity_covariate <- general_params_updated$propensity_covariate
  adaptive_coding <- general_params_updated$adaptive_coding
  b_0 <- general_params_updated$control_coding_init
  b_1 <- general_params_updated$treated_coding_init
  rfx_prior_var <- general_params_updated$rfx_prior_var
  random_seed <- general_params_updated$random_seed
  keep_burnin <- general_params_updated$keep_burnin
  keep_gfr <- general_params_updated$keep_gfr
  keep_every <- general_params_updated$keep_every
  num_chains <- general_params_updated$num_chains
  verbose <- general_params_updated$verbose
  probit_outcome_model <- general_params_updated$probit_outcome_model
  num_threads <- general_params_updated$num_threads

  # 2. Mu forest parameters
  num_trees_mu <- prognostic_forest_params_updated$num_trees
  alpha_mu <- prognostic_forest_params_updated$alpha
  beta_mu <- prognostic_forest_params_updated$beta
  min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf
  max_depth_mu <- prognostic_forest_params_updated$max_depth
  sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf
  sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init
  a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape
  b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale
  keep_vars_mu <- prognostic_forest_params_updated$keep_vars
  drop_vars_mu <- prognostic_forest_params_updated$drop_vars
  num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample

  # 3. Tau forest parameters
  num_trees_tau <- treatment_effect_forest_params_updated$num_trees
  alpha_tau <- treatment_effect_forest_params_updated$alpha
  beta_tau <- treatment_effect_forest_params_updated$beta
  min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf
  max_depth_tau <- treatment_effect_forest_params_updated$max_depth
  sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf
  sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init
  a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape
  b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale
  keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars
  drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars
  delta_max <- treatment_effect_forest_params_updated$delta_max
  num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample

  # 4. Variance forest parameters
  num_trees_variance <- variance_forest_params_updated$num_trees
  alpha_variance <- variance_forest_params_updated$alpha
  beta_variance <- variance_forest_params_updated$beta
  min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf
  max_depth_variance <- variance_forest_params_updated$max_depth
  a_0 <- variance_forest_params_updated$leaf_prior_calibration_param
  variance_forest_init <- variance_forest_params_updated$init_root_val
  a_forest <- variance_forest_params_updated$var_forest_prior_shape
  b_forest <- variance_forest_params_updated$var_forest_prior_scale
  keep_vars_variance <- variance_forest_params_updated$keep_vars
  drop_vars_variance <- variance_forest_params_updated$drop_vars
  num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample

  # 5. Random effects parameters
  rfx_model_spec <- rfx_params_updated$model_spec
  rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean
  rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean
  rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov
  rfx_group_parameter_prior_cov <- rfx_params_updated$group_parameter_prior_cov
  rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape
  rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale

  # Handle random effects specification
  if (!is.character(rfx_model_spec)) {
    stop("rfx_model_spec must be a string or character vector")
  }
  if (
    !(rfx_model_spec %in%
      c("custom", "intercept_only", "intercept_plus_treatment"))
  ) {
    stop(
      "rfx_model_spec must either be 'custom', 'intercept_only', or 'intercept_plus_treatment'"
    )
  }

  # Set a function-scoped RNG if user provided a random seed
  custom_rng <- random_seed >= 0
  if (custom_rng) {
    # Store original global environment RNG state
    original_global_seed <- .Random.seed
    # Set new seed and store associated RNG state
    set.seed(random_seed)
    function_scoped_seed <- .Random.seed
  }

  # Check if there are enough GFR samples to seed num_chains samplers
  if (num_gfr > 0) {
    if (num_chains > num_gfr) {
      stop(
        "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains"
      )
    }
  }

  # Override keep_gfr if there are no MCMC samples
  if (num_mcmc == 0) {
    keep_gfr <- TRUE
  }

  # Check if previous model JSON is provided and parse it if so
  has_prev_model <- !is.null(previous_model_json)
  has_prev_model_index <- !is.null(previous_model_warmstart_sample_num)
  if (has_prev_model) {
    previous_bcf_model <- createBCFModelFromJsonString(previous_model_json)
    prev_num_samples <- previous_bcf_model$model_params$num_samples
    if (!has_prev_model_index) {
      previous_model_warmstart_sample_num <- prev_num_samples
      warning(
        "`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
      )
    } else {
      if (previous_model_warmstart_sample_num < 1) {
        stop(
          "`previous_model_warmstart_sample_num` must be a positive integer"
        )
      }
      if (previous_model_warmstart_sample_num > prev_num_samples) {
        stop(
          "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
        )
      }
    }
    previous_model_decrement <- T
    if (num_chains > previous_model_warmstart_sample_num) {
      warning(
        "The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
      )
      previous_model_decrement <- F
    }
    previous_y_bar <- previous_bcf_model$model_params$outcome_mean
    previous_y_scale <- previous_bcf_model$model_params$outcome_scale
    previous_forest_samples_mu <- previous_bcf_model$forests_mu
    previous_forest_samples_tau <- previous_bcf_model$forests_tau
    if (previous_bcf_model$model_params$include_variance_forest) {
      previous_forest_samples_variance <- previous_bcf_model$forests_variance
    } else {
      previous_forest_samples_variance <- NULL
    }
    if (previous_bcf_model$model_params$sample_sigma2_global) {
      previous_global_var_samples <- previous_bcf_model$sigma2_global_samples /
        (previous_y_scale * previous_y_scale)
    } else {
      previous_global_var_samples <- NULL
    }
    if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) {
      previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples
    } else {
      previous_leaf_var_mu_samples <- NULL
    }
    if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) {
      previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples
    } else {
      previous_leaf_var_tau_samples <- NULL
    }
    if (previous_bcf_model$model_params$has_rfx) {
      previous_rfx_samples <- previous_bcf_model$rfx_samples
    } else {
      previous_rfx_samples <- NULL
    }
    if (previous_bcf_model$model_params$adaptive_coding) {
      previous_b_1_samples <- previous_bcf_model$b_1_samples
      previous_b_0_samples <- previous_bcf_model$b_0_samples
    } else {
      previous_b_1_samples <- NULL
      previous_b_0_samples <- NULL
    }
    previous_model_num_samples <- previous_bcf_model$model_params$num_samples
    if (previous_model_warmstart_sample_num > previous_model_num_samples) {
      stop(
        "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
      )
    }
  } else {
    previous_y_bar <- NULL
    previous_y_scale <- NULL
    previous_global_var_samples <- NULL
    previous_leaf_var_mu_samples <- NULL
    previous_leaf_var_tau_samples <- NULL
    previous_rfx_samples <- NULL
    previous_forest_samples_mu <- NULL
    previous_forest_samples_tau <- NULL
    previous_forest_samples_variance <- NULL
    previous_b_1_samples <- NULL
    previous_b_0_samples <- NULL
  }

  # Determine whether conditional variance will be modeled
  if (num_trees_variance > 0) {
    include_variance_forest = TRUE
  } else {
    include_variance_forest = FALSE
  }

  # Set the variance forest priors if not set
  if (include_variance_forest) {
    if (is.null(a_forest)) {
      a_forest <- num_trees_variance / (a_0^2) + 0.5
    }
    if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2)
  } else {
    a_forest <- 1.
    b_forest <- 1.
  }

  # Variable weight preprocessing (and initialization if necessary)
  if (is.null(variable_weights)) {
    variable_weights = rep(1 / ncol(X_train), ncol(X_train))
  }
  if (any(variable_weights < 0)) {
    stop("variable_weights cannot have any negative weights")
  }

  # Check covariates are matrix or dataframe
  if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
    stop("X_train must be a matrix or dataframe")
  }
  if (!is.null(X_test)) {
    if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
      stop("X_test must be a matrix or dataframe")
    }
  }
  num_cov_orig <- ncol(X_train)

  # Raise a warning if the data have ties and only GFR is being run
  if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
    num_values <- nrow(X_train)
    max_grid_size <- ifelse(
      num_values > cutpoint_grid_size,
      floor(num_values / cutpoint_grid_size),
      1
    )
    x_is_df <- is.data.frame(X_train)
    covs_warning_1 <- NULL
    covs_warning_2 <- NULL
    covs_warning_3 <- NULL
    covs_warning_4 <- NULL
    for (i in 1:num_cov_orig) {
      # Skip check for variables that are treated as categorical
      x_numeric <- T
      if (x_is_df) {
        if (is.factor(X_train[, i])) {
          x_numeric <- F
        }
      }

      if (x_numeric) {
        # Determine the number of unique values
        num_unique_values <- length(unique(X_train[, i]))

        # Determine a "name" for the covariate
        cov_name <- ifelse(
          is.null(colnames(X_train)),
          paste0("X", i),
          colnames(X_train)[i]
        )

        # Check for a small relative number of unique values
        unique_full_ratio <- num_unique_values / num_values
        if (unique_full_ratio < 0.2) {
          covs_warning_1 <- c(covs_warning_1, cov_name)
        }

        # Check for a small absolute number of unique values
        if (num_values > 100) {
          if (num_unique_values < 20) {
            covs_warning_2 <- c(covs_warning_2, cov_name)
          }
        }

        # Check for a large number of duplicates of any individual value
        x_j_hist <- table(X_train[, i])
        if (any(x_j_hist > 2 * max_grid_size)) {
          covs_warning_3 <- c(covs_warning_3, cov_name)
        }

        # Check for binary variables
        if (num_unique_values == 2) {
          covs_warning_4 <- c(covs_warning_4, cov_name)
        }
      }
    }

    if (!is.null(covs_warning_1)) {
      warning(
        paste0(
          "Covariate(s) ",
          paste(covs_warning_1, collapse = ", "),
          " have a ratio of unique to overall observations of less than 0.2. ",
          "This might present some issues with the grow-from-root (GFR) algorithm. ",
          "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
        )
      )
    }

    if (!is.null(covs_warning_2)) {
      warning(
        paste0(
          "Covariate(s) ",
          paste(covs_warning_2, collapse = ", "),
          " have fewer than 20 unique values. ",
          "This might present some issues with the grow-from-root (GFR) algorithm. ",
          "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
        )
      )
    }

    if (!is.null(covs_warning_3)) {
      warning(
        paste0(
          "Covariates ",
          paste(covs_warning_3, collapse = ", "),
          " have some observed values with more than ",
          2 * max_grid_size,
          " repeated observations. ",
          "This might present some issues with the grow-from-root (GFR) algorithm. ",
          "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
        )
      )
    }

    if (!is.null(covs_warning_4)) {
      warning(
        paste0(
          "Covariates ",
          paste(covs_warning_4, collapse = ", "),
          " appear to be binary but are currently treated by stochtree as continuous. ",
          "This might present some issues with the grow-from-root (GFR) algorithm. ",
          "Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
        )
      )
    }
  }

  # Check delta_max is valid
  if ((delta_max <= 0) || (delta_max >= 1)) {
    stop("delta_max must be > 0 and < 1")
  }

  # Standardize the keep variable lists to numeric indices
  if (!is.null(keep_vars_mu)) {
    if (is.character(keep_vars_mu)) {
      if (!all(keep_vars_mu %in% names(X_train))) {
        stop(
          "keep_vars_mu includes some variable names that are not in X_train"
        )
      }
      variable_subset_mu <- unname(which(
        names(X_train) %in% keep_vars_mu
      ))
    } else {
      if (any(keep_vars_mu > ncol(X_train))) {
        stop(
          "keep_vars_mu includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(keep_vars_mu < 0)) {
        stop("keep_vars_mu includes some negative variable indices")
      }
      variable_subset_mu <- keep_vars_mu
    }
  } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) {
    if (is.character(drop_vars_mu)) {
      if (!all(drop_vars_mu %in% names(X_train))) {
        stop(
          "drop_vars_mu includes some variable names that are not in X_train"
        )
      }
      variable_subset_mu <- unname(which(
        !(names(X_train) %in% drop_vars_mu)
      ))
    } else {
      if (any(drop_vars_mu > ncol(X_train))) {
        stop(
          "drop_vars_mu includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(drop_vars_mu < 0)) {
        stop("drop_vars_mu includes some negative variable indices")
      }
      variable_subset_mu <- (1:ncol(X_train))[
        !(1:ncol(X_train) %in% drop_vars_mu)
      ]
    }
  } else {
    variable_subset_mu <- 1:ncol(X_train)
  }
  if (!is.null(keep_vars_tau)) {
    if (is.character(keep_vars_tau)) {
      if (!all(keep_vars_tau %in% names(X_train))) {
        stop(
          "keep_vars_tau includes some variable names that are not in X_train"
        )
      }
      variable_subset_tau <- unname(which(
        names(X_train) %in% keep_vars_tau
      ))
    } else {
      if (any(keep_vars_tau > ncol(X_train))) {
        stop(
          "keep_vars_tau includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(keep_vars_tau < 0)) {
        stop("keep_vars_tau includes some negative variable indices")
      }
      variable_subset_tau <- keep_vars_tau
    }
  } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) {
    if (is.character(drop_vars_tau)) {
      if (!all(drop_vars_tau %in% names(X_train))) {
        stop(
          "drop_vars_tau includes some variable names that are not in X_train"
        )
      }
      variable_subset_tau <- unname(which(
        !(names(X_train) %in% drop_vars_tau)
      ))
    } else {
      if (any(drop_vars_tau > ncol(X_train))) {
        stop(
          "drop_vars_tau includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(drop_vars_tau < 0)) {
        stop("drop_vars_tau includes some negative variable indices")
      }
      variable_subset_tau <- (1:ncol(X_train))[
        !(1:ncol(X_train) %in% drop_vars_tau)
      ]
    }
  } else {
    variable_subset_tau <- 1:ncol(X_train)
  }
  if (!is.null(keep_vars_variance)) {
    if (is.character(keep_vars_variance)) {
      if (!all(keep_vars_variance %in% names(X_train))) {
        stop(
          "keep_vars_variance includes some variable names that are not in X_train"
        )
      }
      variable_subset_variance <- unname(which(
        names(X_train) %in% keep_vars_variance
      ))
    } else {
      if (any(keep_vars_variance > ncol(X_train))) {
        stop(
          "keep_vars_variance includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(keep_vars_variance < 0)) {
        stop(
          "keep_vars_variance includes some negative variable indices"
        )
      }
      variable_subset_variance <- keep_vars_variance
    }
  } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) {
    if (is.character(drop_vars_variance)) {
      if (!all(drop_vars_variance %in% names(X_train))) {
        stop(
          "drop_vars_variance includes some variable names that are not in X_train"
        )
      }
      variable_subset_variance <- unname(which(
        !(names(X_train) %in% drop_vars_variance)
      ))
    } else {
      if (any(drop_vars_variance > ncol(X_train))) {
        stop(
          "drop_vars_variance includes some variable indices that exceed the number of columns in X_train"
        )
      }
      if (any(drop_vars_variance < 0)) {
        stop(
          "drop_vars_variance includes some negative variable indices"
        )
      }
      variable_subset_variance <- (1:ncol(X_train))[
        !(1:ncol(X_train) %in% drop_vars_variance)
      ]
    }
  } else {
    variable_subset_variance <- 1:ncol(X_train)
  }

  # Preprocess covariates
  if (ncol(X_train) != length(variable_weights)) {
    stop("length(variable_weights) must equal ncol(X_train)")
  }
  train_cov_preprocess_list <- preprocessTrainData(X_train)
  X_train_metadata <- train_cov_preprocess_list$metadata
  X_train_raw <- X_train
  X_train <- train_cov_preprocess_list$data
  original_var_indices <- X_train_metadata$original_var_indices
  feature_types <- X_train_metadata$feature_types
  X_test_raw <- X_test
  if (!is.null(X_test)) {
    X_test <- preprocessPredictionData(X_test, X_train_metadata)
  }

  # Convert all input data to matrices if not already converted
  Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train))
  Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)
  if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
    propensity_train <- as.matrix(propensity_train)
  }
  if (!is.null(Z_test)) {
    Z_test <- matrix(as.numeric(Z_test), ncol = Z_col)
  }
  if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
    propensity_test <- as.matrix(propensity_test)
  }
  if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) {
    rfx_basis_train <- as.matrix(rfx_basis_train)
  }
  if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) {
    rfx_basis_test <- as.matrix(rfx_basis_test)
  }

  # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
  has_rfx <- FALSE
  has_rfx_test <- FALSE
  if (!is.null(rfx_group_ids_train)) {
    group_ids_factor <- factor(rfx_group_ids_train)
    rfx_group_ids_train <- as.integer(group_ids_factor)
    has_rfx <- TRUE
    if (!is.null(rfx_group_ids_test)) {
      group_ids_factor_test <- factor(
        rfx_group_ids_test,
        levels = levels(group_ids_factor)
      )
      if (sum(is.na(group_ids_factor_test)) > 0) {
        stop(
          "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train"
        )
      }
      rfx_group_ids_test <- as.integer(group_ids_factor_test)
      has_rfx_test <- TRUE
    }
  }

  # Check that outcome and treatment are numeric
  if (!is.numeric(y_train)) {
    stop("y_train must be numeric")
  }
  if (!is.numeric(Z_train)) {
    stop("Z_train must be numeric")
  }
  if (!is.null(Z_test)) {
    if (!is.numeric(Z_test)) stop("Z_test must be numeric")
  }

  # Data consistency checks
  if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
    stop("X_train and X_test must have the same number of columns")
  }
  if ((!is.null(Z_test)) && (ncol(Z_test) != ncol(Z_train))) {
    stop("Z_train and Z_test must have the same number of columns")
  }
  if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) {
    stop("Z_train and X_train must have the same number of rows")
  }
  if (
    (!is.null(propensity_train)) &&
      (nrow(propensity_train) != nrow(X_train))
  ) {
    stop("propensity_train and X_train must have the same number of rows")
  }
  if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) {
    stop("Z_test and X_test must have the same number of rows")
  }
  if ((!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test))) {
    stop("propensity_test and X_test must have the same number of rows")
  }
  if (nrow(X_train) != length(y_train)) {
    stop("X_train and y_train must have the same number of observations")
  }
  if (
    (!is.null(rfx_basis_test)) &&
      (ncol(rfx_basis_test) != ncol(rfx_basis_train))
  ) {
    stop(
      "rfx_basis_train and rfx_basis_test must have the same number of columns"
    )
  }
  if (!is.null(rfx_group_ids_train)) {
    if (!is.null(rfx_group_ids_test)) {
      if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) {
        stop(
          "rfx_basis_train is provided but rfx_basis_test is not provided"
        )
      }
    }
  }

  # # Stop if multivariate treatment is provided
  # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")

  # Handle multivariate treatment
  has_multivariate_treatment <- ncol(Z_train) > 1
  if (has_multivariate_treatment) {
    # Disable adaptive coding, internal propensity model, and
    # leaf scale sampling if treatment is multivariate
    if (adaptive_coding) {
      warning(
        "Adaptive coding is incompatible with multivariate treatment and will be ignored"
      )
      adaptive_coding <- FALSE
    }
    if (is.null(propensity_train)) {
      if (propensity_covariate != "none") {
        warning(
          "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'"
        )
        propensity_covariate <- "none"
      }
    }
    if (sample_sigma2_leaf_tau) {
      warning(
        "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model."
      )
      sample_sigma2_leaf_tau <- FALSE
    }
  }

  # Update variable weights
  variable_weights_adj <- 1 /
    sapply(original_var_indices, function(x) sum(original_var_indices == x))
  variable_weights <- variable_weights[original_var_indices] *
    variable_weights_adj

  # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables
  variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights
  variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0
  variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0
  if (include_variance_forest) {
    variable_weights_variance[
      !(original_var_indices %in% variable_subset_variance)
    ] <- 0
  }

  # Handle the rfx basis matrices
  has_basis_rfx <- FALSE
  num_basis_rfx <- 0
  if (has_rfx) {
    if (rfx_model_spec == "custom") {
      if (is.null(rfx_basis_train)) {
        stop(
          "A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'"
        )
      }
      has_basis_rfx <- TRUE
      num_basis_rfx <- ncol(rfx_basis_train)
    } else if (rfx_model_spec == "intercept_only") {
      rfx_basis_train <- matrix(
        rep(1, nrow(X_train)),
        nrow = nrow(X_train),
        ncol = 1
      )
      has_basis_rfx <- TRUE
      num_basis_rfx <- 1
    } else if (rfx_model_spec == "intercept_plus_treatment") {
      rfx_basis_train <- cbind(
        rep(1, nrow(X_train)),
        Z_train
      )
      has_basis_rfx <- TRUE
      num_basis_rfx <- 1 + ncol(Z_train)
    }
    num_rfx_groups <- length(unique(rfx_group_ids_train))
    num_rfx_components <- ncol(rfx_basis_train)
    if (num_rfx_groups == 1) {
      warning(
        "Only one group was provided for random effect sampling, so the random effects model is likely overkill"
      )
    }
  }
  if (has_rfx_test) {
    if (rfx_model_spec == "custom") {
      if (is.null(rfx_basis_test)) {
        stop(
          "A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'"
        )
      }
    } else if (rfx_model_spec == "intercept_only") {
      rfx_basis_test <- matrix(
        rep(1, nrow(X_test)),
        nrow = nrow(X_test),
        ncol = 1
      )
    } else if (rfx_model_spec == "intercept_plus_treatment") {
      rfx_basis_test <- cbind(
        rep(1, nrow(X_test)),
        Z_test
      )
    }
  }

  # Random effects covariance prior
  if (has_rfx) {
    if (is.null(rfx_prior_var)) {
      rfx_prior_var <- rep(1, ncol(rfx_basis_train))
    } else {
      if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) {
        stop("rfx_prior_var must be a numeric vector")
      }
      if (length(rfx_prior_var) != ncol(rfx_basis_train)) {
        stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)")
      }
    }
  }

  # Check that number of samples are all nonnegative
  stopifnot(num_gfr >= 0)
  stopifnot(num_burnin >= 0)
  stopifnot(num_mcmc >= 0)

  # Determine whether a test set is provided
  has_test = !is.null(X_test)

  # Convert y_train to numeric vector if not already converted
  if (!is.null(dim(y_train))) {
    y_train <- as.matrix(y_train)
  }

  # Check whether treatment is binary (specifically 0-1 binary)
  binary_treatment <- length(unique(Z_train)) == 2
  if (binary_treatment) {
    unique_treatments <- sort(unique(Z_train))
    if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
  }

  # Adaptive coding will be ignored for continuous / ordered categorical treatments
  if ((!binary_treatment) && (adaptive_coding)) {
    adaptive_coding <- FALSE
  }

  # Check if propensity_covariate is one of the required inputs
  if (
    !(propensity_covariate %in%
      c("prognostic", "treatment_effect", "both", "none"))
  ) {
    stop(
      "propensity_covariate must equal one of 'none', 'prognostic', 'treatment_effect', or 'both'"
    )
  }

  # Estimate if pre-estimated propensity score is not provided
  internal_propensity_model <- FALSE
  if ((is.null(propensity_train)) && (propensity_covariate != "none")) {
    internal_propensity_model <- TRUE
    # Estimate using the last of several iterations of GFR BART
    num_gfr_propensity <- 10
    num_burnin_propensity <- 0
    num_mcmc_propensity <- 10
    bart_model_propensity <- bart(
      X_train = X_train,
      y_train = as.numeric(Z_train),
      X_test = X_test_raw,
      num_gfr = num_gfr_propensity,
      num_burnin = num_burnin_propensity,
      num_mcmc = num_mcmc_propensity
    )
    propensity_train <- rowMeans(bart_model_propensity$y_hat_train)
    if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
      propensity_train <- as.matrix(propensity_train)
    }
    if (has_test) {
      propensity_test <- rowMeans(bart_model_propensity$y_hat_test)
      if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
        propensity_test <- as.matrix(propensity_test)
      }
    }
  }

  if (has_test) {
    if (is.null(propensity_test)) {
      stop(
        "Propensity score must be provided for the test set if provided for the training set"
      )
    }
  }

  # Update feature_types and covariates
  feature_types <- as.integer(feature_types)
  if (propensity_covariate != "none") {
    feature_types <- as.integer(c(
      feature_types,
      rep(0, ncol(propensity_train))
    ))
    X_train <- cbind(X_train, propensity_train)
    if (propensity_covariate == "prognostic") {
      variable_weights_mu <- c(
        variable_weights_mu,
        rep(1. / num_cov_orig, ncol(propensity_train))
      )
      variable_weights_tau <- c(
        variable_weights_tau,
        rep(0, ncol(propensity_train))
      )
      if (include_variance_forest) {
        variable_weights_variance <- c(
          variable_weights_variance,
          rep(0, ncol(propensity_train))
        )
      }
    } else if (propensity_covariate == "treatment_effect") {
      variable_weights_mu <- c(
        variable_weights_mu,
        rep(0, ncol(propensity_train))
      )
      variable_weights_tau <- c(
        variable_weights_tau,
        rep(1. / num_cov_orig, ncol(propensity_train))
      )
      if (include_variance_forest) {
        variable_weights_variance <- c(
          variable_weights_variance,
          rep(0, ncol(propensity_train))
        )
      }
    } else if (propensity_covariate == "both") {
      variable_weights_mu <- c(
        variable_weights_mu,
        rep(1. / num_cov_orig, ncol(propensity_train))
      )
      variable_weights_tau <- c(
        variable_weights_tau,
        rep(1. / num_cov_orig, ncol(propensity_train))
      )
      if (include_variance_forest) {
        variable_weights_variance <- c(
          variable_weights_variance,
          rep(0, ncol(propensity_train))
        )
      }
    }
    if (has_test) X_test <- cbind(X_test, propensity_test)
  }

  # Renormalize variable weights
  variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu)
  variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau)
  if (include_variance_forest) {
    variable_weights_variance <- variable_weights_variance /
      sum(variable_weights_variance)
  }

  # Set num_features_subsample to default, ncol(X_train), if not already set
  if (is.null(num_features_subsample_mu)) {
    num_features_subsample_mu <- ncol(X_train)
  }
  if (is.null(num_features_subsample_tau)) {
    num_features_subsample_tau <- ncol(X_train)
  }
  if (is.null(num_features_subsample_variance)) {
    num_features_subsample_variance <- ncol(X_train)
  }

  # Preliminary runtime checks for probit link
  if (probit_outcome_model) {
    if (!(length(unique(y_train)) == 2)) {
      stop(
        "You specified a probit outcome model, but supplied an outcome with more than 2 unique values"
      )
    }
    unique_outcomes <- sort(unique(y_train))
    if (!(all(unique_outcomes == c(0, 1)))) {
      stop(
        "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1"
      )
    }
    if (include_variance_forest) {
      stop("We do not support heteroskedasticity with a probit link")
    }
    if (sample_sigma2_global) {
      warning(
        "Global error variance will not be sampled with a probit link as it is fixed at 1"
      )
      sample_sigma2_global <- F
    }
  }

  # Runtime checks for variance forest
  if (include_variance_forest) {
    if (sample_sigma2_global) {
      warning(
        "Global error variance will not be sampled with a heteroskedasticity"
      )
      sample_sigma2_global <- F
    }
  }

  # Handle standardization, prior calibration, and initialization of forest
  # differently for binary and continuous outcomes
  if (probit_outcome_model) {
    # Compute a probit-scale offset and fix scale to 1
    y_bar_train <- qnorm(mean(y_train))
    y_std_train <- 1

    # Set a pseudo outcome by subtracting mean(y_train) from y_train
    resid_train <- y_train - mean(y_train)

    # Set initial value for the mu forest
    init_mu <- 0.0

    # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau
    # Set sigma2_init to 1, ignoring any defaults provided
    sigma2_init <- 1.0
    # Skip variance_forest_init, since variance forests are not supported with probit link
    if (is.null(b_leaf_mu)) {
      b_leaf_mu <- 1 / num_trees_mu
    }
    if (is.null(b_leaf_tau)) {
      b_leaf_tau <- 1 / (2 * num_trees_tau)
    }
    if (is.null(sigma2_leaf_mu)) {
      sigma2_leaf_mu <- 2 / (num_trees_mu)
      current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
    } else {
      if (!is.matrix(sigma2_leaf_mu)) {
        current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
      } else {
        current_leaf_scale_mu <- sigma2_leaf_mu
      }
    }
    if (is.null(sigma2_leaf_tau)) {
      # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p
      # Use p = 0.9 as an internal default rather than adding another
      # user-facing "parameter" of the binary outcome BCF prior.
      # Can be overriden by specifying `sigma2_leaf_init` in
      # treatment_effect_forest_params.
      p <- 0.6827
      q_quantile <- qnorm((p + 1) / 2)
      sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) /
        num_trees_tau
      current_leaf_scale_tau <- as.matrix(diag(
        sigma2_leaf_tau,
        ncol(Z_train)
      ))
    } else {
      if (!is.matrix(sigma2_leaf_tau)) {
        current_leaf_scale_tau <- as.matrix(diag(
          sigma2_leaf_tau,
          ncol(Z_train)
        ))
      } else {
        if (ncol(sigma2_leaf_tau) != ncol(Z_train)) {
          stop(
            "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix"
          )
        }
        if (nrow(sigma2_leaf_tau) != ncol(Z_train)) {
          stop(
            "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix"
          )
        }
        current_leaf_scale_tau <- sigma2_leaf_tau
      }
    }
    current_sigma2 <- sigma2_init
  } else {
    # Only standardize if user requested
    if (standardize) {
      y_bar_train <- mean(y_train)
      y_std_train <- sd(y_train)
    } else {
      y_bar_train <- 0
      y_std_train <- 1
    }

    # Compute standardized outcome
    resid_train <- (y_train - y_bar_train) / y_std_train

    # Set initial value for the mu forest
    init_mu <- mean(resid_train)

    # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau
    if (is.null(sigma2_init)) {
      sigma2_init <- 1.0 * var(resid_train)
    }
    if (is.null(variance_forest_init)) {
      variance_forest_init <- 1.0 * var(resid_train)
    }
    if (is.null(b_leaf_mu)) {
      b_leaf_mu <- var(resid_train) / (num_trees_mu)
    }
    if (is.null(b_leaf_tau)) {
      b_leaf_tau <- var(resid_train) / (2 * num_trees_tau)
    }
    if (is.null(sigma2_leaf_mu)) {
      sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu)
      current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
    } else {
      if (!is.matrix(sigma2_leaf_mu)) {
        current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
      } else {
        current_leaf_scale_mu <- sigma2_leaf_mu
      }
    }
    if (is.null(sigma2_leaf_tau)) {
      sigma2_leaf_tau <- var(resid_train) / (num_trees_tau)
      current_leaf_scale_tau <- as.matrix(diag(
        sigma2_leaf_tau,
        ncol(Z_train)
      ))
    } else {
      if (!is.matrix(sigma2_leaf_tau)) {
        current_leaf_scale_tau <- as.matrix(diag(
          sigma2_leaf_tau,
          ncol(Z_train)
        ))
      } else {
        if (ncol(sigma2_leaf_tau) != ncol(Z_train)) {
          stop(
            "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix"
          )
        }
        if (nrow(sigma2_leaf_tau) != ncol(Z_train)) {
          stop(
            "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix"
          )
        }
        current_leaf_scale_tau <- sigma2_leaf_tau
      }
    }
    current_sigma2 <- sigma2_init
  }

  # Set mu and tau leaf models / dimensions
  leaf_model_mu_forest <- 0
  leaf_dimension_mu_forest <- 1
  if (has_multivariate_treatment) {
    leaf_model_tau_forest <- 2
    leaf_dimension_tau_forest <- ncol(Z_train)
  } else {
    leaf_model_tau_forest <- 1
    leaf_dimension_tau_forest <- 1
  }

  # Set variance leaf model type (currently only one option)
  leaf_model_variance_forest <- 3
  leaf_dimension_variance_forest <- 1

  # Random effects prior parameters
  if (has_rfx) {
    # Prior parameters
    if (is.null(rfx_working_parameter_prior_mean)) {
      if (num_rfx_components == 1) {
        alpha_init <- c(0)
      } else if (num_rfx_components > 1) {
        alpha_init <- rep(0, num_rfx_components)
      } else {
        stop("There must be at least 1 random effect component")
      }
    } else {
      alpha_init <- expand_dims_1d(
        rfx_working_parameter_prior_mean,
        num_rfx_components
      )
    }

    if (is.null(rfx_group_parameter_prior_mean)) {
      xi_init <- matrix(
        rep(alpha_init, num_rfx_groups),
        num_rfx_components,
        num_rfx_groups
      )
    } else {
      xi_init <- expand_dims_2d(
        rfx_group_parameter_prior_mean,
        num_rfx_components,
        num_rfx_groups
      )
    }

    if (is.null(rfx_working_parameter_prior_cov)) {
      sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components)
    } else {
      sigma_alpha_init <- expand_dims_2d_diag(
        rfx_working_parameter_prior_cov,
        num_rfx_components
      )
    }

    if (is.null(rfx_group_parameter_prior_cov)) {
      sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components)
    } else {
      sigma_xi_init <- expand_dims_2d_diag(
        rfx_group_parameter_prior_cov,
        num_rfx_components
      )
    }

    sigma_xi_shape <- rfx_variance_prior_shape
    sigma_xi_scale <- rfx_variance_prior_scale
  }

  # Random effects data structure and storage container
  if (has_rfx) {
    rfx_dataset_train <- createRandomEffectsDataset(
      rfx_group_ids_train,
      rfx_basis_train
    )
    rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train)
    rfx_model <- createRandomEffectsModel(
      num_rfx_components,
      num_rfx_groups
    )
    rfx_model$set_working_parameter(alpha_init)
    rfx_model$set_group_parameters(xi_init)
    rfx_model$set_working_parameter_cov(sigma_alpha_init)
    rfx_model$set_group_parameter_cov(sigma_xi_init)
    rfx_model$set_variance_prior_shape(sigma_xi_shape)
    rfx_model$set_variance_prior_scale(sigma_xi_scale)
    rfx_samples <- createRandomEffectSamples(
      num_rfx_components,
      num_rfx_groups,
      rfx_tracker_train
    )
  }

  # Container of variance parameter samples
  num_actual_mcmc_iter <- num_mcmc * keep_every
  num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter
  # Delete GFR samples from these containers after the fact if desired
  # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc
  num_retained_samples <- num_gfr +
    ifelse(keep_burnin, num_burnin, 0) +
    num_mcmc * num_chains
  if (sample_sigma2_global) {
    global_var_samples <- rep(NA, num_retained_samples)
  }
  if (sample_sigma2_leaf_mu) {
    leaf_scale_mu_samples <- rep(NA, num_retained_samples)
  }
  if (sample_sigma2_leaf_tau) {
    leaf_scale_tau_samples <- rep(NA, num_retained_samples)
  }
  muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
  if (include_variance_forest) {
    sigma2_x_train_raw <- matrix(
      NA_real_,
      nrow(X_train),
      num_retained_samples
    )
  }
  sample_counter <- 0

  # Prepare adaptive coding structure
  if (
    (!is.numeric(b_0)) ||
      (!is.numeric(b_1)) ||
      (length(b_0) > 1) ||
      (length(b_1) > 1)
  ) {
    stop("b_0 and b_1 must be single numeric values")
  }
  if (adaptive_coding) {
    b_0_samples <- rep(NA, num_retained_samples)
    b_1_samples <- rep(NA, num_retained_samples)
    current_b_0 <- b_0
    current_b_1 <- b_1
    tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1
    if (has_test) {
      tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1
    }
  } else {
    tau_basis_train <- Z_train
    if (has_test) tau_basis_test <- Z_test
  }

  # Data
  forest_dataset_train <- createForestDataset(X_train, tau_basis_train)
  if (has_test) {
    forest_dataset_test <- createForestDataset(X_test, tau_basis_test)
  }
  outcome_train <- createOutcome(resid_train)

  # Random number generator (std::mt19937)
  if (is.null(random_seed)) {
    random_seed = sample(1:10000, 1, FALSE)
  }
  rng <- createCppRNG(random_seed)

  # Sampling data structures
  global_model_config <- createGlobalModelConfig(
    global_error_variance = current_sigma2
  )
  forest_model_config_mu <- createForestModelConfig(
    feature_types = feature_types,
    num_trees = num_trees_mu,
    num_features = ncol(X_train),
    num_observations = nrow(X_train),
    variable_weights = variable_weights_mu,
    leaf_dimension = leaf_dimension_mu_forest,
    alpha = alpha_mu,
    beta = beta_mu,
    min_samples_leaf = min_samples_leaf_mu,
    max_depth = max_depth_mu,
    leaf_model_type = leaf_model_mu_forest,
    leaf_model_scale = current_leaf_scale_mu,
    cutpoint_grid_size = cutpoint_grid_size,
    num_features_subsample = num_features_subsample_mu
  )
  forest_model_config_tau <- createForestModelConfig(
    feature_types = feature_types,
    num_trees = num_trees_tau,
    num_features = ncol(X_train),
    num_observations = nrow(X_train),
    variable_weights = variable_weights_tau,
    leaf_dimension = leaf_dimension_tau_forest,
    alpha = alpha_tau,
    beta = beta_tau,
    min_samples_leaf = min_samples_leaf_tau,
    max_depth = max_depth_tau,
    leaf_model_type = leaf_model_tau_forest,
    leaf_model_scale = current_leaf_scale_tau,
    cutpoint_grid_size = cutpoint_grid_size,
    num_features_subsample = num_features_subsample_tau
  )
  forest_model_mu <- createForestModel(
    forest_dataset_train,
    forest_model_config_mu,
    global_model_config
  )
  forest_model_tau <- createForestModel(
    forest_dataset_train,
    forest_model_config_tau,
    global_model_config
  )
  if (include_variance_forest) {
    forest_model_config_variance <- createForestModelConfig(
      feature_types = feature_types,
      num_trees = num_trees_variance,
      num_features = ncol(X_train),
      num_observations = nrow(X_train),
      variable_weights = variable_weights_variance,
      leaf_dimension = leaf_dimension_variance_forest,
      alpha = alpha_variance,
      beta = beta_variance,
      min_samples_leaf = min_samples_leaf_variance,
      max_depth = max_depth_variance,
      leaf_model_type = leaf_model_variance_forest,
      cutpoint_grid_size = cutpoint_grid_size,
      num_features_subsample = num_features_subsample_variance
    )
    forest_model_variance <- createForestModel(
      forest_dataset_train,
      forest_model_config_variance,
      global_model_config
    )
  }

  # Container of forest samples
  forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE)
  forest_samples_tau <- createForestSamples(
    num_trees_tau,
    ncol(Z_train),
    FALSE
  )
  active_forest_mu <- createForest(num_trees_mu, 1, TRUE)
  active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE)
  if (include_variance_forest) {
    forest_samples_variance <- createForestSamples(
      num_trees_variance,
      1,
      TRUE,
      TRUE
    )
    active_forest_variance <- createForest(
      num_trees_variance,
      1,
      TRUE,
      TRUE
    )
  }

  # Initialize the leaves of each tree in the prognostic forest
  active_forest_mu$prepare_for_sampler(
    forest_dataset_train,
    outcome_train,
    forest_model_mu,
    leaf_model_mu_forest,
    init_mu
  )
  active_forest_mu$adjust_residual(
    forest_dataset_train,
    outcome_train,
    forest_model_mu,
    FALSE,
    FALSE
  )

  # Initialize the leaves of each tree in the treatment effect forest
  init_tau <- rep(0., ncol(Z_train))
  active_forest_tau$prepare_for_sampler(
    forest_dataset_train,
    outcome_train,
    forest_model_tau,
    leaf_model_tau_forest,
    init_tau
  )
  active_forest_tau$adjust_residual(
    forest_dataset_train,
    outcome_train,
    forest_model_tau,
    TRUE,
    FALSE
  )

  # Initialize the leaves of each tree in the variance forest
  if (include_variance_forest) {
    active_forest_variance$prepare_for_sampler(
      forest_dataset_train,
      outcome_train,
      forest_model_variance,
      leaf_model_variance_forest,
      variance_forest_init
    )
  }

  # Run GFR (warm start) if specified
  if (num_gfr > 0) {
    for (i in 1:num_gfr) {
      # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC
      # keep_sample <- ifelse(keep_gfr, TRUE, FALSE)
      keep_sample <- TRUE
      if (keep_sample) {
        sample_counter <- sample_counter + 1
      }
      # Print progress
      if (verbose) {
        if ((i %% 10 == 0) || (i == num_gfr)) {
          cat(
            "Sampling",
            i,
            "out of",
            num_gfr,
            "XBCF (grow-from-root) draws\n"
          )
        }
      }

      if (probit_outcome_model) {
        # Sample latent probit variable, z | -
        mu_forest_pred <- active_forest_mu$predict(forest_dataset_train)
        tau_forest_pred <- active_forest_tau$predict(
          forest_dataset_train
        )
        outcome_pred <- mu_forest_pred + tau_forest_pred
        if (has_rfx) {
          rfx_pred <- rfx_model$predict(
            rfx_dataset_train,
            rfx_tracker_train
          )
          outcome_pred <- outcome_pred + rfx_pred
        }
        mu0 <- outcome_pred[y_train == 0]
        mu1 <- outcome_pred[y_train == 1]
        u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
        u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
        resid_train[y_train == 0] <- mu0 + qnorm(u0)
        resid_train[y_train == 1] <- mu1 + qnorm(u1)

        # Update outcome
        outcome_train$update_data(resid_train - outcome_pred)
      }

      # Sample the prognostic forest
      forest_model_mu$sample_one_iteration(
        forest_dataset = forest_dataset_train,
        residual = outcome_train,
        forest_samples = forest_samples_mu,
        active_forest = active_forest_mu,
        rng = rng,
        forest_model_config = forest_model_config_mu,
        global_model_config = global_model_config,
        num_threads = num_threads,
        keep_forest = keep_sample,
        gfr = TRUE
      )

      # Cache train set predictions since they are already computed during sampling
      if (keep_sample) {
        muhat_train_raw[,
          sample_counter
        ] <- forest_model_mu$get_cached_forest_predictions()
      }

      # Sample variance parameters (if requested)
      if (sample_sigma2_global) {
        current_sigma2 <- sampleGlobalErrorVarianceOneIteration(
          outcome_train,
          forest_dataset_train,
          rng,
          a_global,
          b_global
        )
        global_model_config$update_global_error_variance(current_sigma2)
      }
      if (sample_sigma2_leaf_mu) {
        leaf_scale_mu_double <- sampleLeafVarianceOneIteration(
          active_forest_mu,
          rng,
          a_leaf_mu,
          b_leaf_mu
        )
        current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
        if (keep_sample) {
          leaf_scale_mu_samples[
            sample_counter
          ] <- leaf_scale_mu_double
        }
        forest_model_config_mu$update_leaf_model_scale(
          current_leaf_scale_mu
        )
      }

      # Sample the treatment forest
      forest_model_tau$sample_one_iteration(
        forest_dataset = forest_dataset_train,
        residual = outcome_train,
        forest_samples = forest_samples_tau,
        active_forest = active_forest_tau,
        rng = rng,
        forest_model_config = forest_model_config_tau,
        global_model_config = global_model_config,
        num_threads = num_threads,
        keep_forest = keep_sample,
        gfr = TRUE
      )

      # Cannot cache train set predictions for tau because the cached predictions in the
      # tracking data structures are pre-multiplied by the basis (treatment)
      # ...

      # Sample coding parameters (if requested)
      if (adaptive_coding) {
        # Estimate mu(X) and tau(X) and compute y - mu(X)
        mu_x_raw_train <- active_forest_mu$predict_raw(
          forest_dataset_train
        )
        tau_x_raw_train <- active_forest_tau$predict_raw(
          forest_dataset_train
        )
        partial_resid_mu_train <- resid_train - mu_x_raw_train
        if (has_rfx) {
          rfx_preds_train <- rfx_model$predict(
            rfx_dataset_train,
            rfx_tracker_train
          )
          partial_resid_mu_train <- partial_resid_mu_train -
            rfx_preds_train
        }

        # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z]
        s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0))
        s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1))
        s_ty0 <- sum(
          tau_x_raw_train * partial_resid_mu_train * (Z_train == 0)
        )
        s_ty1 <- sum(
          tau_x_raw_train * partial_resid_mu_train * (Z_train == 1)
        )

        # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z)
        current_b_0 <- rnorm(
          1,
          (s_ty0 / (s_tt0 + 2 * current_sigma2)),
          sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2))
        )
        current_b_1 <- rnorm(
          1,
          (s_ty1 / (s_tt1 + 2 * current_sigma2)),
          sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2))
        )

        # Update basis for the leaf regression
        tau_basis_train <- (1 - Z_train) *
          current_b_0 +
          Z_train * current_b_1
        forest_dataset_train$update_basis(tau_basis_train)
        if (keep_sample) {
          b_0_samples[sample_counter] <- current_b_0
          b_1_samples[sample_counter] <- current_b_1
        }
        if (has_test) {
          tau_basis_test <- (1 - Z_test) *
            current_b_0 +
            Z_test * current_b_1
          forest_dataset_test$update_basis(tau_basis_test)
        }

        # Update leaf predictions and residual
        forest_model_tau$propagate_basis_update(
          forest_dataset_train,
          outcome_train,
          active_forest_tau
        )
      }

      # Sample variance parameters (if requested)
      if (include_variance_forest) {
        forest_model_variance$sample_one_iteration(
          forest_dataset = forest_dataset_train,
          residual = outcome_train,
          forest_samples = forest_samples_variance,
          active_forest = active_forest_variance,
          rng = rng,
          forest_model_config = forest_model_config_variance,
          global_model_config = global_model_config,
          num_threads = num_threads,
          keep_forest = keep_sample,
          gfr = TRUE
        )

        # Cache train set predictions since they are already computed during sampling
        if (keep_sample) {
          sigma2_x_train_raw[,
            sample_counter
          ] <- forest_model_variance$get_cached_forest_predictions()
        }
      }
      if (sample_sigma2_global) {
        current_sigma2 <- sampleGlobalErrorVarianceOneIteration(
          outcome_train,
          forest_dataset_train,
          rng,
          a_global,
          b_global
        )
        if (keep_sample) {
          global_var_samples[sample_counter] <- current_sigma2
        }
        global_model_config$update_global_error_variance(current_sigma2)
      }
      if (sample_sigma2_leaf_tau) {
        leaf_scale_tau_double <- sampleLeafVarianceOneIteration(
          active_forest_tau,
          rng,
          a_leaf_tau,
          b_leaf_tau
        )
        current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
        if (keep_sample) {
          leaf_scale_tau_samples[
            sample_counter
          ] <- leaf_scale_tau_double
        }
        forest_model_config_mu$update_leaf_model_scale(
          current_leaf_scale_mu
        )
      }

      # Sample random effects parameters (if requested)
      if (has_rfx) {
        rfx_model$sample_random_effect(
          rfx_dataset_train,
          outcome_train,
          rfx_tracker_train,
          rfx_samples,
          keep_sample,
          current_sigma2,
          rng
        )
      }
    }
  }

  # Run MCMC
  if (num_burnin + num_mcmc > 0) {
    for (chain_num in 1:num_chains) {
      if (num_gfr > 0) {
        # Reset state of active_forest and forest_model based on a previous GFR sample
        forest_ind <- num_gfr - chain_num
        resetActiveForest(
          active_forest_mu,
          forest_samples_mu,
          forest_ind
        )
        resetForestModel(
          forest_model_mu,
          active_forest_mu,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        resetActiveForest(
          active_forest_tau,
          forest_samples_tau,
          forest_ind
        )
        resetForestModel(
          forest_model_tau,
          active_forest_tau,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        if (sample_sigma2_leaf_mu) {
          leaf_scale_mu_double <- leaf_scale_mu_samples[
            forest_ind + 1
          ]
          current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
          forest_model_config_mu$update_leaf_model_scale(
            current_leaf_scale_mu
          )
        }
        if (sample_sigma2_leaf_tau) {
          leaf_scale_tau_double <- leaf_scale_tau_samples[
            forest_ind + 1
          ]
          current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
          forest_model_config_tau$update_leaf_model_scale(
            current_leaf_scale_tau
          )
        }
        if (include_variance_forest) {
          resetActiveForest(
            active_forest_variance,
            forest_samples_variance,
            forest_ind
          )
          resetForestModel(
            forest_model_variance,
            active_forest_variance,
            forest_dataset_train,
            outcome_train,
            FALSE
          )
        }
        if (has_rfx) {
          resetRandomEffectsModel(
            rfx_model,
            rfx_samples,
            forest_ind,
            sigma_alpha_init
          )
          resetRandomEffectsTracker(
            rfx_tracker_train,
            rfx_model,
            rfx_dataset_train,
            outcome_train,
            rfx_samples
          )
        }
        if (adaptive_coding) {
          current_b_1 <- b_1_samples[forest_ind + 1]
          current_b_0 <- b_0_samples[forest_ind + 1]
          tau_basis_train <- (1 - Z_train) *
            current_b_0 +
            Z_train * current_b_1
          forest_dataset_train$update_basis(tau_basis_train)
          if (has_test) {
            tau_basis_test <- (1 - Z_test) *
              current_b_0 +
              Z_test * current_b_1
            forest_dataset_test$update_basis(tau_basis_test)
          }
          forest_model_tau$propagate_basis_update(
            forest_dataset_train,
            outcome_train,
            active_forest_tau
          )
        }
        if (sample_sigma2_global) {
          current_sigma2 <- global_var_samples[forest_ind + 1]
          global_model_config$update_global_error_variance(
            current_sigma2
          )
        }
      } else if (has_prev_model) {
        warmstart_index <- ifelse(
          previous_model_decrement,
          previous_model_warmstart_sample_num - chain_num + 1,
          previous_model_warmstart_sample_num
        )
        resetActiveForest(
          active_forest_mu,
          previous_forest_samples_mu,
          warmstart_index - 1
        )
        resetForestModel(
          forest_model_mu,
          active_forest_mu,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        resetActiveForest(
          active_forest_tau,
          previous_forest_samples_tau,
          warmstart_index - 1
        )
        resetForestModel(
          forest_model_tau,
          active_forest_tau,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        if (include_variance_forest) {
          resetActiveForest(
            active_forest_variance,
            previous_forest_samples_variance,
            warmstart_index - 1
          )
          resetForestModel(
            forest_model_variance,
            active_forest_variance,
            forest_dataset_train,
            outcome_train,
            FALSE
          )
        }
        if (
          sample_sigma2_leaf_mu &&
            (!is.null(previous_leaf_var_mu_samples))
        ) {
          leaf_scale_mu_double <- previous_leaf_var_mu_samples[
            warmstart_index
          ]
          current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
          forest_model_config_mu$update_leaf_model_scale(
            current_leaf_scale_mu
          )
        }
        if (
          sample_sigma2_leaf_tau &&
            (!is.null(previous_leaf_var_tau_samples))
        ) {
          leaf_scale_tau_double <- previous_leaf_var_tau_samples[
            warmstart_index
          ]
          current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
          forest_model_config_tau$update_leaf_model_scale(
            current_leaf_scale_tau
          )
        }
        if (adaptive_coding) {
          if (!is.null(previous_b_1_samples)) {
            current_b_1 <- previous_b_1_samples[
              warmstart_index
            ]
          }
          if (!is.null(previous_b_0_samples)) {
            current_b_0 <- previous_b_0_samples[
              warmstart_index
            ]
          }
          tau_basis_train <- (1 - Z_train) *
            current_b_0 +
            Z_train * current_b_1
          forest_dataset_train$update_basis(tau_basis_train)
          if (has_test) {
            tau_basis_test <- (1 - Z_test) *
              current_b_0 +
              Z_test * current_b_1
            forest_dataset_test$update_basis(tau_basis_test)
          }
          forest_model_tau$propagate_basis_update(
            forest_dataset_train,
            outcome_train,
            active_forest_tau
          )
        }
        if (has_rfx) {
          if (is.null(previous_rfx_samples)) {
            warning(
              "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started"
            )
            rootResetRandomEffectsModel(
              rfx_model,
              alpha_init,
              xi_init,
              sigma_alpha_init,
              sigma_xi_init,
              sigma_xi_shape,
              sigma_xi_scale
            )
            rootResetRandomEffectsTracker(
              rfx_tracker_train,
              rfx_model,
              rfx_dataset_train,
              outcome_train
            )
          } else {
            resetRandomEffectsModel(
              rfx_model,
              previous_rfx_samples,
              warmstart_index - 1,
              sigma_alpha_init
            )
            resetRandomEffectsTracker(
              rfx_tracker_train,
              rfx_model,
              rfx_dataset_train,
              outcome_train,
              rfx_samples
            )
          }
        }
        if (sample_sigma2_global) {
          if (!is.null(previous_global_var_samples)) {
            current_sigma2 <- previous_global_var_samples[
              warmstart_index
            ]
          }
          global_model_config$update_global_error_variance(
            current_sigma2
          )
        }
      } else {
        resetActiveForest(active_forest_mu)
        active_forest_mu$set_root_leaves(init_mu / num_trees_mu)
        resetForestModel(
          forest_model_mu,
          active_forest_mu,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        resetActiveForest(active_forest_tau)
        active_forest_tau$set_root_leaves(init_tau / num_trees_tau)
        resetForestModel(
          forest_model_tau,
          active_forest_tau,
          forest_dataset_train,
          outcome_train,
          TRUE
        )
        if (sample_sigma2_leaf_mu) {
          current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu)
          forest_model_config_mu$update_leaf_model_scale(
            current_leaf_scale_mu
          )
        }
        if (sample_sigma2_leaf_tau) {
          current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau)
          forest_model_config_tau$update_leaf_model_scale(
            current_leaf_scale_tau
          )
        }
        if (include_variance_forest) {
          resetActiveForest(active_forest_variance)
          active_forest_variance$set_root_leaves(
            log(variance_forest_init) / num_trees_variance
          )
          resetForestModel(
            forest_model_variance,
            active_forest_variance,
            forest_dataset_train,
            outcome_train,
            FALSE
          )
        }
        if (has_rfx) {
          rootResetRandomEffectsModel(
            rfx_model,
            alpha_init,
            xi_init,
            sigma_alpha_init,
            sigma_xi_init,
            sigma_xi_shape,
            sigma_xi_scale
          )
          rootResetRandomEffectsTracker(
            rfx_tracker_train,
            rfx_model,
            rfx_dataset_train,
            outcome_train
          )
        }
        if (adaptive_coding) {
          current_b_1 <- b_1
          current_b_0 <- b_0
          tau_basis_train <- (1 - Z_train) *
            current_b_0 +
            Z_train * current_b_1
          forest_dataset_train$update_basis(tau_basis_train)
          if (has_test) {
            tau_basis_test <- (1 - Z_test) *
              current_b_0 +
              Z_test * current_b_1
            forest_dataset_test$update_basis(tau_basis_test)
          }
          forest_model_tau$propagate_basis_update(
            forest_dataset_train,
            outcome_train,
            active_forest_tau
          )
        }
        if (sample_sigma2_global) {
          current_sigma2 <- sigma2_init
          global_model_config$update_global_error_variance(
            current_sigma2
          )
        }
      }
      for (i in (num_gfr + 1):num_samples) {
        is_mcmc <- i > (num_gfr + num_burnin)
        if (is_mcmc) {
          mcmc_counter <- i - (num_gfr + num_burnin)
          if (mcmc_counter %% keep_every == 0) {
            keep_sample <- TRUE
          } else {
            keep_sample <- FALSE
          }
        } else {
          if (keep_burnin) {
            keep_sample <- TRUE
          } else {
            keep_sample <- FALSE
          }
        }
        if (keep_sample) {
          sample_counter <- sample_counter + 1
        }
        # Print progress
        if (verbose) {
          if (num_burnin > 0) {
            if (
              ((i - num_gfr) %% 100 == 0) ||
                ((i - num_gfr) == num_burnin)
            ) {
              cat(
                "Sampling",
                i - num_gfr,
                "out of",
                num_gfr,
                "BCF burn-in draws\n"
              )
            }
          }
          if (num_mcmc > 0) {
            if (
              ((i - num_gfr - num_burnin) %% 100 == 0) ||
                (i == num_samples)
            ) {
              cat(
                "Sampling",
                i - num_burnin - num_gfr,
                "out of",
                num_mcmc,
                "BCF MCMC draws\n"
              )
            }
          }
        }

        if (probit_outcome_model) {
          # Sample latent probit variable, z | -
          mu_forest_pred <- active_forest_mu$predict(
            forest_dataset_train
          )
          tau_forest_pred <- active_forest_tau$predict(
            forest_dataset_train
          )
          outcome_pred <- mu_forest_pred + tau_forest_pred
          if (has_rfx) {
            rfx_pred <- rfx_model$predict(
              rfx_dataset_train,
              rfx_tracker_train
            )
            outcome_pred <- outcome_pred + rfx_pred
          }
          mu0 <- outcome_pred[y_train == 0]
          mu1 <- outcome_pred[y_train == 1]
          u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
          u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
          resid_train[y_train == 0] <- mu0 + qnorm(u0)
          resid_train[y_train == 1] <- mu1 + qnorm(u1)

          # Update outcome
          outcome_train$update_data(resid_train - outcome_pred)
        }

        # Sample the prognostic forest
        forest_model_mu$sample_one_iteration(
          forest_dataset = forest_dataset_train,
          residual = outcome_train,
          forest_samples = forest_samples_mu,
          active_forest = active_forest_mu,
          rng = rng,
          forest_model_config = forest_model_config_mu,
          global_model_config = global_model_config,
          num_threads = num_threads,
          keep_forest = keep_sample,
          gfr = FALSE
        )

        # Cache train set predictions since they are already computed during sampling
        if (keep_sample) {
          muhat_train_raw[,
            sample_counter
          ] <- forest_model_mu$get_cached_forest_predictions()
        }

        # Sample variance parameters (if requested)
        if (sample_sigma2_global) {
          current_sigma2 <- sampleGlobalErrorVarianceOneIteration(
            outcome_train,
            forest_dataset_train,
            rng,
            a_global,
            b_global
          )
          global_model_config$update_global_error_variance(
            current_sigma2
          )
        }
        if (sample_sigma2_leaf_mu) {
          leaf_scale_mu_double <- sampleLeafVarianceOneIteration(
            active_forest_mu,
            rng,
            a_leaf_mu,
            b_leaf_mu
          )
          current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
          if (keep_sample) {
            leaf_scale_mu_samples[
              sample_counter
            ] <- leaf_scale_mu_double
          }
          forest_model_config_mu$update_leaf_model_scale(
            current_leaf_scale_mu
          )
        }

        # Sample the treatment forest
        forest_model_tau$sample_one_iteration(
          forest_dataset = forest_dataset_train,
          residual = outcome_train,
          forest_samples = forest_samples_tau,
          active_forest = active_forest_tau,
          rng = rng,
          forest_model_config = forest_model_config_tau,
          global_model_config = global_model_config,
          num_threads = num_threads,
          keep_forest = keep_sample,
          gfr = FALSE
        )

        # Cannot cache train set predictions for tau because the cached predictions in the
        # tracking data structures are pre-multiplied by the basis (treatment)
        # ...

        # Sample coding parameters (if requested)
        if (adaptive_coding) {
          # Estimate mu(X) and tau(X) and compute y - mu(X)
          mu_x_raw_train <- active_forest_mu$predict_raw(
            forest_dataset_train
          )
          tau_x_raw_train <- active_forest_tau$predict_raw(
            forest_dataset_train
          )
          partial_resid_mu_train <- resid_train - mu_x_raw_train
          if (has_rfx) {
            rfx_preds_train <- rfx_model$predict(
              rfx_dataset_train,
              rfx_tracker_train
            )
            partial_resid_mu_train <- partial_resid_mu_train -
              rfx_preds_train
          }

          # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z]
          s_tt0 <- sum(
            tau_x_raw_train * tau_x_raw_train * (Z_train == 0)
          )
          s_tt1 <- sum(
            tau_x_raw_train * tau_x_raw_train * (Z_train == 1)
          )
          s_ty0 <- sum(
            tau_x_raw_train *
              partial_resid_mu_train *
              (Z_train == 0)
          )
          s_ty1 <- sum(
            tau_x_raw_train *
              partial_resid_mu_train *
              (Z_train == 1)
          )

          # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z)
          current_b_0 <- rnorm(
            1,
            (s_ty0 / (s_tt0 + 2 * current_sigma2)),
            sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2))
          )
          current_b_1 <- rnorm(
            1,
            (s_ty1 / (s_tt1 + 2 * current_sigma2)),
            sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2))
          )

          # Update basis for the leaf regression
          tau_basis_train <- (1 - Z_train) *
            current_b_0 +
            Z_train * current_b_1
          forest_dataset_train$update_basis(tau_basis_train)
          if (keep_sample) {
            b_0_samples[sample_counter] <- current_b_0
            b_1_samples[sample_counter] <- current_b_1
          }
          if (has_test) {
            tau_basis_test <- (1 - Z_test) *
              current_b_0 +
              Z_test * current_b_1
            forest_dataset_test$update_basis(tau_basis_test)
          }

          # Update leaf predictions and residual
          forest_model_tau$propagate_basis_update(
            forest_dataset_train,
            outcome_train,
            active_forest_tau
          )
        }

        # Sample variance parameters (if requested)
        if (include_variance_forest) {
          forest_model_variance$sample_one_iteration(
            forest_dataset = forest_dataset_train,
            residual = outcome_train,
            forest_samples = forest_samples_variance,
            active_forest = active_forest_variance,
            rng = rng,
            forest_model_config = forest_model_config_variance,
            global_model_config = global_model_config,
            num_threads = num_threads,
            keep_forest = keep_sample,
            gfr = FALSE
          )

          # Cache train set predictions since they are already computed during sampling
          if (keep_sample) {
            sigma2_x_train_raw[,
              sample_counter
            ] <- forest_model_variance$get_cached_forest_predictions()
          }
        }
        if (sample_sigma2_global) {
          current_sigma2 <- sampleGlobalErrorVarianceOneIteration(
            outcome_train,
            forest_dataset_train,
            rng,
            a_global,
            b_global
          )
          if (keep_sample) {
            global_var_samples[sample_counter] <- current_sigma2
          }
          global_model_config$update_global_error_variance(
            current_sigma2
          )
        }
        if (sample_sigma2_leaf_tau) {
          leaf_scale_tau_double <- sampleLeafVarianceOneIteration(
            active_forest_tau,
            rng,
            a_leaf_tau,
            b_leaf_tau
          )
          current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
          if (keep_sample) {
            leaf_scale_tau_samples[
              sample_counter
            ] <- leaf_scale_tau_double
          }
          forest_model_config_tau$update_leaf_model_scale(
            current_leaf_scale_tau
          )
        }

        # Sample random effects parameters (if requested)
        if (has_rfx) {
          rfx_model$sample_random_effect(
            rfx_dataset_train,
            outcome_train,
            rfx_tracker_train,
            rfx_samples,
            keep_sample,
            current_sigma2,
            rng
          )
        }
      }
    }
  }

  # Remove GFR samples if they are not to be retained
  if ((!keep_gfr) && (num_gfr > 0)) {
    for (i in 1:num_gfr) {
      forest_samples_mu$delete_sample(0)
      forest_samples_tau$delete_sample(0)
      if (include_variance_forest) {
        forest_samples_variance$delete_sample(0)
      }
      if (has_rfx) {
        rfx_samples$delete_sample(0)
      }
    }
    if (sample_sigma2_global) {
      global_var_samples <- global_var_samples[
        (num_gfr + 1):length(global_var_samples)
      ]
    }
    if (sample_sigma2_leaf_mu) {
      leaf_scale_mu_samples <- leaf_scale_mu_samples[
        (num_gfr + 1):length(leaf_scale_mu_samples)
      ]
    }
    if (sample_sigma2_leaf_tau) {
      leaf_scale_tau_samples <- leaf_scale_tau_samples[
        (num_gfr + 1):length(leaf_scale_tau_samples)
      ]
    }
    if (adaptive_coding) {
      b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)]
      b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)]
    }
    muhat_train_raw <- muhat_train_raw[,
      (num_gfr + 1):ncol(muhat_train_raw)
    ]
    if (include_variance_forest) {
      sigma2_x_train_raw <- sigma2_x_train_raw[,
        (num_gfr + 1):ncol(sigma2_x_train_raw)
      ]
    }
    num_retained_samples <- num_retained_samples - num_gfr
  }

  # Forest predictions
  mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train
  if (adaptive_coding) {
    tau_hat_train_raw <- forest_samples_tau$predict_raw(
      forest_dataset_train
    )
    tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) *
      y_std_train
    control_adj_train <- t(t(tau_hat_train_raw) * b_0_samples) * y_std_train
    mu_hat_train <- mu_hat_train + control_adj_train
  } else {
    tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) *
      y_std_train
  }
  if (has_multivariate_treatment) {
    tau_train_dim <- dim(tau_hat_train)
    tau_num_obs <- tau_train_dim[1]
    tau_num_samples <- tau_train_dim[3]
    treatment_term_train <- matrix(
      NA_real_,
      nrow = tau_num_obs,
      tau_num_samples
    )
    for (i in 1:nrow(Z_train)) {
      treatment_term_train[i, ] <- colSums(
        tau_hat_train[i, , ] * Z_train[i, ]
      )
    }
  } else {
    treatment_term_train <- tau_hat_train * as.numeric(Z_train)
  }
  y_hat_train <- mu_hat_train + treatment_term_train
  if (has_test) {
    mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) *
      y_std_train +
      y_bar_train
    if (adaptive_coding) {
      tau_hat_test_raw <- forest_samples_tau$predict_raw(
        forest_dataset_test
      )
      tau_hat_test <- t(
        t(tau_hat_test_raw) * (b_1_samples - b_0_samples)
      ) *
        y_std_train
      control_adj_test <- t(t(tau_hat_test_raw) * b_0_samples) * y_std_train
      mu_hat_test <- mu_hat_test + control_adj_test
    } else {
      tau_hat_test <- forest_samples_tau$predict_raw(
        forest_dataset_test
      ) *
        y_std_train
    }
    if (has_multivariate_treatment) {
      tau_test_dim <- dim(tau_hat_test)
      tau_num_obs <- tau_test_dim[1]
      tau_num_samples <- tau_test_dim[3]
      treatment_term_test <- matrix(
        NA_real_,
        nrow = tau_num_obs,
        tau_num_samples
      )
      for (i in 1:nrow(Z_test)) {
        treatment_term_test[i, ] <- colSums(
          tau_hat_test[i, , ] * Z_test[i, ]
        )
      }
    } else {
      treatment_term_test <- tau_hat_test * as.numeric(Z_test)
    }
    y_hat_test <- mu_hat_test + treatment_term_test
  }
  if (include_variance_forest) {
    sigma2_x_hat_train <- exp(sigma2_x_train_raw)
    if (has_test) {
      sigma2_x_hat_test <- forest_samples_variance$predict(
        forest_dataset_test
      )
    }
  }

  # Random effects predictions
  if (has_rfx) {
    rfx_preds_train <- rfx_samples$predict(
      rfx_group_ids_train,
      rfx_basis_train
    ) *
      y_std_train
    y_hat_train <- y_hat_train + rfx_preds_train
  }
  if ((has_rfx_test) && (has_test)) {
    rfx_preds_test <- rfx_samples$predict(
      rfx_group_ids_test,
      rfx_basis_test
    ) *
      y_std_train
    y_hat_test <- y_hat_test + rfx_preds_test
  }

  # Global error variance
  if (sample_sigma2_global) {
    sigma2_global_samples <- global_var_samples * (y_std_train^2)
  }

  # Leaf parameter variance for prognostic forest
  if (sample_sigma2_leaf_mu) {
    sigma2_leaf_mu_samples <- leaf_scale_mu_samples
  }

  # Leaf parameter variance for treatment effect forest
  if (sample_sigma2_leaf_tau) {
    sigma2_leaf_tau_samples <- leaf_scale_tau_samples
  }

  # Rescale variance forest prediction by global sigma2 (sampled or constant)
  if (include_variance_forest) {
    if (sample_sigma2_global) {
      sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) {
        sigma2_x_hat_train[, i] * sigma2_global_samples[i]
      })
      if (has_test) {
        sigma2_x_hat_test <- sapply(
          1:num_retained_samples,
          function(i) {
            sigma2_x_hat_test[, i] * sigma2_global_samples[i]
          }
        )
      }
    } else {
      sigma2_x_hat_train <- sigma2_x_hat_train *
        sigma2_init *
        y_std_train *
        y_std_train
      if (has_test) {
        sigma2_x_hat_test <- sigma2_x_hat_test *
          sigma2_init *
          y_std_train *
          y_std_train
      }
    }
  }

  # Return results as a list
  if (include_variance_forest) {
    num_variance_covariates <- sum(variable_weights_variance > 0)
  } else {
    num_variance_covariates <- 0
  }
  model_params <- list(
    "initial_sigma2" = sigma2_init,
    "initial_sigma2_leaf_mu" = sigma2_leaf_mu,
    "initial_sigma2_leaf_tau" = sigma2_leaf_tau,
    "initial_b_0" = b_0,
    "initial_b_1" = b_1,
    "a_global" = a_global,
    "b_global" = b_global,
    "a_leaf_mu" = a_leaf_mu,
    "b_leaf_mu" = b_leaf_mu,
    "a_leaf_tau" = a_leaf_tau,
    "b_leaf_tau" = b_leaf_tau,
    "a_forest" = a_forest,
    "b_forest" = b_forest,
    "outcome_mean" = y_bar_train,
    "outcome_scale" = y_std_train,
    "standardize" = standardize,
    "num_covariates" = num_cov_orig,
    "num_prognostic_covariates" = sum(variable_weights_mu > 0),
    "num_treatment_covariates" = sum(variable_weights_tau > 0),
    "num_variance_covariates" = num_variance_covariates,
    "treatment_dim" = ncol(Z_train),
    "propensity_covariate" = propensity_covariate,
    "binary_treatment" = binary_treatment,
    "multivariate_treatment" = has_multivariate_treatment,
    "adaptive_coding" = adaptive_coding,
    "internal_propensity_model" = internal_propensity_model,
    "num_samples" = num_retained_samples,
    "num_gfr" = num_gfr,
    "num_burnin" = num_burnin,
    "num_mcmc" = num_mcmc,
    "keep_every" = keep_every,
    "num_chains" = num_chains,
    "has_rfx" = has_rfx,
    "has_rfx_basis" = has_basis_rfx,
    "num_rfx_basis" = num_basis_rfx,
    "include_variance_forest" = include_variance_forest,
    "sample_sigma2_global" = sample_sigma2_global,
    "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu,
    "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau,
    "probit_outcome_model" = probit_outcome_model,
    "rfx_model_spec" = rfx_model_spec
  )
  result <- list(
    "forests_mu" = forest_samples_mu,
    "forests_tau" = forest_samples_tau,
    "model_params" = model_params,
    "mu_hat_train" = mu_hat_train,
    "tau_hat_train" = tau_hat_train,
    "y_hat_train" = y_hat_train,
    "train_set_metadata" = X_train_metadata
  )
  if (has_test) {
    result[["mu_hat_test"]] = mu_hat_test
  }
  if (has_test) {
    result[["tau_hat_test"]] = tau_hat_test
  }
  if (has_test) {
    result[["y_hat_test"]] = y_hat_test
  }
  if (include_variance_forest) {
    result[["forests_variance"]] = forest_samples_variance
    result[["sigma2_x_hat_train"]] = sigma2_x_hat_train
    if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test
  }
  if (sample_sigma2_global) {
    result[["sigma2_global_samples"]] = sigma2_global_samples
  }
  if (sample_sigma2_leaf_mu) {
    result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples
  }
  if (sample_sigma2_leaf_tau) {
    result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples
  }
  if (adaptive_coding) {
    result[["b_0_samples"]] = b_0_samples
    result[["b_1_samples"]] = b_1_samples
  }
  if (has_rfx) {
    result[["rfx_samples"]] = rfx_samples
    result[["rfx_preds_train"]] = rfx_preds_train
    result[["rfx_unique_group_ids"]] = levels(group_ids_factor)
  }
  if ((has_rfx_test) && (has_test)) {
    result[["rfx_preds_test"]] = rfx_preds_test
  }
  if (internal_propensity_model) {
    result[["bart_propensity_model"]] = bart_model_propensity
  }
  class(result) <- "bcfmodel"

  # Restore global RNG state if user provided a random seed
  if (custom_rng) {
    .Random.seed <- original_global_seed
  }

  return(result)
}

#' Predict from a sampled BCF model on new data
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
#' @param Z Treatments used for prediction.
#' @param propensity (Optional) Propensities used for prediction.
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param ... (Optional) Other prediction parameters.
#'
#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' noise_sd <- 1
#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train, num_gfr = 10,
#'                  num_burnin = 0, num_mcmc = 10)
#' preds <- predict(bcf_model, X_test, Z_test, pi_test)
predict.bcfmodel <- function(
  object,
  X,
  Z,
  propensity = NULL,
  rfx_group_ids = NULL,
  rfx_basis = NULL,
  type = "posterior",
  terms = "all",
  scale = "linear",
  ...
) {
  # Handle mean function scale
  if (!is.character(scale)) {
    stop("scale must be a string or character vector")
  }
  if (!(scale %in% c("linear", "probability"))) {
    stop("scale must either be 'linear' or 'probability'")
  }
  is_probit <- object$model_params$probit_outcome_model
  if ((scale == "probability") && (!is_probit)) {
    stop(
      "scale cannot be 'probability' for models not fit with a probit outcome model"
    )
  }
  probability_scale <- scale == "probability"

  # Handle prediction type
  if (!is.character(type)) {
    stop("type must be a string or character vector")
  }
  if (!(type %in% c("mean", "posterior"))) {
    stop("type must either be 'mean' or 'posterior")
  }
  predict_mean <- type == "mean"

  # Warn users about CATE / prognostic function when rfx_model_spec is "custom"
  if (object$model_params$has_rfx) {
    if (object$model_params$rfx_model_spec == "custom") {
      if (("prognostic_function" %in% terms) || ("cate" %in% terms)) {
        warning(paste0(
          "This BCF model was fit with a custom random effects model specification (i.e. a user-provided basis). ",
          "As a result, 'prognostic_function' and 'cate' refer only to the prognostic ('mu') ",
          "and treatment effect 'tau' forests, respectively, and do not include any random ",
          "effects contributions. If your user-provided random effects basis includes a random intercept or a ",
          "random slope on the treatment variable, you will need to compute the prognostic or CATE functions manually by predicting ",
          "'yhat' for different covariate and rfx_basis values."
        ))
      }
    }
  }

  # Handle prediction terms
  rfx_model_spec = object$model_params$rfx_model_spec
  rfx_intercept_only <- rfx_model_spec == "intercept_only"
  rfx_intercept_plus_treatment <- rfx_model_spec == "intercept_plus_treatment"
  rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment
  mu_prog_separate <- ifelse(rfx_intercept, TRUE, FALSE)
  tau_cate_separate <- ifelse(rfx_intercept_plus_treatment, TRUE, FALSE)
  if (!is.character(terms)) {
    stop("type must be a string or character vector")
  }
  for (term in terms) {
    if (
      !(term %in%
        c(
          "y_hat",
          "prognostic_function",
          "mu",
          "cate",
          "tau",
          "rfx",
          "variance_forest",
          "all"
        ))
    ) {
      warning(paste0(
        "Term '",
        term,
        "' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'.",
        " This term will be ignored and prediction will only proceed if other requested terms are available in the model."
      ))
    }
  }
  num_terms <- length(terms)
  has_mu_forest <- T
  has_tau_forest <- T
  has_variance_forest <- object$model_params$include_variance_forest
  has_rfx <- object$model_params$has_rfx
  has_y_hat <- T
  predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) ||
    ((has_y_hat) && ("all" %in% terms)))
  predict_mu_forest <- (((has_mu_forest) && ("all" %in% terms)) ||
    ((has_mu_forest) && ("mu" %in% terms)))
  predict_tau_forest <- (((has_tau_forest) && ("tau" %in% terms)) ||
    ((has_tau_forest) && ("all" %in% terms)))
  predict_prog_function <- (((has_mu_forest) &&
    ("prognostic_function" %in% terms)) ||
    ((has_mu_forest) && ("all" %in% terms)))
  predict_cate_function <- (((has_tau_forest) && ("cate" %in% terms)) ||
    ((has_tau_forest) && ("all" %in% terms)))
  predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) ||
    ((has_rfx) && ("all" %in% terms)))
  predict_variance_forest <- (((has_variance_forest) &&
    ("variance_forest" %in% terms)) ||
    ((has_variance_forest) && ("all" %in% terms)))
  predict_count <- sum(c(
    predict_y_hat,
    predict_mu_forest,
    predict_prog_function,
    predict_tau_forest,
    predict_cate_function,
    predict_rfx,
    predict_variance_forest
  ))
  if (predict_count == 0) {
    warning(paste0(
      "None of the requested model terms, ",
      paste(terms, collapse = ", "),
      ", were fit in this model"
    ))
    return(NULL)
  }
  predict_rfx_intermediate <- (predict_y_hat && has_rfx)
  predict_rfx_raw <- ((predict_prog_function && has_rfx && rfx_intercept) ||
    (predict_cate_function && has_rfx && rfx_intercept_plus_treatment))
  predict_mu_forest_intermediate <- ((predict_y_hat || predict_prog_function) &&
    has_mu_forest)
  predict_tau_forest_intermediate <- ((predict_y_hat ||
    predict_cate_function) &&
    has_tau_forest)

  # Make sure covariates are matrix or data frame
  if ((!is.data.frame(X)) && (!is.matrix(X))) {
    stop("X must be a matrix or dataframe")
  }

  # Convert all input data to matrices if not already converted
  if ((is.null(dim(Z))) && (!is.null(Z))) {
    Z <- as.matrix(as.numeric(Z))
  }
  if ((is.null(dim(propensity))) && (!is.null(propensity))) {
    propensity <- as.matrix(propensity)
  }
  if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) {
    rfx_basis <- as.matrix(rfx_basis)
  }

  # Data checks
  if (
    (object$model_params$propensity_covariate != "none") &&
      (is.null(propensity))
  ) {
    if (!object$model_params$internal_propensity_model) {
      stop("propensity must be provided for this model")
    }
    # Compute propensity score using the internal bart model
    propensity <- rowMeans(predict(object$bart_propensity_model, X)$y_hat)
  }
  if (nrow(X) != nrow(Z)) {
    stop("X and Z must have the same number of rows")
  }
  if (object$model_params$num_covariates != ncol(X)) {
    stop(
      "X and must have the same number of columns as the covariates used to train the model"
    )
  }
  if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) {
    stop(
      "Random effect group labels (rfx_group_ids) must be provided for this model"
    )
  }
  if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) {
    if (object$model_params$rfx_model_spec == "custom") {
      stop("Random effects basis (rfx_basis) must be provided for this model")
    }
  }
  if ((object$model_params$num_rfx_basis > 0) && (!is.null(rfx_basis))) {
    if (ncol(rfx_basis) != object$model_params$num_rfx_basis) {
      stop(
        "Random effects basis has a different dimension than the basis used to train this model"
      )
    }
  }

  # Preprocess covariates
  train_set_metadata <- object$train_set_metadata
  X <- preprocessPredictionData(X, train_set_metadata)

  # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
  has_rfx <- FALSE
  if (!is.null(rfx_group_ids)) {
    rfx_unique_group_ids <- object$rfx_unique_group_ids
    group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
    if (sum(is.na(group_ids_factor)) > 0) {
      stop(
        "All random effect group labels provided in rfx_group_ids must have been present at sampling time"
      )
    }
    rfx_group_ids <- as.integer(group_ids_factor)
    has_rfx <- TRUE
  }

  # Handle RFX model specification
  if (has_rfx) {
    if (object$model_params$rfx_model_spec == "custom") {
      if (is.null(rfx_basis)) {
        stop(
          "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
        )
      }
    } else if (object$model_params$rfx_model_spec == "intercept_only") {
      # Only construct a basis if user-provided basis missing
      if (is.null(rfx_basis)) {
        rfx_basis <- matrix(
          rep(1, nrow(X)),
          nrow = nrow(X),
          ncol = 1
        )
      }
    } else if (
      object$model_params$rfx_model_spec == "intercept_plus_treatment"
    ) {
      # Only construct a basis if user-provided basis missing
      if (is.null(rfx_basis)) {
        rfx_basis <- cbind(
          rep(1, nrow(X)),
          Z
        )
      }
    }
  }

  # Add propensities to covariate set if necessary
  X_combined <- X
  if (object$model_params$propensity_covariate != "none") {
    X_combined <- cbind(X, propensity)
  }

  # Create prediction datasets
  forest_dataset_pred <- createForestDataset(X_combined, Z)

  # Compute variance forest predictions
  if (predict_variance_forest) {
    s_x_raw <- object$forests_variance$predict(forest_dataset_pred)
  }

  # Scale variance forest predictions
  num_samples <- object$model_params$num_samples
  y_std <- object$model_params$outcome_scale
  y_bar <- object$model_params$outcome_mean
  initial_sigma2 <- object$model_params$initial_sigma2
  if (predict_variance_forest) {
    if (object$model_params$sample_sigma2_global) {
      sigma2_global_samples <- object$sigma2_global_samples
      variance_forest_predictions <- sapply(1:num_samples, function(i) {
        s_x_raw[, i] * sigma2_global_samples[i]
      })
    } else {
      variance_forest_predictions <- s_x_raw *
        initial_sigma2 *
        y_std *
        y_std
    }
    if (predict_mean) {
      variance_forest_predictions <- rowMeans(variance_forest_predictions)
    }
  }

  # Compute mu forest predictions
  if (predict_mu_forest || predict_mu_forest_intermediate) {
    mu_hat_forest <- object$forests_mu$predict(forest_dataset_pred) *
      y_std +
      y_bar
  }

  # Compute CATE forest predictions
  if (predict_tau_forest || predict_tau_forest_intermediate) {
    if (object$model_params$adaptive_coding) {
      tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred)
      tau_hat_forest <- t(
        t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
      ) *
        y_std
      if (predict_mu_forest || predict_mu_forest_intermediate) {
        control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
        mu_hat_forest <- mu_hat_forest + control_adj
      }
    } else {
      tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
        y_std
    }
    if (object$model_params$multivariate_treatment) {
      tau_dim <- dim(tau_hat_forest)
      tau_num_obs <- tau_dim[1]
      tau_num_samples <- tau_dim[3]
      treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
      for (i in 1:nrow(Z)) {
        treatment_term[i, ] <- colSums(tau_hat_forest[i, , ] * Z[i, ])
      }
    } else {
      treatment_term <- tau_hat_forest * as.numeric(Z)
    }
  }

  # Compute rfx predictions
  if (predict_rfx || predict_rfx_intermediate) {
    rfx_predictions <- object$rfx_samples$predict(
      rfx_group_ids,
      rfx_basis
    ) *
      y_std
  }

  # Extract "raw" rfx coefficients for each rfx basis term if needed
  if (predict_rfx_raw) {
    # Extract the raw RFX samples and scale by train set outcome standard deviation
    rfx_param_list <- object$rfx_samples$extract_parameter_samples()
    rfx_beta_draws <- rfx_param_list$beta_samples *
      object$model_params$outcome_scale

    # Construct a matrix with the appropriate group random effects arranged for each observation
    rfx_predictions_raw <- array(
      NA,
      dim = c(
        nrow(X),
        ncol(rfx_basis),
        object$model_params$num_samples
      )
    )
    for (i in 1:nrow(X)) {
      rfx_predictions_raw[i, , ] <-
        rfx_beta_draws[, rfx_group_ids[i], ]
    }
  }

  # Add raw RFX predictions to mu and tau if warranted by the RFX model spec
  if (predict_prog_function) {
    if (mu_prog_separate) {
      prognostic_function <- mu_hat_forest + rfx_predictions_raw[, 1, ]
    } else {
      prognostic_function <- mu_hat_forest
    }
  }
  if (predict_cate_function) {
    if (tau_cate_separate) {
      cate <- (tau_hat_forest +
        rfx_predictions_raw[, 2:ncol(rfx_basis), ])
    } else {
      cate <- tau_hat_forest
    }
  }

  # Combine into y hat predictions
  needs_mean_term_preds <- predict_y_hat ||
    predict_mu_forest ||
    predict_tau_forest ||
    predict_prog_function ||
    predict_cate_function ||
    predict_rfx
  if (needs_mean_term_preds) {
    if (probability_scale) {
      if (has_rfx) {
        if (predict_y_hat) {
          y_hat <- pnorm(mu_hat_forest + treatment_term + rfx_predictions)
        }
        if (predict_rfx) {
          rfx_predictions <- pnorm(rfx_predictions)
        }
      } else {
        if (predict_y_hat) {
          y_hat <- pnorm(mu_hat_forest + treatment_term)
        }
      }
      if (predict_mu_forest) {
        mu_hat <- pnorm(mu_hat_forest)
      }
      if (predict_tau_forest) {
        tau_hat <- pnorm(tau_hat_forest)
      }
      if (predict_prog_function) {
        prognostic_function <- pnorm(prognostic_function)
      }
      if (predict_cate_function) {
        cate <- pnorm(cate)
      }
    } else {
      if (has_rfx) {
        if (predict_y_hat) {
          y_hat <- mu_hat_forest + treatment_term + rfx_predictions
        }
      } else {
        if (predict_y_hat) {
          y_hat <- mu_hat_forest + treatment_term
        }
      }
      if (predict_mu_forest) {
        mu_hat <- mu_hat_forest
      }
      if (predict_tau_forest) {
        tau_hat <- tau_hat_forest
      }
      if (predict_prog_function) {
        prognostic_function <- prognostic_function
      }
      if (predict_cate_function) {
        cate <- cate
      }
    }
  }

  # Collapse to posterior mean predictions if requested
  if (predict_mean) {
    if (predict_mu_forest) {
      mu_hat <- rowMeans(mu_hat)
    }
    if (predict_tau_forest) {
      if (object$model_params$multivariate_treatment) {
        tau_hat <- apply(tau_hat, c(1, 2), mean)
      } else {
        tau_hat <- rowMeans(tau_hat)
      }
    }
    if (predict_prog_function) {
      prognostic_function <- rowMeans(prognostic_function)
    }
    if (predict_cate_function) {
      if (object$model_params$multivariate_treatment) {
        cate <- apply(cate, c(1, 2), mean)
      } else {
        cate <- rowMeans(cate)
      }
    }
    if (predict_rfx) {
      rfx_predictions <- rowMeans(rfx_predictions)
    }
    if (predict_y_hat) {
      y_hat <- rowMeans(y_hat)
    }
  }

  # Return results
  if (predict_count == 1) {
    if (predict_y_hat) {
      return(y_hat)
    } else if (predict_mu_forest) {
      return(mu_hat)
    } else if (predict_tau_forest) {
      return(tau_hat)
    } else if (predict_prog_function) {
      return(prognostic_function)
    } else if (predict_cate_function) {
      return(cate)
    } else if (predict_rfx) {
      return(rfx_predictions)
    } else if (predict_variance_forest) {
      return(variance_forest_predictions)
    }
  } else {
    result <- list()
    if (predict_y_hat) {
      result[["y_hat"]] = y_hat
    } else {
      result[["y_hat"]] <- NULL
    }
    if (predict_mu_forest) {
      result[["mu_hat"]] = mu_hat
    } else {
      result[["mu_hat"]] <- NULL
    }
    if (predict_tau_forest) {
      result[["tau_hat"]] = tau_hat
    } else {
      result[["tau_hat"]] <- NULL
    }
    if (predict_prog_function) {
      result[["prognostic_function"]] = prognostic_function
    } else {
      result[["prognostic_function"]] <- NULL
    }
    if (predict_cate_function) {
      result[["cate"]] = cate
    } else {
      result[["cate"]] <- NULL
    }
    if (predict_rfx) {
      result[["rfx_predictions"]] = rfx_predictions
    } else {
      result[["rfx_predictions"]] <- NULL
    }
    if (predict_variance_forest) {
      result[["variance_forest_predictions"]] = variance_forest_predictions
    } else {
      result[["variance_forest_predictions"]] <- NULL
    }
  }
  return(result)
}

#' Extract raw sample values for each of the random effect parameter terms.
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @param ... Other parameters to be used in random effects extraction
#' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
#' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and is simply a matrix if `num_components = 1`.
#' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' rfx_samples <- getRandomEffectSamples(bcf_model)
getRandomEffectSamples.bcfmodel <- function(object, ...) {
  result = list()

  if (!object$model_params$has_rfx) {
    warning("This model has no RFX terms, returning an empty list")
    return(result)
  }

  # Extract the samples
  result <- object$rfx_samples$extract_parameter_samples()

  # Scale by sd(y_train)
  result$beta_samples <- result$beta_samples *
    object$model_params$outcome_scale
  result$xi_samples <- result$xi_samples * object$model_params$outcome_scale
  result$alpha_samples <- result$alpha_samples *
    object$model_params$outcome_scale
  result$sigma_samples <- result$sigma_samples *
    (object$model_params$outcome_scale^2)

  return(result)
}

#' Convert the persistent aspects of a BCF model to (in-memory) JSON
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#'
#' @return Object of type `CppJson`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' bcf_json <- saveBCFModelToJson(bcf_model)
saveBCFModelToJson <- function(object) {
  jsonobj <- createCppJson()

  if (!inherits(object, "bcfmodel")) {
    stop("`object` must be a BCF model")
  }

  if (is.null(object$model_params)) {
    stop("This BCF model has not yet been sampled")
  }

  # Add the forests
  jsonobj$add_forest(object$forests_mu)
  jsonobj$add_forest(object$forests_tau)
  if (object$model_params$include_variance_forest) {
    jsonobj$add_forest(object$forests_variance)
  }

  # Add metadata
  jsonobj$add_scalar(
    "num_numeric_vars",
    object$train_set_metadata$num_numeric_vars
  )
  jsonobj$add_scalar(
    "num_ordered_cat_vars",
    object$train_set_metadata$num_ordered_cat_vars
  )
  jsonobj$add_scalar(
    "num_unordered_cat_vars",
    object$train_set_metadata$num_unordered_cat_vars
  )
  if (object$train_set_metadata$num_numeric_vars > 0) {
    jsonobj$add_string_vector(
      "numeric_vars",
      object$train_set_metadata$numeric_vars
    )
  }
  if (object$train_set_metadata$num_ordered_cat_vars > 0) {
    jsonobj$add_string_vector(
      "ordered_cat_vars",
      object$train_set_metadata$ordered_cat_vars
    )
    jsonobj$add_string_list(
      "ordered_unique_levels",
      object$train_set_metadata$ordered_unique_levels
    )
  }
  if (object$train_set_metadata$num_unordered_cat_vars > 0) {
    jsonobj$add_string_vector(
      "unordered_cat_vars",
      object$train_set_metadata$unordered_cat_vars
    )
    jsonobj$add_string_list(
      "unordered_unique_levels",
      object$train_set_metadata$unordered_unique_levels
    )
  }

  # Add global parameters
  jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
  jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
  jsonobj$add_boolean("standardize", object$model_params$standardize)
  jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2)
  jsonobj$add_boolean(
    "sample_sigma2_global",
    object$model_params$sample_sigma2_global
  )
  jsonobj$add_boolean(
    "sample_sigma2_leaf_mu",
    object$model_params$sample_sigma2_leaf_mu
  )
  jsonobj$add_boolean(
    "sample_sigma2_leaf_tau",
    object$model_params$sample_sigma2_leaf_tau
  )
  jsonobj$add_boolean(
    "include_variance_forest",
    object$model_params$include_variance_forest
  )
  jsonobj$add_string(
    "propensity_covariate",
    object$model_params$propensity_covariate
  )
  jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
  jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
  jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
  jsonobj$add_boolean(
    "multivariate_treatment",
    object$model_params$multivariate_treatment
  )
  jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding)
  jsonobj$add_boolean(
    "internal_propensity_model",
    object$model_params$internal_propensity_model
  )
  jsonobj$add_scalar("num_gfr", object$model_params$num_gfr)
  jsonobj$add_scalar("num_burnin", object$model_params$num_burnin)
  jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc)
  jsonobj$add_scalar("num_samples", object$model_params$num_samples)
  jsonobj$add_scalar("keep_every", object$model_params$keep_every)
  jsonobj$add_scalar("num_chains", object$model_params$num_chains)
  jsonobj$add_scalar("num_covariates", object$model_params$num_covariates)
  jsonobj$add_boolean(
    "probit_outcome_model",
    object$model_params$probit_outcome_model
  )
  if (object$model_params$sample_sigma2_global) {
    jsonobj$add_vector(
      "sigma2_global_samples",
      object$sigma2_global_samples,
      "parameters"
    )
  }
  if (object$model_params$sample_sigma2_leaf_mu) {
    jsonobj$add_vector(
      "sigma2_leaf_mu_samples",
      object$sigma2_leaf_mu_samples,
      "parameters"
    )
  }
  if (object$model_params$sample_sigma2_leaf_tau) {
    jsonobj$add_vector(
      "sigma2_leaf_tau_samples",
      object$sigma2_leaf_tau_samples,
      "parameters"
    )
  }
  if (object$model_params$adaptive_coding) {
    jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters")
    jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters")
  }

  # Add random effects (if present)
  if (object$model_params$has_rfx) {
    jsonobj$add_random_effects(object$rfx_samples)
    jsonobj$add_string_vector(
      "rfx_unique_group_ids",
      object$rfx_unique_group_ids
    )
  }
  jsonobj$add_string(
    "rfx_model_spec",
    object$model_params$rfx_model_spec
  )

  # Add propensity model (if it exists)
  if (object$model_params$internal_propensity_model) {
    bart_propensity_string <- saveBARTModelToJsonString(
      object$bart_propensity_model
    )
    jsonobj$add_string("bart_propensity_model", bart_propensity_string)
  }

  # Add covariate preprocessor metadata
  preprocessor_metadata_string <- savePreprocessorToJsonString(
    object$train_set_metadata
  )
  jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)

  return(jsonobj)
}

#' Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @param filename String of filepath, must end in ".json"
#'
#' @return in-memory JSON string
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' tmpjson <- tempfile(fileext = ".json")
#' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson))
#' unlink(tmpjson)
saveBCFModelToJsonFile <- function(object, filename) {
  # Convert to Json
  jsonobj <- saveBCFModelToJson(object)

  # Save to file
  jsonobj$save_file(filename)
}

#' Convert the persistent aspects of a BCF model to (in-memory) JSON string
#'
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
#' @return JSON string
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' saveBCFModelToJsonString(bcf_model)
saveBCFModelToJsonString <- function(object) {
  # Convert to Json
  jsonobj <- saveBCFModelToJson(object)

  # Dump to string
  return(jsonobj$return_json_string())
}

#' Convert an (in-memory) JSON representation of a BCF model to a BCF model object
#' which can be used for prediction, etc...
#'
#' @param json_object Object of type `CppJson` containing Json representation of a BCF model
#'
#' @return Object of type `bcfmodel`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' bcf_json <- saveBCFModelToJson(bcf_model)
#' bcf_model_roundtrip <- createBCFModelFromJson(bcf_json)
createBCFModelFromJson <- function(json_object) {
  # Initialize the BCF model
  output <- list()

  # Unpack the forests
  output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0")
  output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1")
  include_variance_forest <- json_object$get_boolean(
    "include_variance_forest"
  )
  if (include_variance_forest) {
    output[["forests_variance"]] <- loadForestContainerJson(
      json_object,
      "forest_2"
    )
  }

  # Unpack metadata
  train_set_metadata = list()
  train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar(
    "num_numeric_vars"
  )
  train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar(
    "num_ordered_cat_vars"
  )
  train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar(
    "num_unordered_cat_vars"
  )
  if (train_set_metadata[["num_numeric_vars"]] > 0) {
    train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector(
      "numeric_vars"
    )
  }
  if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "ordered_cat_vars"
    ]] <- json_object$get_string_vector("ordered_cat_vars")
    train_set_metadata[[
      "ordered_unique_levels"
    ]] <- json_object$get_string_list(
      "ordered_unique_levels",
      train_set_metadata[["ordered_cat_vars"]]
    )
  }
  if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "unordered_cat_vars"
    ]] <- json_object$get_string_vector("unordered_cat_vars")
    train_set_metadata[[
      "unordered_unique_levels"
    ]] <- json_object$get_string_list(
      "unordered_unique_levels",
      train_set_metadata[["unordered_cat_vars"]]
    )
  }
  output[["train_set_metadata"]] <- train_set_metadata

  # Unpack model params
  model_params = list()
  model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
  model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
  model_params[["standardize"]] <- json_object$get_boolean("standardize")
  model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2")
  model_params[["sample_sigma2_global"]] <- json_object$get_boolean(
    "sample_sigma2_global"
  )
  model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean(
    "sample_sigma2_leaf_mu"
  )
  model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean(
    "sample_sigma2_leaf_tau"
  )
  model_params[["include_variance_forest"]] <- include_variance_forest
  model_params[["propensity_covariate"]] <- json_object$get_string(
    "propensity_covariate"
  )
  model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx")
  model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
  model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
  model_params[["adaptive_coding"]] <- json_object$get_boolean(
    "adaptive_coding"
  )
  model_params[["multivariate_treatment"]] <- json_object$get_boolean(
    "multivariate_treatment"
  )
  model_params[["internal_propensity_model"]] <- json_object$get_boolean(
    "internal_propensity_model"
  )
  model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
  model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
  model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
  model_params[["num_samples"]] <- json_object$get_scalar("num_samples")
  model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates")
  model_params[["probit_outcome_model"]] <- json_object$get_boolean(
    "probit_outcome_model"
  )
  model_params[["rfx_model_spec"]] <- json_object$get_string(
    "rfx_model_spec"
  )
  output[["model_params"]] <- model_params

  # Unpack sampled parameters
  if (model_params[["sample_sigma2_global"]]) {
    output[["sigma2_global_samples"]] <- json_object$get_vector(
      "sigma2_global_samples",
      "parameters"
    )
  }
  if (model_params[["sample_sigma2_leaf_mu"]]) {
    output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector(
      "sigma2_leaf_mu_samples",
      "parameters"
    )
  }
  if (model_params[["sample_sigma2_leaf_tau"]]) {
    output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector(
      "sigma2_leaf_tau_samples",
      "parameters"
    )
  }
  if (model_params[["adaptive_coding"]]) {
    output[["b_1_samples"]] <- json_object$get_vector(
      "b_1_samples",
      "parameters"
    )
    output[["b_0_samples"]] <- json_object$get_vector(
      "b_0_samples",
      "parameters"
    )
  }

  # Unpack random effects
  if (model_params[["has_rfx"]]) {
    output[["rfx_unique_group_ids"]] <- json_object$get_string_vector(
      "rfx_unique_group_ids"
    )
    output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0)
  }

  # Unpack propensity model (if it exists)
  if (model_params[["internal_propensity_model"]]) {
    bart_propensity_string <- json_object$get_string(
      "bart_propensity_model"
    )
    output[["bart_propensity_model"]] <- createBARTModelFromJsonString(
      bart_propensity_string
    )
  }

  # Unpack covariate preprocessor
  preprocessor_metadata_string <- json_object$get_string(
    "preprocessor_metadata"
  )
  output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
    preprocessor_metadata_string
  )

  class(output) <- "bcfmodel"
  return(output)
}

#' Convert a JSON file containing sample information on a trained BCF model
#' to a BCF model object which can be used for prediction, etc...
#'
#' @param json_filename String of filepath, must end in ".json"
#'
#' @return Object of type `bcfmodel`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' mu_params <- list(sample_sigma2_leaf = TRUE)
#' tau_params <- list(sample_sigma2_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10,
#'                  prognostic_forest_params = mu_params,
#'                  treatment_effect_forest_params = tau_params)
#' tmpjson <- tempfile(fileext = ".json")
#' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson))
#' bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson))
#' unlink(tmpjson)
createBCFModelFromJsonFile <- function(json_filename) {
  # Load a `CppJson` object from file
  bcf_json <- createCppJsonFile(json_filename)

  # Create and return the BCF object
  bcf_object <- createBCFModelFromJson(bcf_json)

  return(bcf_object)
}

#' Convert a JSON string containing sample information on a trained BCF model
#' to a BCF model object which can be used for prediction, etc...
#'
#' @param json_string JSON string dump
#'
#' @return Object of type `bcfmodel`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bcf_json <- saveBCFModelToJsonString(bcf_model)
#' bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json)
createBCFModelFromJsonString <- function(json_string) {
  # Load a `CppJson` object from string
  bcf_json <- createCppJsonString(json_string)

  # Create and return the BCF object
  bcf_object <- createBCFModelFromJson(bcf_json)

  return(bcf_object)
}

#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object
#' which can be used for prediction, etc...
#'
#' @param json_object_list List of objects of type `CppJson` containing Json representation of a BCF model
#'
#' @return Object of type `bcfmodel`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bcf_json_list <- list(saveBCFModelToJson(bcf_model))
#' bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list)
createBCFModelFromCombinedJson <- function(json_object_list) {
  # Initialize the BCF model
  output <- list()

  # For scalar / preprocessing details which aren't sample-dependent,
  # defer to the first json
  json_object_default <- json_object_list[[1]]

  # Unpack the forests
  output[["forests_mu"]] <- loadForestContainerCombinedJson(
    json_object_list,
    "forest_0"
  )
  output[["forests_tau"]] <- loadForestContainerCombinedJson(
    json_object_list,
    "forest_1"
  )
  include_variance_forest <- json_object_default$get_boolean(
    "include_variance_forest"
  )
  if (include_variance_forest) {
    output[["forests_variance"]] <- loadForestContainerCombinedJson(
      json_object_list,
      "forest_2"
    )
  }

  # Unpack metadata
  train_set_metadata = list()
  train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar(
    "num_numeric_vars"
  )
  train_set_metadata[[
    "num_ordered_cat_vars"
  ]] <- json_object_default$get_scalar("num_ordered_cat_vars")
  train_set_metadata[[
    "num_unordered_cat_vars"
  ]] <- json_object_default$get_scalar("num_unordered_cat_vars")
  if (train_set_metadata[["num_numeric_vars"]] > 0) {
    train_set_metadata[[
      "numeric_vars"
    ]] <- json_object_default$get_string_vector("numeric_vars")
  }
  if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "ordered_cat_vars"
    ]] <- json_object_default$get_string_vector("ordered_cat_vars")
    train_set_metadata[[
      "ordered_unique_levels"
    ]] <- json_object_default$get_string_list(
      "ordered_unique_levels",
      train_set_metadata[["ordered_cat_vars"]]
    )
  }
  if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "unordered_cat_vars"
    ]] <- json_object_default$get_string_vector("unordered_cat_vars")
    train_set_metadata[[
      "unordered_unique_levels"
    ]] <- json_object_default$get_string_list(
      "unordered_unique_levels",
      train_set_metadata[["unordered_cat_vars"]]
    )
  }
  output[["train_set_metadata"]] <- train_set_metadata

  # Unpack model params
  model_params = list()
  model_params[["outcome_scale"]] <- json_object_default$get_scalar(
    "outcome_scale"
  )
  model_params[["outcome_mean"]] <- json_object_default$get_scalar(
    "outcome_mean"
  )
  model_params[["standardize"]] <- json_object_default$get_boolean(
    "standardize"
  )
  model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
    "initial_sigma2"
  )
  model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean(
    "sample_sigma2_global"
  )
  model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean(
    "sample_sigma2_leaf_mu"
  )
  model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean(
    "sample_sigma2_leaf_tau"
  )
  model_params[["include_variance_forest"]] <- include_variance_forest
  model_params[["propensity_covariate"]] <- json_object_default$get_string(
    "propensity_covariate"
  )
  model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
  model_params[["has_rfx_basis"]] <- json_object_default$get_boolean(
    "has_rfx_basis"
  )
  model_params[["num_rfx_basis"]] <- json_object_default$get_scalar(
    "num_rfx_basis"
  )
  model_params[["num_covariates"]] <- json_object_default$get_scalar(
    "num_covariates"
  )
  model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
  model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
  model_params[["adaptive_coding"]] <- json_object_default$get_boolean(
    "adaptive_coding"
  )
  model_params[["multivariate_treatment"]] <- json_object_default$get_boolean(
    "multivariate_treatment"
  )
  model_params[[
    "internal_propensity_model"
  ]] <- json_object_default$get_boolean("internal_propensity_model")
  model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
    "probit_outcome_model"
  )
  model_params[["rfx_model_spec"]] <- json_object_default$get_string(
    "rfx_model_spec"
  )

  # Combine values that are sample-specific
  for (i in 1:length(json_object_list)) {
    json_object <- json_object_list[[i]]
    if (i == 1) {
      model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
      model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
      model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
      model_params[["num_samples"]] <- json_object$get_scalar(
        "num_samples"
      )
    } else {
      prev_json <- json_object_list[[i - 1]]
      model_params[["num_gfr"]] <- model_params[["num_gfr"]] +
        json_object$get_scalar("num_gfr")
      model_params[["num_burnin"]] <- model_params[["num_burnin"]] +
        json_object$get_scalar("num_burnin")
      model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] +
        json_object$get_scalar("num_mcmc")
      model_params[["num_samples"]] <- model_params[["num_samples"]] +
        json_object$get_scalar("num_samples")
    }
  }
  output[["model_params"]] <- model_params

  # Unpack sampled parameters
  if (model_params[["sample_sigma2_global"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_global_samples"]] <- json_object$get_vector(
          "sigma2_global_samples",
          "parameters"
        )
      } else {
        output[["sigma2_global_samples"]] <- c(
          output[["sigma2_global_samples"]],
          json_object$get_vector(
            "sigma2_global_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_mu"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector(
          "sigma2_leaf_mu_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_mu_samples"]] <- c(
          output[["sigma2_leaf_mu_samples"]],
          json_object$get_vector(
            "sigma2_leaf_mu_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_tau"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector(
          "sigma2_leaf_tau_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_tau_samples"]] <- c(
          output[["sigma2_leaf_tau_samples"]],
          json_object$get_vector(
            "sigma2_leaf_tau_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_tau"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector(
          "sigma2_leaf_tau_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_tau_samples"]] <- c(
          output[["sigma2_leaf_tau_samples"]],
          json_object$get_vector(
            "sigma2_leaf_tau_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["adaptive_coding"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["b_1_samples"]] <- json_object$get_vector(
          "b_1_samples",
          "parameters"
        )
        output[["b_0_samples"]] <- json_object$get_vector(
          "b_0_samples",
          "parameters"
        )
      } else {
        output[["b_1_samples"]] <- c(
          output[["b_1_samples"]],
          json_object$get_vector("b_1_samples", "parameters")
        )
        output[["b_0_samples"]] <- c(
          output[["b_0_samples"]],
          json_object$get_vector("b_0_samples", "parameters")
        )
      }
    }
  }

  # Unpack random effects
  if (model_params[["has_rfx"]]) {
    output[[
      "rfx_unique_group_ids"
    ]] <- json_object_default$get_string_vector("rfx_unique_group_ids")
    output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(
      json_object_list,
      0
    )
  }

  # Unpack covariate preprocessor
  preprocessor_metadata_string <- json_object_default$get_string(
    "preprocessor_metadata"
  )
  output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
    preprocessor_metadata_string
  )

  class(output) <- "bcfmodel"
  return(output)
}

#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object
#' which can be used for prediction, etc...
#'
#' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BCF model
#'
#' @return Object of type `bcfmodel`
#' @export
#'
#' @examples
#' n <- 500
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' mu_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' pi_x <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
#' )
#' tau_x <- (
#'     ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
#'     ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
#'     ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
#'     ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
#' )
#' Z <- rbinom(n, 1, pi_x)
#' E_XZ <- mu_x + Z*tau_x
#' snr <- 3
#' rfx_group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
#' y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' pi_test <- pi_x[test_inds]
#' pi_train <- pi_x[train_inds]
#' Z_test <- Z[test_inds]
#' Z_train <- Z[train_inds]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' mu_test <- mu_x[test_inds]
#' mu_train <- mu_x[train_inds]
#' tau_test <- tau_x[test_inds]
#' tau_train <- tau_x[train_inds]
#' rfx_group_ids_test <- rfx_group_ids[test_inds]
#' rfx_group_ids_train <- rfx_group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#'                  propensity_train = pi_train,
#'                  rfx_group_ids_train = rfx_group_ids_train,
#'                  rfx_basis_train = rfx_basis_train, X_test = X_test,
#'                  Z_test = Z_test, propensity_test = pi_test,
#'                  rfx_group_ids_test = rfx_group_ids_test,
#'                  rfx_basis_test = rfx_basis_test,
#'                  num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model))
#' bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list)
createBCFModelFromCombinedJsonString <- function(json_string_list) {
  # Initialize the BCF model
  output <- list()

  # Convert JSON strings
  json_object_list <- list()
  for (i in 1:length(json_string_list)) {
    json_string <- json_string_list[[i]]
    json_object_list[[i]] <- createCppJsonString(json_string)
    # Add runtime check for separately serialized propensity models
    # We don't support merging BCF models with independent propensity models
    # this way at the moment
    if (json_object_list[[i]]$get_boolean("internal_propensity_model")) {
      stop(
        "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling."
      )
    }
  }

  # For scalar / preprocessing details which aren't sample-dependent,
  # defer to the first json
  json_object_default <- json_object_list[[1]]

  # Unpack the forests
  output[["forests_mu"]] <- loadForestContainerCombinedJson(
    json_object_list,
    "forest_0"
  )
  output[["forests_tau"]] <- loadForestContainerCombinedJson(
    json_object_list,
    "forest_1"
  )
  include_variance_forest <- json_object_default$get_boolean(
    "include_variance_forest"
  )
  if (include_variance_forest) {
    output[["forests_variance"]] <- loadForestContainerCombinedJson(
      json_object_list,
      "forest_2"
    )
  }

  # Unpack metadata
  train_set_metadata = list()
  train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar(
    "num_numeric_vars"
  )
  train_set_metadata[[
    "num_ordered_cat_vars"
  ]] <- json_object_default$get_scalar("num_ordered_cat_vars")
  train_set_metadata[[
    "num_unordered_cat_vars"
  ]] <- json_object_default$get_scalar("num_unordered_cat_vars")
  if (train_set_metadata[["num_numeric_vars"]] > 0) {
    train_set_metadata[[
      "numeric_vars"
    ]] <- json_object_default$get_string_vector("numeric_vars")
  }
  if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "ordered_cat_vars"
    ]] <- json_object_default$get_string_vector("ordered_cat_vars")
    train_set_metadata[[
      "ordered_unique_levels"
    ]] <- json_object_default$get_string_list(
      "ordered_unique_levels",
      train_set_metadata[["ordered_cat_vars"]]
    )
  }
  if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
    train_set_metadata[[
      "unordered_cat_vars"
    ]] <- json_object_default$get_string_vector("unordered_cat_vars")
    train_set_metadata[[
      "unordered_unique_levels"
    ]] <- json_object_default$get_string_list(
      "unordered_unique_levels",
      train_set_metadata[["unordered_cat_vars"]]
    )
  }
  output[["train_set_metadata"]] <- train_set_metadata

  # Unpack model params
  model_params = list()
  model_params[["outcome_scale"]] <- json_object_default$get_scalar(
    "outcome_scale"
  )
  model_params[["outcome_mean"]] <- json_object_default$get_scalar(
    "outcome_mean"
  )
  model_params[["standardize"]] <- json_object_default$get_boolean(
    "standardize"
  )
  model_params[["initial_sigma2"]] <- json_object_default$get_scalar(
    "initial_sigma2"
  )
  model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean(
    "sample_sigma2_global"
  )
  model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean(
    "sample_sigma2_leaf_mu"
  )
  model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean(
    "sample_sigma2_leaf_tau"
  )
  model_params[["include_variance_forest"]] <- include_variance_forest
  model_params[["propensity_covariate"]] <- json_object_default$get_string(
    "propensity_covariate"
  )
  model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
  model_params[["has_rfx_basis"]] <- json_object_default$get_boolean(
    "has_rfx_basis"
  )
  model_params[["num_rfx_basis"]] <- json_object_default$get_scalar(
    "num_rfx_basis"
  )
  model_params[["num_covariates"]] <- json_object_default$get_scalar(
    "num_covariates"
  )
  model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
  model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
  model_params[["multivariate_treatment"]] <- json_object_default$get_boolean(
    "multivariate_treatment"
  )
  model_params[["adaptive_coding"]] <- json_object_default$get_boolean(
    "adaptive_coding"
  )
  model_params[[
    "internal_propensity_model"
  ]] <- json_object_default$get_boolean("internal_propensity_model")
  model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
    "probit_outcome_model"
  )
  model_params[["rfx_model_spec"]] <- json_object_default$get_string(
    "rfx_model_spec"
  )

  # Combine values that are sample-specific
  for (i in 1:length(json_object_list)) {
    json_object <- json_object_list[[i]]
    if (i == 1) {
      model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
      model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
      model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
      model_params[["num_samples"]] <- json_object$get_scalar(
        "num_samples"
      )
    } else {
      prev_json <- json_object_list[[i - 1]]
      model_params[["num_gfr"]] <- model_params[["num_gfr"]] +
        json_object$get_scalar("num_gfr")
      model_params[["num_burnin"]] <- model_params[["num_burnin"]] +
        json_object$get_scalar("num_burnin")
      model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] +
        json_object$get_scalar("num_mcmc")
      model_params[["num_samples"]] <- model_params[["num_samples"]] +
        json_object$get_scalar("num_samples")
    }
  }
  output[["model_params"]] <- model_params

  # Unpack sampled parameters
  if (model_params[["sample_sigma2_global"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_global_samples"]] <- json_object$get_vector(
          "sigma2_global_samples",
          "parameters"
        )
      } else {
        output[["sigma2_global_samples"]] <- c(
          output[["sigma2_global_samples"]],
          json_object$get_vector(
            "sigma2_global_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_mu"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector(
          "sigma2_leaf_mu_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_mu_samples"]] <- c(
          output[["sigma2_leaf_mu_samples"]],
          json_object$get_vector(
            "sigma2_leaf_mu_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_tau"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector(
          "sigma2_leaf_tau_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_tau_samples"]] <- c(
          output[["sigma2_leaf_tau_samples"]],
          json_object$get_vector(
            "sigma2_leaf_tau_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["sample_sigma2_leaf_tau"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector(
          "sigma2_leaf_tau_samples",
          "parameters"
        )
      } else {
        output[["sigma2_leaf_tau_samples"]] <- c(
          output[["sigma2_leaf_tau_samples"]],
          json_object$get_vector(
            "sigma2_leaf_tau_samples",
            "parameters"
          )
        )
      }
    }
  }
  if (model_params[["adaptive_coding"]]) {
    for (i in 1:length(json_object_list)) {
      json_object <- json_object_list[[i]]
      if (i == 1) {
        output[["b_1_samples"]] <- json_object$get_vector(
          "b_1_samples",
          "parameters"
        )
        output[["b_0_samples"]] <- json_object$get_vector(
          "b_0_samples",
          "parameters"
        )
      } else {
        output[["b_1_samples"]] <- c(
          output[["b_1_samples"]],
          json_object$get_vector("b_1_samples", "parameters")
        )
        output[["b_0_samples"]] <- c(
          output[["b_0_samples"]],
          json_object$get_vector("b_0_samples", "parameters")
        )
      }
    }
  }

  # Unpack random effects
  if (model_params[["has_rfx"]]) {
    output[[
      "rfx_unique_group_ids"
    ]] <- json_object_default$get_string_vector("rfx_unique_group_ids")
    output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(
      json_object_list,
      0
    )
  }

  # Unpack covariate preprocessor
  preprocessor_metadata_string <- json_object_default$get_string(
    "preprocessor_metadata"
  )
  output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
    preprocessor_metadata_string
  )

  class(output) <- "bcfmodel"
  return(output)
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on Nov. 22, 2025, 9:06 a.m.