R/config.R

Defines functions createGlobalModelConfig createForestModelConfig

Documented in createForestModelConfig createGlobalModelConfig

#' Object used to get / set parameters and other model configuration options
#' for a forest model in the "low-level" stochtree interface
#'
#' @description
#' The "low-level" stochtree interface enables a high degreee of sampler
#' customization, in which users employ R wrappers around C++ objects
#' like ForestDataset, Outcome, CppRng, and ForestModel to run the
#' Gibbs sampler of a BART model with custom modifications.
#' ForestModelConfig allows users to specify / query the parameters of a
#' forest model they wish to run.

ForestModelConfig <- R6::R6Class(
    classname = "ForestModelConfig",
    cloneable = FALSE,
    public = list(

        #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
        feature_types = NULL,
        
        #' @field num_trees Number of trees in the forest being sampled
        num_trees = NULL,
        
        #' @field num_features Number of features in training dataset
        num_features = NULL,
        
        #' @field num_observations Number of observations in training dataset
        num_observations = NULL,
        
        #' @field leaf_dimension Dimension of the leaf model
        leaf_dimension = NULL,
        
        #' @field alpha Root node split probability in tree prior
        alpha = NULL,
        
        #' @field beta Depth prior penalty in tree prior
        beta = NULL,
        
        #' @field min_samples_leaf Minimum number of samples in a tree leaf
        min_samples_leaf = NULL,
        
        #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees.
        max_depth = NULL,
        
        #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)
        leaf_model_type = NULL,
        
        #' @field leaf_model_scale Scale parameter used in Gaussian leaf models
        leaf_model_scale = NULL,
        
        #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
        variable_weights = NULL,
        
        #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`)
        variance_forest_shape = NULL,
        
        #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`)
        variance_forest_scale = NULL,
        
        #' @field cutpoint_grid_size Number of unique cutpoints to consider
        cutpoint_grid_size = NULL,

        #' Create a new ForestModelConfig object.
        #'
        #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
        #' @param num_trees Number of trees in the forest being sampled
        #' @param num_features Number of features in training dataset
        #' @param num_observations Number of observations in training dataset
        #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
        #' @param leaf_dimension Dimension of the leaf model (default: `1`)
        #' @param alpha Root node split probability in tree prior (default: `0.95`)
        #' @param beta Depth prior penalty in tree prior (default: `2.0`)
        #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`)
        #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`.
        #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`.
        #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models.
        #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`.
        #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`.
        #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`)
        #' 
        #' @return A new ForestModelConfig object.
        initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL, 
                              num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, 
                              alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, 
                              leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, 
                              variance_forest_scale = 1.0, cutpoint_grid_size = 100) {
            if (is.null(feature_types)) {
                if (is.null(num_features)) {
                    stop("Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object.")
                }
                warning("`feature_types` not provided, will be assumed to be numeric")
                feature_types <- rep(0, num_features)
            } else {
                if (is.null(num_features)) {
                    num_features <- length(feature_types)
                }
            }
            if (is.null(variable_weights)) {
                warning("`variable_weights` not provided, will be assumed to be equal-weighted")
                variable_weights <- rep(1/num_features, num_features)
            }
            if (num_features != length(feature_types)) {
                stop("`feature_types` must have `num_features` total elements")
            }
            if (num_features != length(variable_weights)) {
                stop("`variable_weights` must have `num_features` total elements")
            }
            self$feature_types <- feature_types
            self$variable_weights <- variable_weights
            self$num_trees <- num_trees
            self$num_features <- num_features
            self$num_observations <- num_observations
            self$leaf_dimension <- leaf_dimension
            self$alpha <- alpha
            self$beta <- beta
            self$min_samples_leaf <- min_samples_leaf
            self$max_depth <- max_depth
            self$variance_forest_shape <- variance_forest_shape
            self$variance_forest_scale <- variance_forest_scale
            self$cutpoint_grid_size <- cutpoint_grid_size
            
            if (!(as.integer(leaf_model_type) == leaf_model_type)) {
                stop("`leaf_model_type` must be an integer between 0 and 3")
                if ((leaf_model_type < 0) | (leaf_model_type > 3)) {
                    stop("`leaf_model_type` must be an integer between 0 and 3")
                }
            }
            self$leaf_model_type <- leaf_model_type
            
            if (is.null(leaf_model_scale)) {
                self$leaf_model_scale <- diag(1/num_trees, leaf_dimension)
            } else if (is.matrix(leaf_model_scale)) {
                if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) {
                    stop("`leaf_model_scale` must be a square matrix")
                }
                if (ncol(leaf_model_scale) != leaf_dimension) {
                    stop("`leaf_model_scale` must have `leaf_dimension` rows and columns")
                }
                self$leaf_model_scale <- leaf_model_scale
            } else {
                if (leaf_model_scale <= 0) {
                    stop("`leaf_model_scale` must be positive, if provided as scalar")
                }
                self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension)
            }
        },
        
        #' @description
        #' Update feature types
        #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
        update_feature_types = function(feature_types) {
            stopifnot(length(feature_types) == self$num_features)
            self$feature_types <- feature_types
        }, 
        
        #' @description
        #' Update variable weights
        #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
        update_variable_weights = function(variable_weights) {
            stopifnot(length(variable_weights) == self$num_features)
            self$variable_weights <- variable_weights
        }, 
        
        #' @description
        #' Update root node split probability in tree prior
        #' @param alpha Root node split probability in tree prior
        update_alpha = function(alpha) {
            self$alpha <- alpha
        }, 
        
        #' @description
        #' Update depth prior penalty in tree prior
        #' @param beta Depth prior penalty in tree prior
        update_beta = function(beta) {
            self$beta <- beta
        }, 
        
        #' @description
        #' Update root node split probability in tree prior
        #' @param min_samples_leaf Minimum number of samples in a tree leaf
        update_min_samples_leaf = function(min_samples_leaf) {
            self$min_samples_leaf <- min_samples_leaf
        }, 
        
        #' @description
        #' Update root node split probability in tree prior
        #' @param max_depth Maximum depth of any tree in the ensemble in the model
        update_max_depth = function(max_depth) {
            self$max_depth <- max_depth
        }, 
        
        #' @description
        #' Update scale parameter used in Gaussian leaf models
        #' @param leaf_model_scale Scale parameter used in Gaussian leaf models
        update_leaf_model_scale = function(leaf_model_scale) {
            if (is.matrix(leaf_model_scale)) {
                if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) {
                    stop("`leaf_model_scale` must be a square matrix")
                }
                if (ncol(leaf_model_scale) != self$leaf_dimension) {
                    stop("`leaf_model_scale` must have `leaf_dimension` rows and columns")
                }
                self$leaf_model_scale <- leaf_model_scale
            } else {
                if (leaf_model_scale <= 0) {
                    stop("`leaf_model_scale` must be positive, if provided as scalar")
                }
                self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension)
            }
        }, 
        
        #' @description
        #' Update shape parameter for IG leaf models
        #' @param variance_forest_shape Shape parameter for IG leaf models
        update_variance_forest_shape = function(variance_forest_shape) {
            self$variance_forest_shape <- variance_forest_shape
        }, 
        
        #' @description
        #' Update scale parameter for IG leaf models
        #' @param variance_forest_scale Scale parameter for IG leaf models
        update_variance_forest_scale = function(variance_forest_scale) {
            self$variance_forest_scale <- variance_forest_scale
        }, 
        
        #' @description
        #' Update number of unique cutpoints to consider
        #' @param cutpoint_grid_size Number of unique cutpoints to consider
        update_cutpoint_grid_size = function(cutpoint_grid_size) {
            self$cutpoint_grid_size <- cutpoint_grid_size
        },
        
        #' @description
        #' Query feature types for this ForestModelConfig object
        #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
        get_feature_types = function() {
            return(self$feature_types)
        }, 
        
        #' @description
        #' Query variable weights for this ForestModelConfig object
        #' @returns Vector specifying sampling probability for all p covariates in ForestDataset
        get_variable_weights = function() {
            return(self$variable_weights)
        }, 
        
        #' @description
        #' Query root node split probability in tree prior for this ForestModelConfig object
        #' @returns Root node split probability in tree prior
        get_alpha = function() {
            return(self$alpha)
        }, 
        
        #' @description
        #' Query depth prior penalty in tree prior for this ForestModelConfig object
        #' @returns Depth prior penalty in tree prior
        get_beta = function() {
            return(self$beta)
        }, 
        
        #' @description
        #' Query root node split probability in tree prior for this ForestModelConfig object
        #' @returns Minimum number of samples in a tree leaf
        get_min_samples_leaf = function() {
            return(self$min_samples_leaf)
        }, 
        
        #' @description
        #' Query root node split probability in tree prior for this ForestModelConfig object
        #' @returns Maximum depth of any tree in the ensemble in the model
        get_max_depth = function() {
            return(self$max_depth)
        }, 
        
        #' @description
        #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object
        #' @returns Scale parameter used in Gaussian leaf models
        get_leaf_model_scale = function() {
            return(self$leaf_model_scale)
        }, 
        
        #' @description
        #' Query shape parameter for IG leaf models for this ForestModelConfig object
        #' @returns Shape parameter for IG leaf models
        get_variance_forest_shape = function() {
            return(self$variance_forest_shape)
        }, 
        
        #' @description
        #' Query scale parameter for IG leaf models for this ForestModelConfig object
        #' @returns Scale parameter for IG leaf models
        get_variance_forest_scale = function() {
            return(self$variance_forest_scale)
        }, 
        
        #' @description
        #' Query number of unique cutpoints to consider for this ForestModelConfig object
        #' @returns Number of unique cutpoints to consider
        get_cutpoint_grid_size = function() {
            return(self$cutpoint_grid_size)
        }
    )
)

