MakeForest: Create an Rcpp_Forest Object

View source: R/MakeForest.R

MakeForestR Documentation

Create an Rcpp_Forest Object

Description

Make an object of type Rcpp_Forest, which can be used to embed a soft BART model into other models. Some examples are given in the package vignette.

Usage

MakeForest(hypers, opts, warn = TRUE)

Arguments

hypers

A list of hyperparameter values obtained from Hypers() function

opts

A list of MCMC chain settings obtained from Opts() function

warn

If TRUE, reminds the user to normalize their design matrix when interacting with a forest object.

Value

Returns an object of type Rcpp_Forest. If forest is an Rcpp_Forest object then it has the following methods.

  • forest$do_gibbs(X, Y, X_test, i) runs i iterations of the Bayesian backfitting algorithm and predicts on the test set X_test. The state of forest is also updated.

  • forest$do_gibbs_weighted(X, Y, weights X_test, i) runs i iterations of the Bayesian backfitting algorithm and predicts on the test set X_test; assumes that Y is heteroskedastic with known weights. The state of forest is also updated.

  • forest$do_predict(X) returns the predictions from a matrix X of predictors.

  • forest$get_counts() returns the number of times each variable has been used in a splitting rule at the current state of forest.

  • forest$get_s() returns the splitting probabilities of the forest.

  • forest$get_sigma() returns the error standard deviation of the forest.

  • forest$get_sigma_mu() returns the standard deviation of the leaf node parameters.

  • forest$get_tree_counts() returns a matrix with a row for each group of predictors and a column for each tree that counts the number of times each group of predictors is used in each tree at the current state of forest.

  • forest$predict_iteration(X, i) returns the predictions from a matrix X of predictors at iteration i. Requires that opts$cache_trees = TRUE in MakeForest(hypers, opts).

  • forest$set_s(s) sets the splitting probabilities of the forest to s.

  • forest$set_sigma(x) sets the error standard deviation of the forest to x.

  • forest$num_gibbs returns the number of iterations in total that the Gibbs sampler has been run.

Examples


X <- matrix(runif(100 * 10), nrow = 100, ncol = 10)
Y <- rowSums(X) + rnorm(100)
my_forest <- MakeForest(Hypers(X,Y), Opts())
mu_hat <- my_forest$do_gibbs(X,Y,X,200)


SoftBart documentation built on June 8, 2025, 9:40 p.m.