R/posterior_probs.R

Defines functions posterior_probs

Documented in posterior_probs

#' BART Posterior Inclusion Probabilities
#'
#' Computes the posterior inclusion probabilities (PIPs) for the fitted 
#' SoftBART model, as well as variable importances and the median probability
#' model (MPM).
#'
#' @param fit An object of class \code{softbart}, \code{softbart_regression}, or
#'   \code{softbart_probit}.
#'
#' @return A list containing the following:
#' \itemize{
#'   \item \code{varimp}: a vector containing the average number of times a predictor
#'                 was used in a splitting rule.
#'   \item \code{post_probs}: the posterior inclusion probabilities for each predictor.
#'   \item \code{median_probability_model}: a vector containing the indicies of the
#'     variables included in at least 50 percent of the samples.
#' }
#' @export
#'
#' @examples
#' ## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE
#' 
#' num_burn <- 10 ## Should be ~ 5000
#' num_save <- 10 ## Should be ~ 5000
#' 
#' set.seed(1234)
#' f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 + 
#'   10 * x[,4] + 5 * x[,5]
#' 
#' gen_data <- function(n_train, n_test, P, sigma) {
#'   X <- matrix(runif(n_train * P), nrow = n_train)
#'   mu <- f_fried(X)
#'   X_test <- matrix(runif(n_test * P), nrow = n_test)
#'   mu_test <- f_fried(X_test)
#'   Y <- mu + sigma * rnorm(n_train)
#'   Y_test <- mu_test + sigma * rnorm(n_test)
#'   
#'   return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
#' }
#' 
#' ## Simiulate dataset
#' sim_data <- gen_data(250, 100, 1000, 1)
#' 
#' ## Fit the model
#' fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test, 
#'                 hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
#'                 opts = Opts(num_burn = num_burn, num_save = num_save, update_tau = TRUE))
#'                 
#' ## Variable selection
#' 
#' post_probs <- posterior_probs(fit)
#' plot(post_probs$post_probs)
#' print(post_probs$median_probability_model)
#' 
#'
posterior_probs <- function(fit) {
  varimp                   <- colMeans(fit$var_counts)
  post_probs               <- colMeans(fit$var_counts > 0)
  median_probability_model <- which(post_probs > 0.5)
  
  out <- list(varimp = varimp, post_probs = post_probs, 
              median_probability_model = median_probability_model)
  class(out) <- "sb_postprobs"
  return(out)
}

Try the SoftBart package in your browser

Any scripts or data that you put into this service are public.

SoftBart documentation built on June 8, 2025, 9:40 p.m.