#' Object used to get / set global parameters and other global model 
#' configuration options in the "low-level" stochtree interface
#'
#' @description
#' The "low-level" stochtree interface enables a high degreee of sampler
#' customization, in which users employ R wrappers around C++ objects
#' like ForestDataset, Outcome, CppRng, and ForestModel to run the
#' Gibbs sampler of a BART model with custom modifications.
#' GlobalModelConfig allows users to specify / query the global parameters 
#' of a model they wish to run.

GlobalModelConfig <- R6::R6Class(
    classname = "GlobalModelConfig",
    cloneable = FALSE,
    public = list(
        
        #' @field global_error_variance Global error variance parameter
        global_error_variance = NULL,

        #' Create a new GlobalModelConfig object.
        #'
        #' @param global_error_variance Global error variance parameter (default: `1.0`)
        #' 
        #' @return A new GlobalModelConfig object.
        initialize = function(global_error_variance = 1.0) {
            self$global_error_variance <- global_error_variance
        },
        
        #' @description
        #' Update global error variance parameter
        #' @param global_error_variance Global error variance parameter
        update_global_error_variance = function(global_error_variance) {
            self$global_error_variance <- global_error_variance
        }, 
        
        #' @description
        #' Query global error variance parameter for this GlobalModelConfig object
        #' @returns Global error variance parameter
        get_global_error_variance = function() {
            return(self$global_error_variance)
        }
    )
)

