posterior_probs: BART Posterior Inclusion Probabilities

View source: R/posterior_probs.R

posterior_probsR Documentation

BART Posterior Inclusion Probabilities

Description

Computes the posterior inclusion probabilities (PIPs) for the fitted SoftBART model, as well as variable importances and the median probability model (MPM).

Usage

posterior_probs(fit)

Arguments

fit

An object of class softbart, softbart_regression, or softbart_probit.

Value

A list containing the following:

  • varimp: a vector containing the average number of times a predictor was used in a splitting rule.

  • post_probs: the posterior inclusion probabilities for each predictor.

  • median_probability_model: a vector containing the indicies of the variables included in at least 50 percent of the samples.

Examples

## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE

num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000

set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 + 
  10 * x[,4] + 5 * x[,5]

gen_data <- function(n_train, n_test, P, sigma) {
  X <- matrix(runif(n_train * P), nrow = n_train)
  mu <- f_fried(X)
  X_test <- matrix(runif(n_test * P), nrow = n_test)
  mu_test <- f_fried(X_test)
  Y <- mu + sigma * rnorm(n_train)
  Y_test <- mu_test + sigma * rnorm(n_test)
  
  return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 100, 1000, 1)

## Fit the model
fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test, 
                hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
                opts = Opts(num_burn = num_burn, num_save = num_save, update_tau = TRUE))
                
## Variable selection

post_probs <- posterior_probs(fit)
plot(post_probs$post_probs)
print(post_probs$median_probability_model)



SoftBart documentation built on June 8, 2025, 9:40 p.m.