computeForestMaxLeafIndex: Compute and return the largest possible leaf index computable...

View source: R/kernel.R

computeForestMaxLeafIndexR Documentation

Compute and return the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

Description

Compute and return the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

Usage

computeForestMaxLeafIndex(
  model_object,
  covariates,
  forest_type = NULL,
  forest_inds = NULL
)

Arguments

model_object

Object of type bartmodel, bcfmodel, or ForestSamples corresponding to a BART / BCF model with at least one forest sample, or a low-level ForestSamples object.

covariates

Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.

forest_type

Which forest to use from model_object. Valid inputs depend on the model type, and whether or not a

1. BART

  • 'mean': Extracts leaf indices for the mean forest

  • 'variance': Extracts leaf indices for the variance forest

2. BCF

  • 'prognostic': Extracts leaf indices for the prognostic forest

  • 'treatment': Extracts leaf indices for the treatment effect forest

  • 'variance': Extracts leaf indices for the variance forest

3. ForestSamples

  • NULL: It is not necessary to disambiguate when this function is called directly on a ForestSamples object. This is the default value of this

forest_inds

(Optional) Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to forest_num = 0, and so on.

Value

Vector containing the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

Examples

X <- matrix(runif(10*100), ncol = 10)
y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
bart_model <- bart(X, y, num_gfr=0, num_mcmc=10)
computeForestMaxLeafIndex(bart_model, X, "mean")
computeForestMaxLeafIndex(bart_model, X, "mean", 0)
computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))

stochtree documentation built on April 4, 2025, 2:11 a.m.