Nothing
##variable selection procedures from Bleich et al. (2013)
#' Perform Variable Selection using Three Threshold-based Procedures
#'
#' @description
#' Performs variable selection using the three thresholding methods introduced in Bleich et al. (2013).
#'
#' @details
#' See Bleich et al. (2013) for a complete description of the procedures outlined above as well as the corresponding vignette for a brief summary with examples.
#' @param bart_machine An object of class ``bartMachine''.
#' @param num_reps_for_avg Number of replicates to over over to for the BART model's variable inclusion proportions.
#' @param num_permute_samples Number of permutations of the response to be made to generate the ``null'' permutation distribution.
#' @param num_trees_for_permute Number of trees to use in the variable selection procedure. As with \cr \code{\link{investigate_var_importance}}, a small number of trees should be used to force variables to compete for entry into the model. Note that this number is used to estimate both the ``true'' and ``null'' variable inclusion proportions.
#' @param alpha Cut-off level for the thresholds.
#' @param plot If TRUE, a plot showing which variables are selected by each of the procedures is generated.
#' @param num_var_plot Number of variables (in order of decreasing variable inclusion proportion) to be plotted.
#' @param bottom_margin A display parameter that adjusts the bottom margin of the graph if labels are clipped. The scale of this parameter is the same as set with \code{par(mar = c(....))} in R.
#' Higher values allow for more space if the crossed covariate names are long. Note that making this parameter too large will prevent plotting and the plot function in R will throw an error.
#' @param verbose If TRUE, prints progress messages.
#'
#' @return
#' Invisibly, returns a list with the following components:
#'
#' \item{important_vars_local_names}{Names of the variables chosen by the Local procedure.}
#' \item{important_vars_global_max_names}{Names of the variables chosen by the Global Max procedure.}
#' \item{important_vars_global_se_names}{Names of the variables chosen by the Global SE procedure.}
#' \item{important_vars_local_col_nums}{Column numbers of the variables chosen by the Local procedure.}
#' \item{important_vars_global_max_col_nums}{Column numbers of the variables chosen by the Global Max procedure.}
#' \item{important_vars_global_se_col_nums}{Column numbers of the variables chosen by the Global SE procedure.}
#' \item{var_true_props_avg}{The variable inclusion proportions for the actual data.}
#' \item{permute_mat}{The permutation distribution generated by permuting the response vector.}
#'
#' @references
#' J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian
#' Additive Regression Trees. ArXiv e-prints, 2013.
#'
#' Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning
#' with Bayesian Additive Regression Trees. Journal of Statistical
#' Software, 70(4), 1-40. \doi{10.18637/jss.v070.i04}
#'
#' @seealso
#' \code{\link{var_selection_by_permute}}, \code{\link{investigate_var_importance}}
#'
#' @author
#' Adam Kapelner and Justin Bleich
#'
#' @note
#' Although the reference only explores regression settings, this procedure is applicable to both regression and classification problems.
#' This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}.
#'
#' @examples
#' \dontrun{
#' #generate Friedman data
#' set.seed(11)
#' n = 300
#' p = 20 ##15 useless predictors
#' X = data.frame(matrix(runif(n * p), ncol = p))
#' y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)
#'
#' ##build BART regression model (not actually used in variable selection)
#' bart_machine = bartMachine(X, y)
#'
#' #variable selection
#' var_sel = var_selection_by_permute(bart_machine)
#' print(var_sel$important_vars_local_names)
#' print(var_sel$important_vars_global_max_names)
#' }
#' @export
var_selection_by_permute = function(bart_machine, num_reps_for_avg = 10, num_permute_samples = 100, num_trees_for_permute = 20, alpha = 0.05, plot = TRUE, num_var_plot = Inf, bottom_margin = 10, verbose = TRUE){
assert_class(bart_machine, "bartMachine")
assert_int(num_reps_for_avg, lower = 1)
assert_int(num_permute_samples, lower = 1)
assert_int(num_trees_for_permute, lower = 1)
assert_number(alpha, lower = 0, upper = 1)
assert_flag(plot)
assert_number(num_var_plot, lower = 1)
assert_number(bottom_margin, lower = 0)
assert_flag(verbose)
check_serialization(bart_machine) #ensure the Java object exists and fire an error if not
permute_mat = matrix(NA, nrow = num_permute_samples, ncol = bart_machine$p) ##set up permute mat
colnames(permute_mat) = bart_machine$training_data_features_with_missing_features
if (verbose){
cat("avg")
}
var_true_props_avg = get_averaged_true_var_props(bart_machine, num_reps_for_avg, num_trees_for_permute, verbose = verbose) ##get props from actual data
#now sort from high to low
var_true_props_avg = sort(var_true_props_avg, decreasing = TRUE) ##sort props
if (verbose){
cat("null")
}
for (b in 1 : num_permute_samples){
permute_mat[b, ] = get_null_permute_var_importances(bart_machine, num_trees_for_permute, verbose = verbose) ##build null permutation distribution
}
if (verbose){
cat("\n")
}
#sort permute mat
permute_mat = permute_mat[, names(var_true_props_avg)]
##use local cutoff
pointwise_cutoffs = as.numeric(matrixStats::colQuantiles(permute_mat, probs = 1 - alpha))
important_vars_pointwise_names = names(var_true_props_avg[var_true_props_avg > pointwise_cutoffs & var_true_props_avg > 0])
important_vars_pointwise_col_nums = match(important_vars_pointwise_names, bart_machine$training_data_features_with_missing_features)
##use global max cutoff
row_maxes = matrixStats::rowMaxs(permute_mat)
max_cut = quantile(row_maxes, 1 - alpha)
important_vars_simul_max_names = names(var_true_props_avg[var_true_props_avg >= max_cut & var_true_props_avg > 0])
important_vars_simul_max_col_nums = match(important_vars_simul_max_names, bart_machine$training_data_features_with_missing_features)
#use global se cutoff
perm_mean = colMeans(permute_mat)
perm_se = matrixStats::colSds(permute_mat)
cover_constant = bisectK(tol = .01 , coverage = 1 - alpha, permute_mat = permute_mat, x_left = 1, x_right = 20, countLimit = 100, perm_mean = perm_mean, perm_se = perm_se)
important_vars_simul_se_names = names(var_true_props_avg[which(var_true_props_avg >= perm_mean + cover_constant * perm_se & var_true_props_avg > 0)])
important_vars_simul_se_col_nums = match(important_vars_simul_se_names, bart_machine$training_data_features_with_missing_features)
if (plot){
if (is.infinite(num_var_plot) || num_var_plot > bart_machine$p){
num_var_plot = bart_machine$p
}
non_zero_idx = which(var_true_props_avg > 0)
if (length(non_zero_idx) < length(var_true_props_avg)) {
warning(paste(length(which(var_true_props_avg == 0)), "covariates with inclusion proportions of 0 omitted from plots."))
}
if (length(non_zero_idx) > 0){
non_zero_idx = non_zero_idx[seq_len(min(num_var_plot, length(non_zero_idx)))]
plot_df = data.frame(
index = seq_len(length(non_zero_idx)),
variable = names(var_true_props_avg)[non_zero_idx],
prop = as.numeric(var_true_props_avg[non_zero_idx]),
pointwise_cutoff = as.numeric(pointwise_cutoffs[non_zero_idx]),
perm_mean = as.numeric(perm_mean[non_zero_idx]),
perm_se = as.numeric(perm_se[non_zero_idx])
)
plot_df$pointwise_selected = plot_df$prop > plot_df$pointwise_cutoff
plot_df$simul_threshold = plot_df$perm_mean + cover_constant * plot_df$perm_se
plot_df$simul_marker = ifelse(
plot_df$prop >= max_cut,
"global_max",
ifelse(plot_df$prop > plot_df$simul_threshold, "global_se", "none")
)
y_max = max(c(plot_df$prop, plot_df$pointwise_cutoff, plot_df$simul_threshold, max_cut * 1.1), na.rm = TRUE)
margin_cfg = ggplot2::margin(t = 5.5, r = 5.5, b = bottom_margin, l = 5.5)
local_plot = ggplot2::ggplot(plot_df, ggplot2::aes(x = index, y = prop)) +
ggplot2::geom_segment(
ggplot2::aes(xend = index, y = 0, yend = pointwise_cutoff),
color = "forestgreen"
) +
ggplot2::geom_point(ggplot2::aes(shape = pointwise_selected)) +
ggplot2::scale_shape_manual(values = c(`TRUE` = 16, `FALSE` = 1)) +
ggplot2::scale_x_continuous(breaks = plot_df$index, labels = plot_df$variable) +
ggplot2::scale_y_continuous(limits = c(0, y_max)) +
ggplot2::labs(title = "Local Procedure", y = "proportion included", x = NULL) +
ggplot2::theme_minimal() +
ggplot2::theme(
legend.position = "none",
axis.text.x = ggplot2::element_text(angle = 90, hjust = 1, vjust = 0.5),
plot.margin = margin_cfg
)
simul_plot = ggplot2::ggplot(plot_df, ggplot2::aes(x = index, y = prop)) +
ggplot2::geom_segment(
ggplot2::aes(xend = index, y = 0, yend = simul_threshold),
color = "blue"
) +
ggplot2::geom_hline(yintercept = max_cut, color = "red") +
ggplot2::geom_point(ggplot2::aes(shape = simul_marker)) +
ggplot2::scale_shape_manual(values = c(global_max = 16, global_se = 8, none = 1)) +
ggplot2::scale_x_continuous(breaks = plot_df$index, labels = plot_df$variable) +
ggplot2::scale_y_continuous(limits = c(0, y_max)) +
ggplot2::labs(title = "Simul. Max and SE Procedures", y = "proportion included", x = NULL) +
ggplot2::theme_minimal() +
ggplot2::theme(
legend.position = "none",
axis.text.x = ggplot2::element_text(angle = 90, hjust = 1, vjust = 0.5),
plot.margin = margin_cfg
)
plot_list = list(local_plot, simul_plot)
if (verbose){
grid::grid.newpage()
grid::pushViewport(grid::viewport(layout = grid::grid.layout(2, 1)))
for (i in seq_len(length(plot_list))){
suppressMessages(print(
plot_list[[i]],
vp = grid::viewport(layout.pos.row = i, layout.pos.col = 1)
))
}
grid::popViewport()
}
}
}
#return an invisible list
invisible(list(
important_vars_local_names = important_vars_pointwise_names,
important_vars_global_max_names = important_vars_simul_max_names,
important_vars_global_se_names = important_vars_simul_se_names,
important_vars_local_col_nums = as.numeric(important_vars_pointwise_col_nums),
important_vars_global_max_col_nums = as.numeric(important_vars_simul_max_col_nums),
important_vars_global_se_col_nums = as.numeric(important_vars_simul_se_col_nums),
var_true_props_avg = var_true_props_avg,
permute_mat = permute_mat
))
}
##private
get_averaged_true_var_props = function(bart_machine, num_reps_for_avg, num_trees_for_permute, verbose = TRUE){
var_props = rep(0, bart_machine$p)
for (i in 1 : num_reps_for_avg){
bart_machine_dup = bart_machine_duplicate(bart_machine, num_trees = num_trees_for_permute)
var_props = var_props + get_var_props_over_chain(bart_machine_dup)
if (verbose){
cat(".")
}
}
#average over many runs
var_props / num_reps_for_avg
}
##private
get_null_permute_var_importances = function(bart_machine, num_trees_for_permute, verbose = TRUE){
#permute the responses to disconnect x and y
y_permuted = sample(bart_machine$y, replace = FALSE)
#build BART on this permuted training data
bart_machine_with_permuted_y = build_bart_machine(bart_machine$X, y_permuted,
num_trees = as.numeric(num_trees_for_permute),
num_burn_in = bart_machine$num_burn_in,
num_iterations_after_burn_in = bart_machine$num_iterations_after_burn_in,
run_in_sample = FALSE,
use_missing_data = bart_machine$use_missing_data,
use_missing_data_dummies_as_covars = bart_machine$use_missing_data_dummies_as_covars,
num_rand_samps_in_library = bart_machine$num_rand_samps_in_library,
replace_missing_data_with_x_j_bar = bart_machine$replace_missing_data_with_x_j_bar,
impute_missingness_with_rf_impute = bart_machine$impute_missingness_with_rf_impute,
impute_missingness_with_x_j_bar_for_lm = bart_machine$impute_missingness_with_x_j_bar_for_lm,
verbose = FALSE)
#just return the variable proportions
var_props = get_var_props_over_chain(bart_machine_with_permuted_y)
if (verbose){
cat(".")
}
var_props
}
##private - used to compute constant for global se method. simple bisection algo.
bisectK = function(tol, coverage, permute_mat, x_left, x_right, countLimit, perm_mean, perm_se){
count = 0
guess = (x_left + x_right) / 2
# Pre-center and scale the permutation matrix to speed up the loop
# We want to check if all(permute_mat[s,] - perm_mean <= guess * perm_se)
# which is equivalent to all((permute_mat[s,] - perm_mean) / perm_se <= guess)
# which is equivalent to max((permute_mat[s,] - perm_mean) / perm_se) <= guess
# Precompute row-wise maximums of the standardized differences
row_max_std_diffs = apply(sweep(sweep(permute_mat, 2, perm_mean, "-"), 2, perm_se, "/"), 1, max)
while ((x_right - x_left) / 2 >= tol & count < countLimit){
empirical_coverage = mean(row_max_std_diffs <= guess)
if (empirical_coverage == coverage){
break
} else if (empirical_coverage < coverage){
x_left = guess
} else {
x_right = guess
}
guess = (x_left + x_right) / 2
count = count + 1
}
guess
}
##var selection -- choose best method via CV
#' Perform Variable Selection Using Cross-validation Procedure
#'
#' @description
#' Performs variable selection by cross-validating over the three threshold-based procedures outlined in Bleich et al. (2013) and selecting the single procedure that returns the lowest cross-validation RMSE.
#'
#' @details
#' See Bleich et al. (2013) for a complete description of the procedures outlined above as well as the corresponding vignette for a brief summary with examples.
#' @param bart_machine An object of class ``bartMachine''.
#' @param k_folds Number of folds to be used in cross-validation.
#' @param folds_vec An integer vector of indices specifying which fold each observation belongs to.
#' @param num_reps_for_avg Number of replicates to over over to for the BART model's variable inclusion proportions.
#' @param num_permute_samples Number of permutations of the response to be made to generate the ``null'' permutation distribution.
#' @param num_trees_for_permute Number of trees to use in the variable selection procedure. As with \cr \code{\link{investigate_var_importance}}, a small number of trees should be used to force variables to compete for entry into the model. Note that this number is used to estimate both the ``true'' and ``null'' variable inclusion proportions.
#' @param alpha Cut-off level for the thresholds.
#' @param num_trees_pred_cv Number of trees to use for prediction on the hold-out portion of each fold. Once variables have been selected using the training portion of each fold, a new model is built using only those variables with \code{num_trees_pred_cv} trees in the sum-of-trees model. Forecasts for the holdout sample are made using this model. A larger number of trees is recommended to exploit the full forecasting power of BART.
#' @param verbose If TRUE, prints progress messages.
#'
#' @return
#' Returns a list with the following components:
#'
#' \item{best_method}{The name of the best variable selection procedure, as chosen via cross-validation.}
#' \item{important_vars_cv}{The variables chosen by the \code{best_method} above.}
#'
#' @references
#' J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian
#' Additive Regression Trees. ArXiv e-prints, 2013.
#'
#' Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning
#' with Bayesian Additive Regression Trees. Journal of Statistical
#' Software, 70(4), 1-40. \doi{10.18637/jss.v070.i04}
#'
#' @seealso
#' \code{\link{var_selection_by_permute}}, \code{\link{investigate_var_importance}}
#'
#' @author
#' Adam Kapelner and Justin Bleich
#'
#' @note
#' This function can have substantial run-time.
#' This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}.
#'
#' @examples
#' \dontrun{
#' #generate Friedman data
#' set.seed(11)
#' n = 150
#' p = 100 ##95 useless predictors
#' X = data.frame(matrix(runif(n * p), ncol = p))
#' y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)
#'
#' ##build BART regression model (not actually used in variable selection)
#' bart_machine = bartMachine(X, y)
#'
#' #variable selection via cross-validation
#' var_sel_cv = var_selection_by_permute_cv(bart_machine, k_folds = 3)
#' print(var_sel_cv$best_method)
#' print(var_sel_cv$important_vars_cv)
#' }
#' @export
var_selection_by_permute_cv = function(bart_machine, k_folds = 5, folds_vec = NULL, num_reps_for_avg = 5, num_permute_samples = 100, num_trees_for_permute = 20, alpha = 0.05, num_trees_pred_cv = 50, verbose = TRUE){
assert_class(bart_machine, "bartMachine")
assert_number(k_folds, lower = 2) # Inf ok
assert_integerish(folds_vec, null.ok = TRUE)
assert_int(num_reps_for_avg, lower = 1)
assert_int(num_permute_samples, lower = 1)
assert_int(num_trees_for_permute, lower = 1)
assert_number(alpha, lower = 0, upper = 1)
assert_int(num_trees_pred_cv, lower = 1)
assert_flag(verbose)
check_serialization(bart_machine) #ensure the Java object exists and fire an error if not
if (k_folds <= 1 || k_folds > bart_machine$n){
stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")
}
if (k_folds == Inf){ #leave-one-out
k_folds = bart_machine$n
}
if (!is.null(folds_vec) & !inherits(folds_vec, "integer")){
stop("folds_vec must be an a vector of integers specifying the indexes of each folds.")
}
#set up k folds
if (is.null(folds_vec)){ ##if folds were not pre-set:
n = nrow(bart_machine$X)
if (k_folds == Inf){ #leave-one-out
k_folds = n
}
if (k_folds <= 1 || k_folds > n){
stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")
}
temp = rnorm(n)
folds_vec = cut(temp, breaks = quantile(temp, seq(0, 1, length.out = k_folds + 1)),
include.lowest= T, labels = FALSE)
} else {
k_folds = length(unique(folds_vec)) ##otherwise we know the folds, so just get k
}
L2_err_mat = matrix(NA, nrow = k_folds, ncol = 3)
colnames(L2_err_mat) = c("important_vars_local_names", "important_vars_global_max_names", "important_vars_global_se_names")
for (k in 1 : k_folds){
if (verbose){
cat("cv #", k, "\n", sep = "")
}
#find out the indices of the holdout sample
train_idx = which(folds_vec != k)
test_idx = setdiff(1 : n, train_idx)
#pull out training data
training_X_k = bart_machine$model_matrix_training_data[train_idx, -ncol(bart_machine$model_matrix_training_data)] ##toss last col bc its response
training_y_k = bart_machine$y[train_idx]
#make a temporary bart machine just so we can run the var selection for all three methods
bart_machine_temp = bart_machine_duplicate(bart_machine, X = as.data.frame(training_X_k), y = training_y_k, run_in_sample = FALSE, verbose = FALSE)
##do variable selection
bart_variables_select_obj_k = var_selection_by_permute(bart_machine_temp,
num_permute_samples = num_permute_samples,
num_trees_for_permute = num_trees_for_permute,
num_reps_for_avg = num_reps_for_avg,
alpha = alpha,
plot = FALSE,
verbose = verbose)
#pull out test data
test_X_k = bart_machine$model_matrix_training_data[test_idx, -ncol(bart_machine$model_matrix_training_data)]
test_y_k = bart_machine$y[test_idx]
if (verbose){
cat("method")
}
for (method in colnames(L2_err_mat)){
if (verbose){
cat(".")
}
#pull out the appropriate vars
vars_selected_by_method = bart_variables_select_obj_k[[method]]
if (length(vars_selected_by_method) == 0){
#we just predict ybar
ybar_est = mean(training_y_k)
#and we take L2 error against ybar
L2_err_mat[k, method] = sum((test_y_k - ybar_est)^2)
} else {
#now build the bart machine based on reduced model
training_X_k_red_by_vars_picked_by_method = data.frame(training_X_k[, vars_selected_by_method])
colnames(training_X_k_red_by_vars_picked_by_method) = vars_selected_by_method #bug fix for single column - maybe drop = F?
##need to account for cov_prior_vec update
bart_machine_temp = bart_machine_duplicate(bart_machine, X = training_X_k_red_by_vars_picked_by_method, y = training_y_k,
num_trees = num_trees_pred_cv,
run_in_sample = FALSE,
cov_prior_vec = rep(1, times = ncol(training_X_k_red_by_vars_picked_by_method)), ##do not want old vec -- standard here
verbose = FALSE)
#and calculate oos-L2 and cleanup
test_X_k_red_by_vars_picked_by_method = data.frame(test_X_k[, vars_selected_by_method])
colnames(test_X_k_red_by_vars_picked_by_method) = vars_selected_by_method #bug fix for single column
predict_obj = bart_predict_for_test_data(
bart_machine_temp,
test_X_k_red_by_vars_picked_by_method,
test_y_k,
verbose = verbose
)
#now record it
L2_err_mat[k, method] = predict_obj$L2_err
}
}
if (verbose){
cat("\n")
}
}
#now extract the lowest oos-L2 to find the "best" method for variable selection
L2_err_by_method = colSums(L2_err_mat)
min_var_selection_method = colnames(L2_err_mat)[which(L2_err_by_method == min(L2_err_by_method))]
min_var_selection_method = min_var_selection_method[1]
#now (finally) do var selection on the entire data and then return the vars from the best method found via cross-validation
if (verbose){
cat("final", "\n")
}
bart_variables_select_obj = var_selection_by_permute(bart_machine,
num_permute_samples = num_permute_samples,
num_trees_for_permute = num_trees_for_permute,
num_reps_for_avg = num_reps_for_avg,
alpha = alpha,
plot = FALSE,
verbose = verbose)
#return vars from best method and method name
list(best_method = min_var_selection_method, important_vars_cv = sort(bart_variables_select_obj[[min_var_selection_method]]))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.