R/model.R

Defines functions createForestModel createCppRNG

Documented in createCppRNG createForestModel

#' Class that wraps a C++ random number generator (for reproducibility)
#'
#' @description
#' Persists a C++ random number generator throughout an R session to 
#' ensure reproducibility from a given random seed. If no seed is provided, 
#' the C++ random number generator is initialized using `std::random_device`.

CppRNG <- R6::R6Class(
    classname = "CppRNG",
    cloneable = FALSE,
    public = list(
        
        #' @field rng_ptr External pointer to a C++ std::mt19937 class
        rng_ptr = NULL,

        #' @description
        #' Create a new CppRNG object.
        #' @param random_seed (Optional) random seed for sampling
        #' @return A new `CppRNG` object.
        initialize = function(random_seed = -1) {
            self$rng_ptr <- rng_cpp(random_seed)
        }
    )
)

#' Class that defines and samples a forest model
#'
#' @description
#' Hosts the C++ data structures needed to sample an ensemble of decision 
#' trees, and exposes functionality to run a forest sampler 
#' (using either MCMC or the grow-from-root algorithm).

ForestModel <- R6::R6Class(
    classname = "ForestModel",
    cloneable = FALSE,
    public = list(
        
        #' @field tracker_ptr External pointer to a C++ ForestTracker class
        tracker_ptr = NULL,
        
        #' @field tree_prior_ptr External pointer to a C++ TreePrior class
        tree_prior_ptr = NULL, 
        
        #' @description
        #' Create a new ForestModel object.
        #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures
        #' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
        #' @param num_trees Number of trees in the forest being sampled
        #' @param n Number of observations in `forest_dataset`
        #' @param alpha Root node split probability in tree prior
        #' @param beta Depth prior penalty in tree prior
        #' @param min_samples_leaf Minimum number of samples in a tree leaf
        #' @param max_depth Maximum depth that any tree can reach
        #' @return A new `ForestModel` object.
        initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) {
            stopifnot(!is.null(forest_dataset$data_ptr))
            self$tracker_ptr <- forest_tracker_cpp(forest_dataset$data_ptr, feature_types, num_trees, n)
            self$tree_prior_ptr <- tree_prior_cpp(alpha, beta, min_samples_leaf, max_depth)
        }, 
        
        #' @description
        #' Run a single iteration of the forest sampling algorithm (MCMC or GFR)
        #' @param forest_dataset Dataset used to sample the forest
        #' @param residual Outcome used to sample the forest
        #' @param forest_samples Container of forest samples
        #' @param active_forest "Active" forest updated by the sampler in each iteration
        #' @param rng Wrapper around C++ random number generator
        #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings
        #' @param global_model_config GlobalModelConfig object containing global model parameters and settings
        #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`.
        #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`.
        sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, 
                                        rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE) {
            if (active_forest$is_empty()) {
                stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.")
            }

            # Unpack parameters from model config object
            feature_types <- forest_model_config$feature_types
            leaf_model_int <- forest_model_config$leaf_model_type
            leaf_model_scale <- forest_model_config$leaf_model_scale
            variable_weights <- forest_model_config$variable_weights
            a_forest <- forest_model_config$variance_forest_shape
            b_forest <- forest_model_config$variance_forest_scale
            global_scale <- global_model_config$global_error_variance
            cutpoint_grid_size <- forest_model_config$cutpoint_grid_size
            
            if (gfr) {
                sample_gfr_one_iteration_cpp(
                    forest_dataset$data_ptr, residual$data_ptr, 
                    forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, 
                    self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, 
                    variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
                )
            } else {
                sample_mcmc_one_iteration_cpp(
                    forest_dataset$data_ptr, residual$data_ptr, 
                    forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, 
                    self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, 
                    variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
                ) 
            }
        }, 
        
        #' @description
        #' Propagates basis update through to the (full/partial) residual by iteratively 
        #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions 
        #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
        #' 
        #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside 
        #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). 
        #' Once a basis has been updated, the overall "function" represented by a tree model has 
        #' changed and this should be reflected through to the residual before the next sampling loop is run.
        #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest
        #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions
        #' @param active_forest "Active" forest updated by the sampler in each iteration
        propagate_basis_update = function(dataset, outcome, active_forest) {
            stopifnot(!is.null(dataset$data_ptr))
            stopifnot(!is.null(outcome$data_ptr))
            stopifnot(!is.null(self$tracker_ptr))
            stopifnot(!is.null(active_forest$forest_ptr))
            
            propagate_basis_update_active_forest_cpp(
                dataset$data_ptr, outcome$data_ptr, active_forest$forest_ptr, 
                self$tracker_ptr
            )
        }, 
        
        #' @description
        #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. 
        #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data.
        #' @param residual Outcome used to sample the forest
        #' @return None
        propagate_residual_update = function(residual) {
            propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr)
        }, 
        
        #' @description
        #' Update alpha in the tree prior
        #' @param alpha New value of alpha to be used
        #' @return None
        update_alpha = function(alpha) {
            update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha)
        }, 
        
        #' @description
        #' Update beta in the tree prior
        #' @param beta New value of beta to be used
        #' @return None
        update_beta = function(beta) {
            update_beta_tree_prior_cpp(self$tree_prior_ptr, beta)
        }, 
        
        #' @description
        #' Update min_samples_leaf in the tree prior
        #' @param min_samples_leaf New value of min_samples_leaf to be used
        #' @return None
        update_min_samples_leaf = function(min_samples_leaf) {
            update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, min_samples_leaf)
        }, 
        
        #' @description
        #' Update max_depth in the tree prior
        #' @param max_depth New value of max_depth to be used
        #' @return None
        update_max_depth = function(max_depth) {
            update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth)
        }
    )
)

#' Create an R class that wraps a C++ random number generator
#'
#' @param random_seed (Optional) random seed for sampling
#'
#' @return `CppRng` object
#' @export
#' 
#' @examples
#' rng <- createCppRNG(1234)
#' rng <- createCppRNG()
createCppRNG <- function(random_seed = -1){
    return(invisible((
        CppRNG$new(random_seed)
    )))
}

#' Create a forest model object
#'
#' @param forest_dataset ForestDataset object, used to initialize forest sampling data structures
#' @param forest_model_config ForestModelConfig object containing forest model parameters and settings
#' @param global_model_config GlobalModelConfig object containing global model parameters and settings
#'
#' @return `ForestModel` object
#' @export
#' 
#' @examples
#' num_trees <- 100
#' n <- 100
#' p <- 10
#' alpha <- 0.95
#' beta <- 2.0
#' min_samples_leaf <- 2
#' max_depth <- 10
#' feature_types <- as.integer(rep(0, p))
#' X <- matrix(runif(n*p), ncol = p)
#' forest_dataset <- createForestDataset(X)
#' forest_model_config <- createForestModelConfig(feature_types=feature_types, 
#'                                                num_trees=num_trees, num_features=p, 
#'                                                num_observations=n, alpha=alpha, beta=beta, 
#'                                                min_samples_leaf=min_samples_leaf, 
#'                                                max_depth=max_depth, leaf_model_type=1)
#' global_model_config <- createGlobalModelConfig(global_error_variance=1.0)
#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
createForestModel <- function(forest_dataset, forest_model_config, global_model_config) {
    return(invisible((
        ForestModel$new(forest_dataset, forest_model_config$feature_types, forest_model_config$num_trees, 
                        forest_model_config$num_observations, forest_model_config$alpha, forest_model_config$beta, 
                        forest_model_config$min_samples_leaf, forest_model_config$max_depth)
    )))
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 4, 2025, 2:11 a.m.