R/bart.R

Defines functions createBARTModelFromCombinedJsonString createBARTModelFromCombinedJson createBARTModelFromJsonString createBARTModelFromJsonFile createBARTModelFromJson saveBARTModelToJsonString saveBARTModelToJsonFile saveBARTModelToJson getRandomEffectSamples.bartmodel predict.bartmodel bart

Documented in bart createBARTModelFromCombinedJson createBARTModelFromCombinedJsonString createBARTModelFromJson createBARTModelFromJsonFile createBARTModelFromJsonString getRandomEffectSamples.bartmodel predict.bartmodel saveBARTModelToJson saveBARTModelToJsonFile saveBARTModelToJsonString

#' Run the BART algorithm for supervised learning. 
#'
#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. 
#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be 
#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, 
#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata 
#' that the column is ordered categorical).
#' @param y_train Outcome to be modeled by the ensemble.
#' @param leaf_basis_train (Optional) Bases used to define a regression model `y ~ W` in 
#' each leaf of each regression tree. By default, BART assumes constant leaf node 
#' parameters, implicitly regressing on a constant basis of ones (i.e. `y ~ 1`).
#' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model.
#' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model.
#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model 
#' will be estimated.
#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. 
#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with 
#' that of `X_train`.
#' @param leaf_basis_test (Optional) Test set of bases used to define "out of sample" evaluation data. 
#' While a test set is optional, the structure of any provided test set must match that 
#' of the training set (i.e. if both `X_train` and `leaf_basis_train` are provided, then a test set must 
#' consist of `X_test` and `leaf_basis_test` with the same number of columns).
#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model. 
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
#' @param previous_model_json (Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`.
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
#'   - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`.
#'   - `sample_sigma2_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Default: `TRUE`.
#'   - `sigma2_global_init` Starting value of global error variance parameter. Calibrated internally as `1.0*var(y_train)`, where `y_train` is the possibly standardized outcome, if not set.
#'   - `sigma2_global_shape` Shape parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
#'   - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
#'   - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`.
#'   - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#'   - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
#'   - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
#'   - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
#'   - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
#'   - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#'
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `num_trees` Number of trees in the ensemble for the conditional mean model. Default: `200`. If `num_trees = 0`, the conditional mean will not be modeled using a forest, and the function will only proceed if `num_trees > 0` for the variance forest.
#'   - `alpha` Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`.
#'   - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
#'   - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: `5`.
#'   - `max_depth` Maximum depth of any tree in the ensemble in the mean model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'   - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(leaf_basis_train)>1`. Default: `FALSE`.
#'   - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
#'   - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`.
#'   - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#'   - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#'   - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
#'   - `num_trees` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees > 0`.
#'   - `alpha` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`.
#'   - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
#'   - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`.
#'   - `max_depth` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
#'   - `leaf_prior_calibration_param` Hyperparameter used to calibrate the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model. If `var_forest_prior_shape` and `var_forest_prior_scale` are not set below, this calibration parameter is used to set these values to `num_trees / leaf_prior_calibration_param^2 + 0.5` and `num_trees / leaf_prior_calibration_param^2`, respectively. Default: `1.5`.
#'   - `var_forest_leaf_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(0.6*var(y_train))/num_trees`, where `y_train` is the possibly standardized outcome, if not set.
#'   - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2 + 0.5` if not set.
#'   - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set.
#'   - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#'   - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#'   
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL, 
                 rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL, 
                 rfx_group_ids_test = NULL, rfx_basis_test = NULL, 
                 num_gfr = 5, num_burnin = 0, num_mcmc = 100, 
                 previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, 
                 general_params = list(), mean_forest_params = list(), 
                 variance_forest_params = list()) {
    # Update general BART parameters
    general_params_default <- list(
        cutpoint_grid_size = 100, standardize = TRUE, 
        sample_sigma2_global = TRUE, sigma2_global_init = NULL, 
        sigma2_global_shape = 0, sigma2_global_scale = 0, 
        variable_weights = NULL, random_seed = -1, 
        keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1, 
        num_chains = 1, verbose = FALSE
    )
    general_params_updated <- preprocessParams(
        general_params_default, general_params
    )
    
    # Update mean forest BART parameters
    mean_forest_params_default <- list(
        num_trees = 200, alpha = 0.95, beta = 2.0, 
        min_samples_leaf = 5, max_depth = 10, 
        sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL, 
        sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, 
        keep_vars = NULL, drop_vars = NULL
    )
    mean_forest_params_updated <- preprocessParams(
        mean_forest_params_default, mean_forest_params
    )
    
    # Update variance forest BART parameters
    variance_forest_params_default <- list(
        num_trees = 0, alpha = 0.95, beta = 2.0, 
        min_samples_leaf = 5, max_depth = 10, 
        leaf_prior_calibration_param = 1.5, 
        var_forest_leaf_init = NULL,
        var_forest_prior_shape = NULL, 
        var_forest_prior_scale = NULL, 
        keep_vars = NULL, drop_vars = NULL
    )
    variance_forest_params_updated <- preprocessParams(
        variance_forest_params_default, variance_forest_params
    )
    
    ### Unpack all parameter values
    # 1. General parameters
    cutpoint_grid_size <- general_params_updated$cutpoint_grid_size
    standardize <- general_params_updated$standardize
    sample_sigma_global <- general_params_updated$sample_sigma2_global
    sigma2_init <- general_params_updated$sigma2_global_init
    a_global <- general_params_updated$sigma2_global_shape
    b_global <- general_params_updated$sigma2_global_scale
    variable_weights <- general_params_updated$variable_weights
    random_seed <- general_params_updated$random_seed
    keep_burnin <- general_params_updated$keep_burnin
    keep_gfr <- general_params_updated$keep_gfr
    keep_every <- general_params_updated$keep_every
    num_chains <- general_params_updated$num_chains
    verbose <- general_params_updated$verbose
    
    # 2. Mean forest parameters
    num_trees_mean <- mean_forest_params_updated$num_trees
    alpha_mean <- mean_forest_params_updated$alpha
    beta_mean <- mean_forest_params_updated$beta
    min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf
    max_depth_mean <- mean_forest_params_updated$max_depth
    sample_sigma_leaf <- mean_forest_params_updated$sample_sigma2_leaf
    sigma_leaf_init <- mean_forest_params_updated$sigma2_leaf_init
    a_leaf <- mean_forest_params_updated$sigma2_leaf_shape
    b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
    keep_vars_mean <- mean_forest_params_updated$keep_vars
    drop_vars_mean <- mean_forest_params_updated$drop_vars
    
    # 3. Variance forest parameters
    num_trees_variance <- variance_forest_params_updated$num_trees
    alpha_variance <- variance_forest_params_updated$alpha
    beta_variance <- variance_forest_params_updated$beta
    min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf
    max_depth_variance <- variance_forest_params_updated$max_depth
    a_0 <- variance_forest_params_updated$leaf_prior_calibration_param
    variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init
    a_forest <- variance_forest_params_updated$var_forest_prior_shape
    b_forest <- variance_forest_params_updated$var_forest_prior_scale
    keep_vars_variance <- variance_forest_params_updated$keep_vars
    drop_vars_variance <- variance_forest_params_updated$drop_vars
    
    # Check if there are enough GFR samples to seed num_chains samplers
    if (num_gfr > 0) {
        if (num_chains > num_gfr) {
            stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains")
        }
    }
    
    # Override keep_gfr if there are no MCMC samples
    if (num_mcmc == 0) keep_gfr <- TRUE
    
    # Check if previous model JSON is provided and parse it if so
    has_prev_model <- !is.null(previous_model_json)
    if (has_prev_model) {
        previous_bart_model <- createBARTModelFromJsonString(previous_model_json)
        previous_y_bar <- previous_bart_model$model_params$outcome_mean
        previous_y_scale <- previous_bart_model$model_params$outcome_scale
        if (previous_bart_model$model_params$include_mean_forest) {
            previous_forest_samples_mean <- previous_bart_model$mean_forests
        } else previous_forest_samples_mean <- NULL
        if (previous_bart_model$model_params$include_mean_forest) {
            previous_forest_samples_variance <- previous_bart_model$variance_forests
        } else previous_forest_samples_variance <- NULL
        if (previous_bart_model$model_params$sample_sigma_global) {
            previous_global_var_samples <- previous_bart_model$sigma2_global_samples / (
                previous_y_scale*previous_y_scale
            )
        } else previous_global_var_samples <- NULL
        if (previous_bart_model$model_params$sample_sigma_leaf) {
            previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples
        } else previous_leaf_var_samples <- NULL
        if (previous_bart_model$model_params$has_rfx) {
            previous_rfx_samples <- previous_bart_model$rfx_samples
        } else previous_rfx_samples <- NULL
        previous_model_num_samples <- previous_bart_model$model_params$num_samples
        if (previous_model_warmstart_sample_num > previous_model_num_samples) {
            stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`")
        }
    } else {
        previous_y_bar <- NULL
        previous_y_scale <- NULL
        previous_global_var_samples <- NULL
        previous_leaf_var_samples <- NULL
        previous_rfx_samples <- NULL
        previous_forest_samples_mean <- NULL
        previous_forest_samples_variance <- NULL
        previous_model_num_samples <- 0
    }
    
    # Determine whether conditional mean, variance, or both will be modeled
    if (num_trees_variance > 0) include_variance_forest = TRUE
    else include_variance_forest = FALSE
    if (num_trees_mean > 0) include_mean_forest = TRUE
    else include_mean_forest = FALSE
    
    # Set the variance forest priors if not set
    if (include_variance_forest) {
        if (is.null(a_forest)) a_forest <- num_trees_variance / (a_0^2) + 0.5
        if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2)
    } else {
        a_forest <- 1.
        b_forest <- 1.
    }
    
    # Override tau sampling if there is no mean forest
    if (!include_mean_forest) sample_sigma_leaf <- FALSE
    
    # Variable weight preprocessing (and initialization if necessary)
    if (is.null(variable_weights)) {
        variable_weights = rep(1/ncol(X_train), ncol(X_train))
    }
    if (any(variable_weights < 0)) {
        stop("variable_weights cannot have any negative weights")
    }

    # Check covariates are matrix or dataframe
    if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
        stop("X_train must be a matrix or dataframe")
    }
    if (!is.null(X_test)){
        if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
            stop("X_test must be a matrix or dataframe")
        }
    }
    num_cov_orig <- ncol(X_train)

    # Standardize the keep variable lists to numeric indices
    if (!is.null(keep_vars_mean)) {
        if (is.character(keep_vars_mean)) {
            if (!all(keep_vars_mean %in% names(X_train))) {
                stop("keep_vars_mean includes some variable names that are not in X_train")
            }
            variable_subset_mu <- unname(which(names(X_train) %in% keep_vars_mean))
        } else {
            if (any(keep_vars_mean > ncol(X_train))) {
                stop("keep_vars_mean includes some variable indices that exceed the number of columns in X_train")
            }
            if (any(keep_vars_mean < 0)) {
                stop("keep_vars_mean includes some negative variable indices")
            }
            variable_subset_mu <- keep_vars_mean
        }
    } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) {
        if (is.character(drop_vars_mean)) {
            if (!all(drop_vars_mean %in% names(X_train))) {
                stop("drop_vars_mean includes some variable names that are not in X_train")
            }
            variable_subset_mean <- unname(which(!(names(X_train) %in% drop_vars_mean)))
        } else {
            if (any(drop_vars_mean > ncol(X_train))) {
                stop("drop_vars_mean includes some variable indices that exceed the number of columns in X_train")
            }
            if (any(drop_vars_mean < 0)) {
                stop("drop_vars_mean includes some negative variable indices")
            }
            variable_subset_mean <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_mean)]
        }
    } else {
        variable_subset_mean <- 1:ncol(X_train)
    }
    if (!is.null(keep_vars_variance)) {
        if (is.character(keep_vars_variance)) {
            if (!all(keep_vars_variance %in% names(X_train))) {
                stop("keep_vars_variance includes some variable names that are not in X_train")
            }
            variable_subset_variance <- unname(which(names(X_train) %in% keep_vars_variance))
        } else {
            if (any(keep_vars_variance > ncol(X_train))) {
                stop("keep_vars_variance includes some variable indices that exceed the number of columns in X_train")
            }
            if (any(keep_vars_variance < 0)) {
                stop("keep_vars_variance includes some negative variable indices")
            }
            variable_subset_variance <- keep_vars_variance
        }
    } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) {
        if (is.character(drop_vars_variance)) {
            if (!all(drop_vars_variance %in% names(X_train))) {
                stop("drop_vars_variance includes some variable names that are not in X_train")
            }
            variable_subset_variance <- unname(which(!(names(X_train) %in% drop_vars_variance)))
        } else {
            if (any(drop_vars_variance > ncol(X_train))) {
                stop("drop_vars_variance includes some variable indices that exceed the number of columns in X_train")
            }
            if (any(drop_vars_variance < 0)) {
                stop("drop_vars_variance includes some negative variable indices")
            }
            variable_subset_variance <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_variance)]
        }
    } else {
        variable_subset_variance <- 1:ncol(X_train)
    }
    
    # Preprocess covariates
    if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
        stop("X_train must be a matrix or dataframe")
    }
    if (!is.null(X_test)){
        if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
            stop("X_test must be a matrix or dataframe")
        }
    }
    if (ncol(X_train) != length(variable_weights)) {
        stop("length(variable_weights) must equal ncol(X_train)")
    }
    train_cov_preprocess_list <- preprocessTrainData(X_train)
    X_train_metadata <- train_cov_preprocess_list$metadata
    X_train <- train_cov_preprocess_list$data
    original_var_indices <- X_train_metadata$original_var_indices
    feature_types <- X_train_metadata$feature_types
    if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
    
    # Update variable weights
    variable_weights_mean <- variable_weights_variance <- variable_weights
    variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
    if (include_mean_forest) {
        variable_weights_mean <- variable_weights_mean[original_var_indices]*variable_weights_adj
        variable_weights_mean[!(original_var_indices %in% variable_subset_mean)] <- 0
    }
    if (include_variance_forest) {
        variable_weights_variance <- variable_weights_variance[original_var_indices]*variable_weights_adj
        variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0
    }

    # Convert all input data to matrices if not already converted
    if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) {
        leaf_basis_train <- as.matrix(leaf_basis_train)
    }
    if ((is.null(dim(leaf_basis_test))) && (!is.null(leaf_basis_test))) {
        leaf_basis_test <- as.matrix(leaf_basis_test)
    }
    if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) {
        rfx_basis_train <- as.matrix(rfx_basis_train)
    }
    if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) {
        rfx_basis_test <- as.matrix(rfx_basis_test)
    }
    
    # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
    has_rfx <- FALSE
    has_rfx_test <- FALSE
    if (!is.null(rfx_group_ids_train)) {
        group_ids_factor <- factor(rfx_group_ids_train)
        rfx_group_ids_train <- as.integer(group_ids_factor)
        has_rfx <- TRUE
        if (!is.null(rfx_group_ids_test)) {
            group_ids_factor_test <- factor(rfx_group_ids_test, levels = levels(group_ids_factor))
            if (sum(is.na(group_ids_factor_test)) > 0) {
                stop("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train")
            }
            rfx_group_ids_test <- as.integer(group_ids_factor_test)
            has_rfx_test <- TRUE
        }
    }
    
    # Data consistency checks
    if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) {
        stop("X_train and X_test must have the same number of columns")
    }
    if ((!is.null(leaf_basis_test)) && (ncol(leaf_basis_test) != ncol(leaf_basis_train))) {
        stop("leaf_basis_train and leaf_basis_test must have the same number of columns")
    }
    if ((!is.null(leaf_basis_train)) && (nrow(leaf_basis_train) != nrow(X_train))) {
        stop("leaf_basis_train and X_train must have the same number of rows")
    }
    if ((!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test))) {
        stop("leaf_basis_test and X_test must have the same number of rows")
    }
    if (nrow(X_train) != length(y_train)) {
        stop("X_train and y_train must have the same number of observations")
    }
    if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) {
        stop("rfx_basis_train and rfx_basis_test must have the same number of columns")
    }
    if (!is.null(rfx_group_ids_train)) {
        if (!is.null(rfx_group_ids_test)) {
            if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) {
                stop("rfx_basis_train is provided but rfx_basis_test is not provided")
            }
        }
    }
    
    # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided 
    has_basis_rfx <- FALSE
    num_basis_rfx <- 0
    if (has_rfx) {
        if (is.null(rfx_basis_train)) {
            rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1)
        } else {
            has_basis_rfx <- TRUE
            num_basis_rfx <- ncol(rfx_basis_train)
        }
        num_rfx_groups <- length(unique(rfx_group_ids_train))
        num_rfx_components <- ncol(rfx_basis_train)
        if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill")
    }
    if (has_rfx_test) {
        if (is.null(rfx_basis_test)) {
            if (!is.null(rfx_basis_train)) {
                stop("Random effects basis provided for training set, must also be provided for the test set")
            }
            rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1)
        }
    }

    # Convert y_train to numeric vector if not already converted
    if (!is.null(dim(y_train))) {
        y_train <- as.matrix(y_train)
    }
    
    # Determine whether a basis vector is provided
    has_basis = !is.null(leaf_basis_train)
    
    # Determine whether a test set is provided
    has_test = !is.null(X_test)

    # Standardize outcome separately for test and train
    if (standardize) {
        y_bar_train <- mean(y_train)
        y_std_train <- sd(y_train)
    } else {
        y_bar_train <- 0
        y_std_train <- 1
    }
    resid_train <- (y_train-y_bar_train)/y_std_train
    
    # Compute initial value of root nodes in mean forest
    init_val_mean <- mean(resid_train)

    # Calibrate priors for sigma^2 and tau
    if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
    if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
    if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
    if (has_basis) {
        if (ncol(leaf_basis_train) > 1) {
            if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
            if (!is.matrix(sigma_leaf_init)) {
                current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
            } else {
                current_leaf_scale <- sigma_leaf_init
            }
        } else {
            if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
            if (!is.matrix(sigma_leaf_init)) {
                current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
            } else {
                current_leaf_scale <- sigma_leaf_init
            }
        }
    } else {
        if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
        if (!is.matrix(sigma_leaf_init)) {
            current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
        } else {
            current_leaf_scale <- sigma_leaf_init
        }
    }
    current_sigma2 <- sigma2_init

    # Determine leaf model type
    if (!has_basis) leaf_model_mean_forest <- 0
    else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1
    else if (ncol(leaf_basis_train) > 1) leaf_model_mean_forest <- 2
    else stop("leaf_basis_train passed must be a matrix with at least 1 column")

    # Set variance leaf model type (currently only one option)
    leaf_model_variance_forest <- 3
    
    # Unpack model type info
    if (leaf_model_mean_forest == 0) {
        leaf_dimension = 1
        is_leaf_constant = TRUE
        leaf_regression = FALSE
    } else if (leaf_model_mean_forest == 1) {
        stopifnot(has_basis)
        stopifnot(ncol(leaf_basis_train) == 1)
        leaf_dimension = 1
        is_leaf_constant = FALSE
        leaf_regression = TRUE
    } else if (leaf_model_mean_forest == 2) {
        stopifnot(has_basis)
        stopifnot(ncol(leaf_basis_train) > 1)
        leaf_dimension = ncol(leaf_basis_train)
        is_leaf_constant = FALSE
        leaf_regression = TRUE
        if (sample_sigma_leaf) {
            warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.")
            sample_sigma_leaf <- FALSE
        }
    }
    
    # Data
    if (leaf_regression) {
        forest_dataset_train <- createForestDataset(X_train, leaf_basis_train)
        if (has_test) forest_dataset_test <- createForestDataset(X_test, leaf_basis_test)
        requires_basis <- TRUE
    } else {
        forest_dataset_train <- createForestDataset(X_train)
        if (has_test) forest_dataset_test <- createForestDataset(X_test)
        requires_basis <- FALSE
    }
    outcome_train <- createOutcome(resid_train)
    
    # Random number generator (std::mt19937)
    if (is.null(random_seed)) random_seed = sample(1:10000,1,FALSE)
    rng <- createCppRNG(random_seed)
    
    # Sampling data structures
    feature_types <- as.integer(feature_types)
    global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2)
    if (include_mean_forest) {
        forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train), 
                                                            num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension, 
                                                            alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean, 
                                                            leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, 
                                                            cutpoint_grid_size=cutpoint_grid_size)
        forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config)
    }
    if (include_variance_forest) {
        forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), 
                                                                num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1, 
                                                                alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, 
                                                                max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest, 
                                                                cutpoint_grid_size=cutpoint_grid_size)
        forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
    }
    
    # Container of forest samples
    if (include_mean_forest) {
        forest_samples_mean <- createForestSamples(num_trees_mean, leaf_dimension, is_leaf_constant, FALSE)
        active_forest_mean <- createForest(num_trees_mean, leaf_dimension, is_leaf_constant, FALSE)
    }
    if (include_variance_forest) {
        forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE)
        active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE)
    }
    
    # Random effects initialization 
    if (has_rfx) {
        # Prior parameters
        if (num_rfx_components == 1) {
            alpha_init <- c(1)
        } else if (num_rfx_components > 1) {
            alpha_init <- c(1,rep(0,num_rfx_components-1))
        } else {
            stop("There must be at least 1 random effect component")
        }
        xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
        sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
        sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
        sigma_xi_shape <- 1
        sigma_xi_scale <- 1
        
        # Random effects data structure and storage container
        rfx_dataset_train <- createRandomEffectsDataset(rfx_group_ids_train, rfx_basis_train)
        rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train)
        rfx_model <- createRandomEffectsModel(num_rfx_components, num_rfx_groups)
        rfx_model$set_working_parameter(alpha_init)
        rfx_model$set_group_parameters(xi_init)
        rfx_model$set_working_parameter_cov(sigma_alpha_init)
        rfx_model$set_group_parameter_cov(sigma_xi_init)
        rfx_model$set_variance_prior_shape(sigma_xi_shape)
        rfx_model$set_variance_prior_scale(sigma_xi_scale)
        rfx_samples <- createRandomEffectSamples(num_rfx_components, num_rfx_groups, rfx_tracker_train)
    }

    # Container of variance parameter samples
    num_actual_mcmc_iter <- num_mcmc * keep_every
    num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter
    # Delete GFR samples from these containers after the fact if desired
    # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc
    num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains
    if (sample_sigma_global) global_var_samples <- rep(NA, num_retained_samples)
    if (sample_sigma_leaf) leaf_scale_samples <- rep(NA, num_retained_samples)
    sample_counter <- 0
    
    # Initialize the leaves of each tree in the mean forest
    if (include_mean_forest) {
        if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train))
        else init_values_mean_forest <- 0.
        active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest)
        active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, FALSE)
    }

    # Initialize the leaves of each tree in the variance forest
    if (include_variance_forest) {
        active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init)

    }
    
    # Run GFR (warm start) if specified
    if (num_gfr > 0){
        for (i in 1:num_gfr) {
            # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC
            # keep_sample <- ifelse(keep_gfr, TRUE, FALSE)
            keep_sample <- TRUE
            if (keep_sample) sample_counter <- sample_counter + 1
            # Print progress
            if (verbose) {
                if ((i %% 10 == 0) || (i == num_gfr)) {
                    cat("Sampling", i, "out of", num_gfr, "XBART (grow-from-root) draws\n")
                }
            }
            
            if (include_mean_forest) {
                forest_model_mean$sample_one_iteration(
                    forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, 
                    active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, 
                    global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
                )
            }
            if (include_variance_forest) {
                forest_model_variance$sample_one_iteration(
                    forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, 
                    active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, 
                    global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
                )
            }
            if (sample_sigma_global) {
                current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
                if (keep_sample) global_var_samples[sample_counter] <- current_sigma2
                global_model_config$update_global_error_variance(current_sigma2)
            }
            if (sample_sigma_leaf) {
                leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf)
                current_leaf_scale <- as.matrix(leaf_scale_double)
                if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double
                forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
            }
            if (has_rfx) {
                rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng)
            }
        }
    }
    
    # Run MCMC
    if (num_burnin + num_mcmc > 0) {
        for (chain_num in 1:num_chains) {
            if (num_gfr > 0) {
                # Reset state of active_forest and forest_model based on a previous GFR sample
                forest_ind <- num_gfr - chain_num
                if (include_mean_forest) {
                    resetActiveForest(active_forest_mean, forest_samples_mean, forest_ind)
                    resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE)
                    if (sample_sigma_leaf) {
                        leaf_scale_double <- leaf_scale_samples[forest_ind + 1]
                        current_leaf_scale <- as.matrix(leaf_scale_double)
                        forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
                    }
                }
                if (include_variance_forest) {
                    resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind)
                    resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE)
                }
                if (has_rfx) {
                    resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init)
                    resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples)
                }
                if (sample_sigma_global) {
                    current_sigma2 <- global_var_samples[forest_ind + 1]
                    global_model_config$update_global_error_variance(current_sigma2)
                }
            } else if (has_prev_model) {
                if (include_mean_forest) {
                    resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1)
                    resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE)
                    if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) {
                        leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num]
                        current_leaf_scale <- as.matrix(leaf_scale_double)
                        forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
                    }
                }
                if (include_variance_forest) {
                    resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1)
                    resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE)
                }
                if (has_rfx) {
                    if (is.null(previous_rfx_samples)) {
                        warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started")
                        rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init,
                                                    sigma_xi_init, sigma_xi_shape, sigma_xi_scale)
                        rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train)
                    } else {
                        resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init)
                        resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples)
                    }
                }
                if (sample_sigma_global) {
                    if (!is.null(previous_global_var_samples)) {
                        current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num]
                        global_model_config$update_global_error_variance(current_sigma2)
                    }
                }
            } else {
                if (include_mean_forest) {
                    resetActiveForest(active_forest_mean)
                    active_forest_mean$set_root_leaves(init_values_mean_forest / num_trees_mean)
                    resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE)
                    if (sample_sigma_leaf) {
                        current_leaf_scale <- as.matrix(sigma_leaf_init)
                        forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
                    }
                }
                if (include_variance_forest) {
                    resetActiveForest(active_forest_variance)
                    active_forest_variance$set_root_leaves(log(variance_forest_init) / num_trees_variance)
                    resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE)
                }
                if (has_rfx) {
                    rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init,
                                                sigma_xi_init, sigma_xi_shape, sigma_xi_scale)
                    rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train)
                }
                if (sample_sigma_global) {
                    current_sigma2 <- sigma2_init
                    global_model_config$update_global_error_variance(current_sigma2)
                }
            }
            for (i in (num_gfr+1):num_samples) {
                is_mcmc <- i > (num_gfr + num_burnin)
                if (is_mcmc) {
                    mcmc_counter <- i - (num_gfr + num_burnin)
                    if (mcmc_counter %% keep_every == 0) keep_sample <- TRUE
                    else keep_sample <- FALSE
                } else {
                    if (keep_burnin) keep_sample <- TRUE
                    else keep_sample <- FALSE
                }
                if (keep_sample) sample_counter <- sample_counter + 1
                # Print progress
                if (verbose) {
                    if (num_burnin > 0) {
                        if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) {
                            cat("Sampling", i - num_gfr, "out of", num_burnin, "BART burn-in draws; Chain number ", chain_num, "\n")
                        }
                    }
                    if (num_mcmc > 0) {
                        if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) {
                            cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BART MCMC draws; Chain number ", chain_num, "\n")
                        }
                    }
                }
                
                if (include_mean_forest) {
                    forest_model_mean$sample_one_iteration(
                        forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, 
                        active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, 
                        global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
                    )
                }
                if (include_variance_forest) {
                    forest_model_variance$sample_one_iteration(
                        forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, 
                        active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, 
                        global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
                    )
                }
                if (sample_sigma_global) {
                    current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
                    if (keep_sample) global_var_samples[sample_counter] <- current_sigma2
                    global_model_config$update_global_error_variance(current_sigma2)
                }
                if (sample_sigma_leaf) {
                    leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf)
                    current_leaf_scale <- as.matrix(leaf_scale_double)
                    if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double
                    forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
                }
                if (has_rfx) {
                    rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng)
                }
            }
        }
    }
    
    # Remove GFR samples if they are not to be retained
    if ((!keep_gfr) && (num_gfr > 0)) {
        for (i in 1:num_gfr) {
            if (include_mean_forest) {
                forest_samples_mean$delete_sample(i-1)
            }
            if (include_variance_forest) {
                forest_samples_variance$delete_sample(i-1)
            }
            if (has_rfx) {
                rfx_samples$delete_sample(i-1)
            }
        }
        if (sample_sigma_global) {
            global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)]
        }
        if (sample_sigma_leaf) {
            leaf_scale_samples <- leaf_scale_samples[(num_gfr+1):length(leaf_scale_samples)]
        }
        num_retained_samples <- num_retained_samples - num_gfr
    }

    # Mean forest predictions
    if (include_mean_forest) {
        y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
        if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train
    }
    
    # Variance forest predictions
    if (include_variance_forest) {
        sigma_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
        if (has_test) sigma_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
    }
    
    # Random effects predictions
    if (has_rfx) {
        rfx_preds_train <- rfx_samples$predict(rfx_group_ids_train, rfx_basis_train)*y_std_train
        y_hat_train <- y_hat_train + rfx_preds_train
    }
    if ((has_rfx_test) && (has_test)) {
        rfx_preds_test <- rfx_samples$predict(rfx_group_ids_test, rfx_basis_test)*y_std_train
        y_hat_test <- y_hat_test + rfx_preds_test
    }

    # Global error variance
    if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2)
    
    # Leaf parameter variance
    if (sample_sigma_leaf) tau_samples <- leaf_scale_samples
    
    # Rescale variance forest prediction by global sigma2 (sampled or constant)
    if (include_variance_forest) {
        if (sample_sigma_global) {
            sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i]))
            if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i]))
        } else {
            sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train
            if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train
        }
    }
    
    # Return results as a list
    model_params <- list(
        "sigma2_init" = sigma2_init, 
        "sigma_leaf_init" = sigma_leaf_init,
        "a_global" = a_global,
        "b_global" = b_global, 
        "a_leaf" = a_leaf, 
        "b_leaf" = b_leaf,
        "a_forest" = a_forest, 
        "b_forest" = b_forest,
        "outcome_mean" = y_bar_train,
        "outcome_scale" = y_std_train, 
        "standardize" = standardize, 
        "leaf_dimension" = leaf_dimension,
        "is_leaf_constant" = is_leaf_constant,
        "leaf_regression" = leaf_regression,
        "requires_basis" = requires_basis, 
        "num_covariates" = ncol(X_train), 
        "num_basis" = ifelse(is.null(leaf_basis_train),0,ncol(leaf_basis_train)), 
        "num_samples" = num_retained_samples, 
        "num_gfr" = num_gfr, 
        "num_burnin" = num_burnin, 
        "num_mcmc" = num_mcmc, 
        "keep_every" = keep_every,
        "num_chains" = num_chains,
        "has_basis" = !is.null(leaf_basis_train), 
        "has_rfx" = has_rfx, 
        "has_rfx_basis" = has_basis_rfx, 
        "num_rfx_basis" = num_basis_rfx, 
        "sample_sigma_global" = sample_sigma_global,
        "sample_sigma_leaf" = sample_sigma_leaf,
        "include_mean_forest" = include_mean_forest,
        "include_variance_forest" = include_variance_forest
    )
    result <- list(
        "model_params" = model_params, 
        "train_set_metadata" = X_train_metadata
    )
    if (include_mean_forest) {
        result[["mean_forests"]] = forest_samples_mean
        result[["y_hat_train"]] = y_hat_train
        if (has_test) result[["y_hat_test"]] = y_hat_test
    }
    if (include_variance_forest) {
        result[["variance_forests"]] = forest_samples_variance
        result[["sigma_x_hat_train"]] = sigma_x_hat_train
        if (has_test) result[["sigma_x_hat_test"]] = sigma_x_hat_test
    }
    if (sample_sigma_global) result[["sigma2_global_samples"]] = sigma2_samples
    if (sample_sigma_leaf) result[["sigma2_leaf_samples"]] = tau_samples
    if (has_rfx) {
        result[["rfx_samples"]] = rfx_samples
        result[["rfx_preds_train"]] = rfx_preds_train
        result[["rfx_unique_group_ids"]] = levels(group_ids_factor)
    }
    if ((has_rfx_test) && (has_test)) result[["rfx_preds_test"]] = rfx_preds_test
    class(result) <- "bartmodel"
    
    # Clean up classes with external pointers to C++ data structures
    if (include_mean_forest) rm(forest_model_mean)
    if (include_variance_forest) rm(forest_model_variance)
    rm(forest_dataset_train)
    if (has_test) rm(forest_dataset_test)
    if (has_rfx) rm(rfx_dataset_train, rfx_tracker_train, rfx_model)
    rm(outcome_train)
    rm(rng)
    
    return(result)
}

#' Predict from a sampled BART model on new data
#'
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
#' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. 
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
#' that were not in the training set.
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model.
#' @param ... (Optional) Other prediction parameters.
#'
#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. 
#' If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (`y_hat`).
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' y_hat_test <- predict(bart_model, X_test)$y_hat
predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, ...){
    # Preprocess covariates
    if ((!is.data.frame(X)) && (!is.matrix(X))) {
        stop("X must be a matrix or dataframe")
    }
    train_set_metadata <- object$train_set_metadata
    X <- preprocessPredictionData(X, train_set_metadata)
    
    # Convert all input data to matrices if not already converted
    if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) {
        leaf_basis <- as.matrix(leaf_basis)
    }
    if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) {
        rfx_basis <- as.matrix(rfx_basis)
    }
    
    # Data checks
    if ((object$model_params$requires_basis) && (is.null(leaf_basis))) {
        stop("Basis (leaf_basis) must be provided for this model")
    }
    if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) {
        stop("X and leaf_basis must have the same number of rows")
    }
    if (object$model_params$num_covariates != ncol(X)) {
        stop("X and leaf_basis must have the same number of rows")
    }
    if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) {
        stop("Random effect group labels (rfx_group_ids) must be provided for this model")
    }
    if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) {
        stop("Random effects basis (rfx_basis) must be provided for this model")
    }
    if ((object$model_params$num_rfx_basis > 0) && (ncol(rfx_basis) != object$model_params$num_rfx_basis)) {
        stop("Random effects basis has a different dimension than the basis used to train this model")
    }
    
    # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
    has_rfx <- FALSE
    if (!is.null(rfx_group_ids)) {
        rfx_unique_group_ids <- object$rfx_unique_group_ids
        group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
        if (sum(is.na(group_ids_factor)) > 0) {
            stop("All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train")
        }
        rfx_group_ids <- as.integer(group_ids_factor)
        has_rfx <- TRUE
    }
    
    # Produce basis for the "intercept-only" random effects case
    if ((object$model_params$has_rfx) && (is.null(rfx_basis))) {
        rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1)
    }
    
    # Create prediction dataset
    if (!is.null(leaf_basis)) prediction_dataset <- createForestDataset(X, leaf_basis)
    else prediction_dataset <- createForestDataset(X)
    
    # Compute mean forest predictions
    num_samples <- object$model_params$num_samples
    y_std <- object$model_params$outcome_scale
    y_bar <- object$model_params$outcome_mean
    sigma2_init <- object$model_params$sigma2_init
    if (object$model_params$include_mean_forest) {
        mean_forest_predictions <- object$mean_forests$predict(prediction_dataset)*y_std + y_bar
    }
    
    # Compute variance forest predictions
    if (object$model_params$include_variance_forest) {
        s_x_raw <- object$variance_forests$predict(prediction_dataset)
    }
    
    # Compute rfx predictions (if needed)
    if (object$model_params$has_rfx) {
        rfx_predictions <- object$rfx_samples$predict(rfx_group_ids, rfx_basis)*y_std
    }
    
    # Scale variance forest predictions
    if (object$model_params$include_variance_forest) {
        if (object$model_params$sample_sigma_global) {
            sigma2_samples <- object$sigma2_global_samples
            variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i]))
        } else {
            variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std
        }
    }

    if ((object$model_params$include_mean_forest) && (object$model_params$has_rfx)) {
        y_hat <- mean_forest_predictions + rfx_predictions
    } else if ((object$model_params$include_mean_forest) && (!object$model_params$has_rfx)) {
        y_hat <- mean_forest_predictions
    } else if ((!object$model_params$include_mean_forest) && (object$model_params$has_rfx)) {
        y_hat <- rfx_predictions
    } 
    
    result <- list()
    if ((object$model_params$has_rfx) || (object$model_params$include_mean_forest)) {
        result[["y_hat"]] = y_hat
    }
    if (object$model_params$include_mean_forest) {
        result[["mean_forest_predictions"]] = mean_forest_predictions
    }
    if (object$model_params$has_rfx) {
        result[["rfx_predictions"]] = rfx_predictions
    }
    if (object$model_params$include_variance_forest) {
        result[["variance_forest_predictions"]] = variance_forest_predictions
    }
    return(result)
}

#' Extract raw sample values for each of the random effect parameter terms.
#'
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
#' @param ... Other parameters to be used in random effects extraction
#' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
#' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and is simply a matrix if `num_components = 1`.
#' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' snr <- 3
#' group_ids <- rep(c(1,2), n %/% 2)
#' rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
#' rfx_basis <- cbind(1, runif(n, -1, 1))
#' rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis)
#' E_y <- f_XW + rfx_term
#' y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' rfx_group_ids_test <- group_ids[test_inds]
#' rfx_group_ids_train <- group_ids[train_inds]
#' rfx_basis_test <- rfx_basis[test_inds,]
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, 
#'                    rfx_group_ids_train = rfx_group_ids_train, 
#'                    rfx_group_ids_test = rfx_group_ids_test, 
#'                    rfx_basis_train = rfx_basis_train, 
#'                    rfx_basis_test = rfx_basis_test, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' rfx_samples <- getRandomEffectSamples(bart_model)
getRandomEffectSamples.bartmodel <- function(object, ...){
    result = list()
    
    if (!object$model_params$has_rfx) {
        warning("This model has no RFX terms, returning an empty list")
        return(result)
    }
    
    # Extract the samples
    result <- object$rfx_samples$extract_parameter_samples()
    
    # Scale by sd(y_train)
    result$beta_samples <- result$beta_samples*object$model_params$outcome_scale
    result$xi_samples <- result$xi_samples*object$model_params$outcome_scale
    result$alpha_samples <- result$alpha_samples*object$model_params$outcome_scale
    result$sigma_samples <- result$sigma_samples*(object$model_params$outcome_scale^2)
    
    return(result)
}

#' Convert the persistent aspects of a BART model to (in-memory) JSON
#'
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
#'
#' @return Object of type `CppJson`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json <- saveBARTModelToJson(bart_model)
saveBARTModelToJson <- function(object){
    jsonobj <- createCppJson()
    
    if (!inherits(object, "bartmodel")) {
        stop("`object` must be a BART model")
    }
    
    if (is.null(object$model_params)) {
        stop("This BCF model has not yet been sampled")
    }

    # Add the forests
    if (object$model_params$include_mean_forest) {
        jsonobj$add_forest(object$mean_forests)
    }
    if (object$model_params$include_variance_forest) {
        jsonobj$add_forest(object$variance_forests)
    }

    # Add metadata
    jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars)
    jsonobj$add_scalar("num_ordered_cat_vars", object$train_set_metadata$num_ordered_cat_vars)
    jsonobj$add_scalar("num_unordered_cat_vars", object$train_set_metadata$num_unordered_cat_vars)
    if (object$train_set_metadata$num_numeric_vars > 0) {
        jsonobj$add_string_vector("numeric_vars", object$train_set_metadata$numeric_vars)
    }
    if (object$train_set_metadata$num_ordered_cat_vars > 0) {
        jsonobj$add_string_vector("ordered_cat_vars", object$train_set_metadata$ordered_cat_vars)
        jsonobj$add_string_list("ordered_unique_levels", object$train_set_metadata$ordered_unique_levels)
    }
    if (object$train_set_metadata$num_unordered_cat_vars > 0) {
        jsonobj$add_string_vector("unordered_cat_vars", object$train_set_metadata$unordered_cat_vars)
        jsonobj$add_string_list("unordered_unique_levels", object$train_set_metadata$unordered_unique_levels)
    }
    
    # Add global parameters
    jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
    jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
    jsonobj$add_boolean("standardize", object$model_params$standardize)
    jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init)
    jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
    jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf)
    jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest)
    jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest)
    jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
    jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
    jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
    jsonobj$add_scalar("num_gfr", object$model_params$num_gfr)
    jsonobj$add_scalar("num_burnin", object$model_params$num_burnin)
    jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc)
    jsonobj$add_scalar("num_samples", object$model_params$num_samples)
    jsonobj$add_scalar("num_covariates", object$model_params$num_covariates)
    jsonobj$add_scalar("num_basis", object$model_params$num_basis)
    jsonobj$add_scalar("num_chains", object$model_params$num_chains)
    jsonobj$add_scalar("keep_every", object$model_params$keep_every)
    jsonobj$add_boolean("requires_basis", object$model_params$requires_basis)
    if (object$model_params$sample_sigma_global) {
        jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters")
    }
    if (object$model_params$sample_sigma_leaf) {
        jsonobj$add_vector("sigma2_leaf_samples", object$sigma2_leaf_samples, "parameters")
    }

    # Add random effects (if present)
    if (object$model_params$has_rfx) {
        jsonobj$add_random_effects(object$rfx_samples)
        jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids)
    }
    
    # Add covariate preprocessor metadata
    preprocessor_metadata_string <- savePreprocessorToJsonString(
        object$train_set_metadata
    )
    jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string)
    
    return(jsonobj)
}

#' Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file
#'
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
#' @param filename String of filepath, must end in ".json"
#'
#' @return None
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' tmpjson <- tempfile(fileext = ".json")
#' saveBARTModelToJsonFile(bart_model, file.path(tmpjson))
#' unlink(tmpjson)
saveBARTModelToJsonFile <- function(object, filename){
    # Convert to Json
    jsonobj <- saveBARTModelToJson(object)
    
    # Save to file
    jsonobj$save_file(filename)
}

#' Convert the persistent aspects of a BART model to (in-memory) JSON string
#'
#' @param object Object of type `bartmodel` containing draws of a BART model and associated sampling outputs.
#' @return in-memory JSON string
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json_string <- saveBARTModelToJsonString(bart_model)
saveBARTModelToJsonString <- function(object){
    # Convert to Json
    jsonobj <- saveBARTModelToJson(object)
    
    # Dump to string
    return(jsonobj$return_json_string())
}

#' Convert an (in-memory) JSON representation of a BART model to a BART model object 
#' which can be used for prediction, etc...
#'
#' @param json_object Object of type `CppJson` containing Json representation of a BART model
#'
#' @return Object of type `bartmodel`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json <- saveBARTModelToJson(bart_model)
#' bart_model_roundtrip <- createBARTModelFromJson(bart_json)
createBARTModelFromJson <- function(json_object){
    # Initialize the BCF model
    output <- list()
    
    # Unpack the forests
    include_mean_forest <- json_object$get_boolean("include_mean_forest")
    include_variance_forest <- json_object$get_boolean("include_variance_forest")
    if (include_mean_forest) {
        output[["mean_forests"]] <- loadForestContainerJson(json_object, "forest_0")
        if (include_variance_forest) {
            output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_1")
        }
    } else {
        output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_0")
    }

    # Unpack metadata
    train_set_metadata = list()
    train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars")
    train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars")
    train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars")
    if (train_set_metadata[["num_numeric_vars"]] > 0) {
        train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars")
    }
    if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
        train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars")
        train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]])
    }
    if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
        train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars")
        train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]])
    }
    output[["train_set_metadata"]] <- train_set_metadata
    
    # Unpack model params
    model_params = list()
    model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
    model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
    model_params[["standardize"]] <- json_object$get_boolean("standardize")
    model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init")
    model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
    model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
    model_params[["include_mean_forest"]] <- include_mean_forest
    model_params[["include_variance_forest"]] <- include_variance_forest
    model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx")
    model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
    model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
    model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
    model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
    model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
    model_params[["num_samples"]] <- json_object$get_scalar("num_samples")
    model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates")
    model_params[["num_basis"]] <- json_object$get_scalar("num_basis")
    model_params[["num_chains"]] <- json_object$get_scalar("num_chains")
    model_params[["keep_every"]] <- json_object$get_scalar("keep_every")
    model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis")
    output[["model_params"]] <- model_params
    
    # Unpack sampled parameters
    if (model_params[["sample_sigma_global"]]) {
        output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
    }
    if (model_params[["sample_sigma_leaf"]]) {
        output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
    }

    # Unpack random effects
    if (model_params[["has_rfx"]]) {
        output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids")
        output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0)
    }
    
    # Unpack covariate preprocessor
    preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
    output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
        preprocessor_metadata_string
    )
    
    class(output) <- "bartmodel"
    return(output)
}

#' Convert a JSON file containing sample information on a trained BART model 
#' to a BART model object which can be used for prediction, etc...
#'
#' @param json_filename String of filepath, must end in ".json"
#'
#' @return Object of type `bartmodel`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' tmpjson <- tempfile(fileext = ".json")
#' saveBARTModelToJsonFile(bart_model, file.path(tmpjson))
#' bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson))
#' unlink(tmpjson)
createBARTModelFromJsonFile <- function(json_filename){
    # Load a `CppJson` object from file
    bart_json <- createCppJsonFile(json_filename)
    
    # Create and return the BART object
    bart_object <- createBARTModelFromJson(bart_json)
    
    return(bart_object)
}

#' Convert a JSON string containing sample information on a trained BART model 
#' to a BART model object which can be used for prediction, etc...
#'
#' @param json_string JSON string dump
#'
#' @return Object of type `bartmodel`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json <- saveBARTModelToJsonString(bart_model)
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
createBARTModelFromJsonString <- function(json_string){
    # Load a `CppJson` object from string
    bart_json <- createCppJsonString(json_string)
    
    # Create and return the BART object
    bart_object <- createBARTModelFromJson(bart_json)
    
    return(bart_object)
}

#' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object 
#' which can be used for prediction, etc...
#'
#' @param json_object_list List of objects of type `CppJson` containing Json representation of a BART model
#'
#' @return Object of type `bartmodel`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json <- list(saveBARTModelToJson(bart_model))
#' bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json)
createBARTModelFromCombinedJson <- function(json_object_list){
    # Initialize the BCF model
    output <- list()

    # For scalar / preprocessing details which aren't sample-dependent, 
    # defer to the first json
    json_object_default <- json_object_list[[1]]
    
    # Unpack the forests
    include_mean_forest <- json_object_default$get_boolean("include_mean_forest")
    include_variance_forest <- json_object_default$get_boolean("include_variance_forest")
    if (include_mean_forest) {
        output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
        if (include_variance_forest) {
            output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1")
        }
    } else {
        output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
    }
    
    # Unpack metadata
    train_set_metadata = list()
    train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars")
    train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars")
    train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars")
    if (train_set_metadata[["num_numeric_vars"]] > 0) {
        train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars")
    }
    if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
        train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars")
        train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]])
    }
    if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
        train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars")
        train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]])
    }
    output[["train_set_metadata"]] <- train_set_metadata

    # Unpack model params
    model_params = list()
    model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
    model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
    model_params[["standardize"]] <- json_object_default$get_boolean("standardize")
    model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
    model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global")
    model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf")
    model_params[["include_mean_forest"]] <- include_mean_forest
    model_params[["include_variance_forest"]] <- include_variance_forest
    model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
    model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis")
    model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis")
    model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
    model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis")
    model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
    model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
    model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")

    # Combine values that are sample-specific
    for (i in 1:length(json_object_list)) {
        json_object <- json_object_list[[i]]
        if (i == 1) {
            model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
            model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
            model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
            model_params[["num_samples"]] <- json_object$get_scalar("num_samples")
        } else {
            prev_json <- json_object_list[[i-1]]
            model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr")
            model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin")
            model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc")
            model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples")
        }
    }
    output[["model_params"]] <- model_params
    
    # Unpack sampled parameters
    if (model_params[["sample_sigma_global"]]) {
        for (i in 1:length(json_object_list)) {
            json_object <- json_object_list[[i]]
            if (i == 1) {
                output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
            } else {
                output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters"))
            }
        }
    }
    if (model_params[["sample_sigma_leaf"]]) {
        for (i in 1:length(json_object_list)) {
            json_object <- json_object_list[[i]]
            if (i == 1) {
                output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
            } else {
                output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters"))
            }
        }
    }
    
    # Unpack random effects
    if (model_params[["has_rfx"]]) {
        output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids")
        output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
    }
    
    # Unpack covariate preprocessor
    preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
    output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
        preprocessor_metadata_string
    )
    
    class(output) <- "bartmodel"
    return(output)
}

#' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object 
#' which can be used for prediction, etc...
#'
#' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BART model
#'
#' @return Object of type `bartmodel`
#' @export
#'
#' @examples
#' n <- 100
#' p <- 5
#' X <- matrix(runif(n*p), ncol = p)
#' f_XW <- (
#'     ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
#'     ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
#'     ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
#'     ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
#' )
#' noise_sd <- 1
#' y <- f_XW + rnorm(n, 0, noise_sd)
#' test_set_pct <- 0.2
#' n_test <- round(test_set_pct*n)
#' n_train <- n - n_test
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
#' X_test <- X[test_inds,]
#' X_train <- X[train_inds,]
#' y_test <- y[test_inds]
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, 
#'                    num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' bart_json_string_list <- list(saveBARTModelToJsonString(bart_model))
#' bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list)
createBARTModelFromCombinedJsonString <- function(json_string_list){
    # Initialize the BCF model
    output <- list()
    
    # Convert JSON strings
    json_object_list <- list()
    for (i in 1:length(json_string_list)) {
        json_string <- json_string_list[[i]]
        json_object_list[[i]] <- createCppJsonString(json_string)
    }
    
    # For scalar / preprocessing details which aren't sample-dependent, 
    # defer to the first json
    json_object_default <- json_object_list[[1]]
    
    # Unpack the forests
    include_mean_forest <- json_object_default$get_boolean("include_mean_forest")
    include_variance_forest <- json_object_default$get_boolean("include_variance_forest")
    if (include_mean_forest) {
        output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
        if (include_variance_forest) {
            output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1")
        }
    } else {
        output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0")
    }
    
    # Unpack metadata
    train_set_metadata = list()
    train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars")
    train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars")
    train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars")
    if (train_set_metadata[["num_numeric_vars"]] > 0) {
        train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars")
    }
    if (train_set_metadata[["num_ordered_cat_vars"]] > 0) {
        train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars")
        train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]])
    }
    if (train_set_metadata[["num_unordered_cat_vars"]] > 0) {
        train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars")
        train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]])
    }
    output[["train_set_metadata"]] <- train_set_metadata

    # Unpack model params
    model_params = list()
    model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
    model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
    model_params[["standardize"]] <- json_object_default$get_boolean("standardize")
    model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
    model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global")
    model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf")
    model_params[["include_mean_forest"]] <- include_mean_forest
    model_params[["include_variance_forest"]] <- include_variance_forest
    model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx")
    model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis")
    model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis")
    model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
    model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis")
    model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
    model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
    model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
    
    # Combine values that are sample-specific
    for (i in 1:length(json_object_list)) {
        json_object <- json_object_list[[i]]
        if (i == 1) {
            model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
            model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
            model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc")
            model_params[["num_samples"]] <- json_object$get_scalar("num_samples")
        } else {
            prev_json <- json_object_list[[i-1]]
            model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr")
            model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin")
            model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc")
            model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples")
        }
    }
    output[["model_params"]] <- model_params
    
    # Unpack sampled parameters
    if (model_params[["sample_sigma_global"]]) {
        for (i in 1:length(json_object_list)) {
            json_object <- json_object_list[[i]]
            if (i == 1) {
                output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters")
            } else {
                output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters"))
            }
        }
    }
    if (model_params[["sample_sigma_leaf"]]) {
        for (i in 1:length(json_object_list)) {
            json_object <- json_object_list[[i]]
            if (i == 1) {
                output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters")
            } else {
                output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters"))
            }
        }
    }
    
    # Unpack random effects
    if (model_params[["has_rfx"]]) {
        output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids")
        output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0)
    }
    
    # Unpack covariate preprocessor
    preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
    output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
        preprocessor_metadata_string
    )
    
    class(output) <- "bartmodel"
    return(output)
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 4, 2025, 2:11 a.m.