#' Create a forest model config object
#'
#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
#' @param num_trees Number of trees in the forest being sampled
#' @param num_features Number of features in training dataset
#' @param num_observations Number of observations in training dataset
#' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
#' @param leaf_dimension Dimension of the leaf model (default: `1`)
#' @param alpha Root node split probability in tree prior (default: `0.95`)
#' @param beta Depth prior penalty in tree prior (default: `2.0`)
#' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`)
#' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`.
#' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`.
#' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models.
#' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`.
#' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`.
#' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`)
#' @return ForestModelConfig object
#' @export
#'
#' @examples
#' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100)
createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL, 
                                    num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, 
                                    alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, 
                                    leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, 
                                    variance_forest_scale = 1.0, cutpoint_grid_size = 100){
    return(invisible((
        ForestModelConfig$new(feature_types, num_trees, num_features, num_observations, 
                              variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, 
                              max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, 
                              variance_forest_scale, cutpoint_grid_size)
    )))
}

#' Create a global model config object
#'
#' @param global_error_variance Global error variance parameter (default: `1.0`)
#' @return GlobalModelConfig object
#' @export
#'
#' @examples
#' config <- createGlobalModelConfig(global_error_variance = 100)
createGlobalModelConfig <- function(global_error_variance = 1.0){
    return(invisible((
        GlobalModelConfig$new(global_error_variance)
    )))
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 4, 2025, 2:11 a.m.