Nothing
#' BART Posterior Inclusion Probabilities
#'
#' Computes the posterior inclusion probabilities (PIPs) for the fitted
#' SoftBART model, as well as variable importances and the median probability
#' model (MPM).
#'
#' @param fit An object of class \code{softbart}, \code{softbart_regression}, or
#' \code{softbart_probit}.
#'
#' @return A list containing the following:
#' \itemize{
#' \item \code{varimp}: a vector containing the average number of times a predictor
#' was used in a splitting rule.
#' \item \code{post_probs}: the posterior inclusion probabilities for each predictor.
#' \item \code{median_probability_model}: a vector containing the indicies of the
#' variables included in at least 50 percent of the samples.
#' }
#' @export
#'
#' @examples
#' ## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE
#'
#' num_burn <- 10 ## Should be ~ 5000
#' num_save <- 10 ## Should be ~ 5000
#'
#' set.seed(1234)
#' f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
#' 10 * x[,4] + 5 * x[,5]
#'
#' gen_data <- function(n_train, n_test, P, sigma) {
#' X <- matrix(runif(n_train * P), nrow = n_train)
#' mu <- f_fried(X)
#' X_test <- matrix(runif(n_test * P), nrow = n_test)
#' mu_test <- f_fried(X_test)
#' Y <- mu + sigma * rnorm(n_train)
#' Y_test <- mu_test + sigma * rnorm(n_test)
#'
#' return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
#' }
#'
#' ## Simiulate dataset
#' sim_data <- gen_data(250, 100, 1000, 1)
#'
#' ## Fit the model
#' fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test,
#' hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
#' opts = Opts(num_burn = num_burn, num_save = num_save, update_tau = TRUE))
#'
#' ## Variable selection
#'
#' post_probs <- posterior_probs(fit)
#' plot(post_probs$post_probs)
#' print(post_probs$median_probability_model)
#'
#'
posterior_probs <- function(fit) {
varimp <- colMeans(fit$var_counts)
post_probs <- colMeans(fit$var_counts > 0)
median_probability_model <- which(post_probs > 0.5)
out <- list(varimp = varimp, post_probs = post_probs,
median_probability_model = median_probability_model)
class(out) <- "sb_postprobs"
return(out)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.