ForestModel: Class that defines and samples a forest model

ForestModelR Documentation

Class that defines and samples a forest model

Description

Hosts the C++ data structures needed to sample an ensemble of decision trees, and exposes functionality to run a forest sampler (using either MCMC or the grow-from-root algorithm).

Public fields

tracker_ptr

External pointer to a C++ ForestTracker class

tree_prior_ptr

External pointer to a C++ TreePrior class

Methods

Public methods


Method new()

Create a new ForestModel object.

Usage
ForestModel$new(
  forest_dataset,
  feature_types,
  num_trees,
  n,
  alpha,
  beta,
  min_samples_leaf,
  max_depth = -1
)
Arguments
forest_dataset

ForestDataset object, used to initialize forest sampling data structures

feature_types

Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)

num_trees

Number of trees in the forest being sampled

n

Number of observations in forest_dataset

alpha

Root node split probability in tree prior

beta

Depth prior penalty in tree prior

min_samples_leaf

Minimum number of samples in a tree leaf

max_depth

Maximum depth that any tree can reach

Returns

A new ForestModel object.


Method sample_one_iteration()

Run a single iteration of the forest sampling algorithm (MCMC or GFR)

Usage
ForestModel$sample_one_iteration(
  forest_dataset,
  residual,
  forest_samples,
  active_forest,
  rng,
  forest_model_config,
  global_model_config,
  keep_forest = TRUE,
  gfr = TRUE
)
Arguments
forest_dataset

Dataset used to sample the forest

residual

Outcome used to sample the forest

forest_samples

Container of forest samples

active_forest

"Active" forest updated by the sampler in each iteration

rng

Wrapper around C++ random number generator

forest_model_config

ForestModelConfig object containing forest model parameters and settings

global_model_config

GlobalModelConfig object containing global model parameters and settings

keep_forest

(Optional) Whether the updated forest sample should be saved to forest_samples. Default: TRUE.

gfr

(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: TRUE.


Method propagate_basis_update()

Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.

This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.

Usage
ForestModel$propagate_basis_update(dataset, outcome, active_forest)
Arguments
dataset

ForestDataset object storing the covariates and bases for a given forest

outcome

Outcome object storing the residuals to be updated based on forest predictions

active_forest

"Active" forest updated by the sampler in each iteration


Method propagate_residual_update()

Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. This function is run after the Outcome class's update_data method, which overwrites the partial residual with an entirely new stream of outcome data.

Usage
ForestModel$propagate_residual_update(residual)
Arguments
residual

Outcome used to sample the forest

Returns

None


Method update_alpha()

Update alpha in the tree prior

Usage
ForestModel$update_alpha(alpha)
Arguments
alpha

New value of alpha to be used

Returns

None


Method update_beta()

Update beta in the tree prior

Usage
ForestModel$update_beta(beta)
Arguments
beta

New value of beta to be used

Returns

None


Method update_min_samples_leaf()

Update min_samples_leaf in the tree prior

Usage
ForestModel$update_min_samples_leaf(min_samples_leaf)
Arguments
min_samples_leaf

New value of min_samples_leaf to be used

Returns

None


Method update_max_depth()

Update max_depth in the tree prior

Usage
ForestModel$update_max_depth(max_depth)
Arguments
max_depth

New value of max_depth to be used

Returns

None


stochtree documentation built on April 4, 2025, 2:11 a.m.