| Forest | R Documentation |
Wrapper around a C++ class that stores a single ensemble of decision trees (often treated as the "active forest" / current state of a forest term in a sampling loop in R)
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
forest_ptrExternal pointer to a C++ TreeEnsemble class
internal_forest_is_emptyWhether the forest has not yet been "initialized" such that its predict function can be called.
new()Create a new Forest object.
Forest$new( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
num_treesNumber of trees in the forest
leaf_dimensionDimensionality of the outcome model
is_leaf_constantWhether leaf is constant
is_exponentiatedWhether forest predictions should be exponentiated before being returned
A new Forest object.
merge_forest()Create a larger forest by merging the trees of this forest with those of another forest
Forest$merge_forest(forest)
forestForest to be merged into this forest
add_constant()Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, constant_value will be added to every dimension of the leaves.
Forest$add_constant(constant_value)
constant_valueValue that will be added to every leaf of every tree
multiply_constant()Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, constant_multiple will be multiplied through every dimension of the leaves.
Forest$multiply_constant(constant_multiple)
constant_multipleValue that will be multiplied by every leaf of every tree
predict()Predict forest on every sample in forest_dataset
Forest$predict(forest_dataset)
forest_datasetForestDataset R class
vector of predictions with as many rows as in forest_dataset
predict_raw()Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset
Forest$predict_raw(forest_dataset)
forest_datasetForestDataset R class
Array of predictions for each observation in forest_dataset and
each sample in the ForestSamples class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is a vector (length is the number of
observations). In the case of a multivariate leaf regression,
this array is a matrix (number of observations by leaf model dimension,
number of samples).
set_root_leaves()Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Forest$set_root_leaves(leaf_value)
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
prepare_for_sampler()Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Forest$prepare_for_sampler( dataset, outcome, forest_model, leaf_model_int, leaf_value )
datasetForestDataset Dataset class (covariates, basis, etc...)
outcomeOutcome Outcome class (residual / partial residual)
forest_modelForestModel object storing tracking structures used in training / sampling
leaf_model_intInteger value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
adjust_residual()Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)
datasetForestDataset object storing the covariates and bases for a given forest
outcomeOutcome object storing the residuals to be updated based on forest predictions
forest_modelForestModel object storing tracking structures used in training / sampling
requires_basisWhether or not a forest requires a basis for prediction
addWhether forest predictions should be added to or subtracted from residuals
num_trees()Return number of trees in each ensemble of a Forest object
Forest$num_trees()
Tree count
leaf_dimension()Return output dimension of trees in a Forest object
Forest$leaf_dimension()
Leaf node parameter size
is_leaf_constant()Return constant leaf status of trees in a Forest object
Forest$is_leaf_constant()
TRUE if leaves are constant, FALSE otherwise
is_exponentiated()Return exponentiation status of trees in a Forest object
Forest$is_exponentiated()
TRUE if leaf predictions must be exponentiated, FALSE otherwise
add_numeric_split_tree()Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble
Forest$add_numeric_split_tree( tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value )
tree_numIndex of the tree to be split
leaf_numLeaf to be split
feature_numFeature that defines the new split
split_thresholdValue that defines the cutoff of the new split
left_leaf_valueValue (or vector of values) to assign to the newly created left node
right_leaf_valueValue (or vector of values) to assign to the newly created right node
get_tree_leaves()Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Forest$get_tree_leaves(tree_num)
tree_numIndex of the tree for which leaf indices will be retrieved
get_tree_split_counts()Retrieve a vector of split counts for every training set variable in a given tree in the forest
Forest$get_tree_split_counts(tree_num, num_features)
tree_numIndex of the tree for which split counts will be retrieved
num_featuresTotal number of features in the training set
get_forest_split_counts()Retrieve a vector of split counts for every training set variable in the forest
Forest$get_forest_split_counts(num_features)
num_featuresTotal number of features in the training set
tree_max_depth()Maximum depth of a specific tree in the forest
Forest$tree_max_depth(tree_num)
tree_numTree index within forest
Maximum leaf depth
average_max_depth()Average the maximum depth of each tree in the forest
Forest$average_max_depth()
Average maximum depth
is_empty()When a forest object is created, it is "empty" in the sense that none
of its component trees have leaves with values. There are two ways to
"initialize" a Forest object. First, the set_root_leaves() method
simply initializes every tree in the forest to a single node carrying
the same (user-specified) leaf value. Second, the prepare_for_sampler()
method initializes every tree in the forest to a single node with the
same value and also propagates this information through to a ForestModel
object, which must be synchronized with a Forest during a forest
sampler loop.
Forest$is_empty()
TRUE if a Forest has not yet been initialized with a constant
root value, FALSE otherwise if the forest has already been
initialized / grown.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.