softbart: Fits the SoftBart model

View source: R/SoftBart.R

softbartR Documentation

Fits the SoftBart model

Description

Runs the Markov chain for the semiparametric Gaussian model

Y = r(X) + \epsilon

and collects the output, where r(x) is modeled using a soft BART model.

Usage

softbart(X, Y, X_test, hypers = NULL, opts = Opts(), verbose = TRUE)

Arguments

X

A matrix of training data covariates.

Y

A vector of training data responses.

X_test

A matrix of test data covariates

hypers

A ;ist of hyperparameter values obtained from Hypers function

opts

A list of MCMC chain settings obtained from Opts function

verbose

If TRUE, progress of the chain will be printed to the console.

Value

Returns a list with the following components:

  • y_hat_train: predicted values for the training data for each iteration of the chain.

  • y_hat_test: predicted values for the test data for each iteration of the chain.

  • y_hat_train_mean: predicted values for the training data, averaged over iterations.

  • y_hat_test_mean: predicted values for the test data, averaged over iterations.

  • sigma: posterior samples of the error standard deviations.

  • sigma_mu: posterior samples of sigma_mu, the standard deviation of the leaf node parameters.

  • s: posterior samples of s.

  • alpha: posterior samples of alpha.

  • beta: posterior samples of beta.

  • gamma: posterior samples of gamma.

  • k: posterior samples of k = 0.5 / (sqrt(num_tree) * sigma_mu)

  • num_leaves_final: the number of leaves for each tree at the final iteration.

Examples


## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE

num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000

set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 + 
  10 * x[,4] + 5 * x[,5]

gen_data <- function(n_train, n_test, P, sigma) {
  X <- matrix(runif(n_train * P), nrow = n_train)
  mu <- f_fried(X)
  X_test <- matrix(runif(n_test * P), nrow = n_test)
  mu_test <- f_fried(X_test)
  Y <- mu + sigma * rnorm(n_train)
  Y_test <- mu_test + sigma * rnorm(n_test)
  
  return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 100, 1000, 1)

## Fit the model
fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test, 
                hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
                opts = Opts(num_burn = num_burn, num_save = num_save, update_tau = TRUE))

## Plot the fit (note: interval estimates are not prediction intervals, 
## so they do not cover the predictions at the nominal rate)
plot(fit)

## Look at posterior model inclusion probabilities for each predictor. 

plot(posterior_probs(fit)[["post_probs"]], 
     col = ifelse(posterior_probs(fit)[["post_probs"]] > 0.5, scales::muted("blue"), 
                  scales::muted("green")), 
     pch = 20)


rmse(fit$y_hat_test_mean, sim_data$mu_test)
rmse(fit$y_hat_train_mean, sim_data$mu)


SoftBart documentation built on June 8, 2025, 9:40 p.m.