R/bart_package_variable_selection.R

Defines functions var_selection_by_permute_cv bisectK get_null_permute_var_importances get_averaged_true_var_props var_selection_by_permute

Documented in var_selection_by_permute var_selection_by_permute_cv

##variable selection procedures from Bleich et al. (2013)
#' Perform Variable Selection using Three Threshold-based Procedures
#'
#' @description
#' Performs variable selection using the three thresholding methods introduced in Bleich et al. (2013).
#'
#' @details
#' See Bleich et al. (2013) for a complete description of the procedures outlined above as well as the corresponding vignette for a brief summary with examples.
#' @param bart_machine An object of class ``bartMachine''.
#' @param num_reps_for_avg Number of replicates to over over to for the BART model's variable inclusion proportions.
#' @param num_permute_samples Number of permutations of the response to be made to generate the ``null'' permutation distribution.
#' @param num_trees_for_permute Number of trees to use in the variable selection procedure. As with \cr \code{\link{investigate_var_importance}}, a small number of trees should be used to force variables to compete for entry into the model. Note that this number is used to estimate both the ``true'' and ``null'' variable inclusion proportions.
#' @param alpha Cut-off level for the thresholds.
#' @param plot If TRUE, a plot showing which variables are selected by each of the procedures is generated.
#' @param num_var_plot Number of variables (in order of decreasing variable inclusion proportion) to be plotted.
#' @param bottom_margin A display parameter that adjusts the bottom margin of the graph if labels are clipped. The scale of this parameter is the same as set with \code{par(mar = c(....))} in R.
#'   Higher values allow for more space if the crossed covariate names are long. Note that making this parameter too large will prevent plotting and the plot function in R will throw an error.
#' @param verbose If TRUE, prints progress messages.
#'
#' @return
#' Invisibly, returns a list with the following components:
#' 
#'   \item{important_vars_local_names}{Names of the variables chosen by the Local procedure.}
#'   \item{important_vars_global_max_names}{Names of the variables chosen by the Global Max procedure.}
#'     \item{important_vars_global_se_names}{Names of the variables chosen by the Global SE procedure.}
#'     \item{important_vars_local_col_nums}{Column numbers of the variables chosen by the Local procedure.}
#'   \item{important_vars_global_max_col_nums}{Column numbers of the variables chosen by the Global Max procedure.}
#'     \item{important_vars_global_se_col_nums}{Column numbers of the variables chosen by the Global SE procedure.}
#'     \item{var_true_props_avg}{The variable inclusion proportions for the actual data.}
#'   \item{permute_mat}{The permutation distribution generated by permuting the response vector.}
#'
#' @references
#' J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian
#' Additive Regression Trees. ArXiv e-prints, 2013.
#' 
#' Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning
#' with Bayesian Additive Regression Trees. Journal of Statistical
#' Software, 70(4), 1-40. \doi{10.18637/jss.v070.i04}
#'
#' @seealso
#' \code{\link{var_selection_by_permute}}, \code{\link{investigate_var_importance}}
#'
#' @author
#' Adam Kapelner and Justin Bleich
#'
#' @note
#' Although the reference only explores regression settings, this procedure is applicable to both regression and classification problems.
#' This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}.
#'
#' @examples
#' \dontrun{
#' #generate Friedman data
#' set.seed(11)
#' n  = 300
#' p = 20 ##15 useless predictors
#' X = data.frame(matrix(runif(n * p), ncol = p))
#' y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)
#' 
#' ##build BART regression model (not actually used in variable selection)
#' bart_machine = bartMachine(X, y)
#' 
#' #variable selection
#' var_sel = var_selection_by_permute(bart_machine)
#' print(var_sel$important_vars_local_names)
#' print(var_sel$important_vars_global_max_names)
#' }
#' @export
var_selection_by_permute = function(bart_machine, num_reps_for_avg = 10, num_permute_samples = 100, num_trees_for_permute = 20, alpha = 0.05, plot = TRUE, num_var_plot = Inf, bottom_margin = 10, verbose = TRUE){	
  assert_class(bart_machine, "bartMachine")
  assert_int(num_reps_for_avg, lower = 1)
  assert_int(num_permute_samples, lower = 1)
  assert_int(num_trees_for_permute, lower = 1)
  assert_number(alpha, lower = 0, upper = 1)
  assert_flag(plot)
  assert_number(num_var_plot, lower = 1)
  assert_number(bottom_margin, lower = 0)
  assert_flag(verbose)

	check_serialization(bart_machine) #ensure the Java object exists and fire an error if not
	
	permute_mat = matrix(NA, nrow = num_permute_samples, ncol = bart_machine$p) ##set up permute mat
	colnames(permute_mat) = bart_machine$training_data_features_with_missing_features
	
	if (verbose){
		cat("avg")
	}
	var_true_props_avg = get_averaged_true_var_props(bart_machine, num_reps_for_avg, num_trees_for_permute, verbose = verbose) ##get props from actual data
	
	#now sort from high to low
	var_true_props_avg = sort(var_true_props_avg, decreasing = TRUE) ##sort props
	
	if (verbose){
		cat("null")
	}
	for (b in 1 : num_permute_samples){
		permute_mat[b, ] = get_null_permute_var_importances(bart_machine, num_trees_for_permute, verbose = verbose) ##build null permutation distribution
	}
	if (verbose){
		cat("\n")
	}
	
	#sort permute mat
	permute_mat = permute_mat[, names(var_true_props_avg)]
	
    ##use local cutoff
	pointwise_cutoffs = as.numeric(matrixStats::colQuantiles(permute_mat, probs = 1 - alpha))
	important_vars_pointwise_names = names(var_true_props_avg[var_true_props_avg > pointwise_cutoffs & var_true_props_avg > 0])
	important_vars_pointwise_col_nums = match(important_vars_pointwise_names, bart_machine$training_data_features_with_missing_features)
	
    ##use global max cutoff
	row_maxes = matrixStats::rowMaxs(permute_mat)
	max_cut = quantile(row_maxes, 1 - alpha)
	important_vars_simul_max_names = names(var_true_props_avg[var_true_props_avg >= max_cut & var_true_props_avg > 0])	
	important_vars_simul_max_col_nums = match(important_vars_simul_max_names, bart_machine$training_data_features_with_missing_features)
	
    #use global se cutoff
	perm_mean = colMeans(permute_mat)
	perm_se = matrixStats::colSds(permute_mat)
	cover_constant = bisectK(tol = .01 , coverage = 1 - alpha, permute_mat = permute_mat, x_left = 1, x_right = 20, countLimit = 100, perm_mean = perm_mean, perm_se = perm_se)
	important_vars_simul_se_names = names(var_true_props_avg[which(var_true_props_avg >= perm_mean + cover_constant * perm_se & var_true_props_avg > 0)])	
	important_vars_simul_se_col_nums = match(important_vars_simul_se_names, bart_machine$training_data_features_with_missing_features)
	
	if (plot){
		if (is.infinite(num_var_plot) || num_var_plot > bart_machine$p){
			num_var_plot = bart_machine$p
		}
		non_zero_idx = which(var_true_props_avg > 0)
		if (length(non_zero_idx) < length(var_true_props_avg)) {
			warning(paste(length(which(var_true_props_avg == 0)), "covariates with inclusion proportions of 0 omitted from plots."))
		}
		if (length(non_zero_idx) > 0){
			non_zero_idx = non_zero_idx[seq_len(min(num_var_plot, length(non_zero_idx)))]
			plot_df = data.frame(
				index = seq_len(length(non_zero_idx)),
				variable = names(var_true_props_avg)[non_zero_idx],
				prop = as.numeric(var_true_props_avg[non_zero_idx]),
				pointwise_cutoff = as.numeric(pointwise_cutoffs[non_zero_idx]),
				perm_mean = as.numeric(perm_mean[non_zero_idx]),
				perm_se = as.numeric(perm_se[non_zero_idx])
			)
			plot_df$pointwise_selected = plot_df$prop > plot_df$pointwise_cutoff
			plot_df$simul_threshold = plot_df$perm_mean + cover_constant * plot_df$perm_se
			plot_df$simul_marker = ifelse(
				plot_df$prop >= max_cut,
				"global_max",
				ifelse(plot_df$prop > plot_df$simul_threshold, "global_se", "none")
			)

			y_max = max(c(plot_df$prop, plot_df$pointwise_cutoff, plot_df$simul_threshold, max_cut * 1.1), na.rm = TRUE)
			margin_cfg = ggplot2::margin(t = 5.5, r = 5.5, b = bottom_margin, l = 5.5)

			local_plot = ggplot2::ggplot(plot_df, ggplot2::aes(x = index, y = prop)) +
				ggplot2::geom_segment(
					ggplot2::aes(xend = index, y = 0, yend = pointwise_cutoff),
					color = "forestgreen"
				) +
				ggplot2::geom_point(ggplot2::aes(shape = pointwise_selected)) +
				ggplot2::scale_shape_manual(values = c(`TRUE` = 16, `FALSE` = 1)) +
				ggplot2::scale_x_continuous(breaks = plot_df$index, labels = plot_df$variable) +
				ggplot2::scale_y_continuous(limits = c(0, y_max)) +
				ggplot2::labs(title = "Local Procedure", y = "proportion included", x = NULL) +
				ggplot2::theme_minimal() +
				ggplot2::theme(
					legend.position = "none",
					axis.text.x = ggplot2::element_text(angle = 90, hjust = 1, vjust = 0.5),
					plot.margin = margin_cfg
				)

			simul_plot = ggplot2::ggplot(plot_df, ggplot2::aes(x = index, y = prop)) +
				ggplot2::geom_segment(
					ggplot2::aes(xend = index, y = 0, yend = simul_threshold),
					color = "blue"
				) +
				ggplot2::geom_hline(yintercept = max_cut, color = "red") +
				ggplot2::geom_point(ggplot2::aes(shape = simul_marker)) +
				ggplot2::scale_shape_manual(values = c(global_max = 16, global_se = 8, none = 1)) +
				ggplot2::scale_x_continuous(breaks = plot_df$index, labels = plot_df$variable) +
				ggplot2::scale_y_continuous(limits = c(0, y_max)) +
				ggplot2::labs(title = "Simul. Max and SE Procedures", y = "proportion included", x = NULL) +
				ggplot2::theme_minimal() +
				ggplot2::theme(
					legend.position = "none",
					axis.text.x = ggplot2::element_text(angle = 90, hjust = 1, vjust = 0.5),
					plot.margin = margin_cfg
				)

			plot_list = list(local_plot, simul_plot)
			if (verbose){
				grid::grid.newpage()
				grid::pushViewport(grid::viewport(layout = grid::grid.layout(2, 1)))
				for (i in seq_len(length(plot_list))){
					suppressMessages(print(
						plot_list[[i]],
						vp = grid::viewport(layout.pos.row = i, layout.pos.col = 1)
					))
				}
				grid::popViewport()
			}
		}
	}
	
  #return an invisible list
	invisible(list(
		important_vars_local_names = important_vars_pointwise_names,
		important_vars_global_max_names = important_vars_simul_max_names,
		important_vars_global_se_names = important_vars_simul_se_names,
		important_vars_local_col_nums = as.numeric(important_vars_pointwise_col_nums),
		important_vars_global_max_col_nums = as.numeric(important_vars_simul_max_col_nums),
		important_vars_global_se_col_nums = as.numeric(important_vars_simul_se_col_nums),		
		var_true_props_avg = var_true_props_avg,
		permute_mat = permute_mat
	))
}

##private
get_averaged_true_var_props = function(bart_machine, num_reps_for_avg, num_trees_for_permute, verbose = TRUE){
	var_props = rep(0, bart_machine$p)
	for (i in 1 : num_reps_for_avg){
		bart_machine_dup = bart_machine_duplicate(bart_machine, num_trees = num_trees_for_permute)
		var_props = var_props + get_var_props_over_chain(bart_machine_dup)
		if (verbose){
			cat(".")
		}
	}
	#average over many runs
	var_props / num_reps_for_avg
}

##private
get_null_permute_var_importances = function(bart_machine, num_trees_for_permute, verbose = TRUE){
	#permute the responses to disconnect x and y
	y_permuted = sample(bart_machine$y, replace = FALSE)
	
	#build BART on this permuted training data
	bart_machine_with_permuted_y = build_bart_machine(bart_machine$X, y_permuted, 
			num_trees = as.numeric(num_trees_for_permute), 
			num_burn_in = bart_machine$num_burn_in, 
			num_iterations_after_burn_in = bart_machine$num_iterations_after_burn_in,
			run_in_sample = FALSE,
			use_missing_data = bart_machine$use_missing_data,
			use_missing_data_dummies_as_covars = bart_machine$use_missing_data_dummies_as_covars,
			num_rand_samps_in_library = bart_machine$num_rand_samps_in_library,
			replace_missing_data_with_x_j_bar = bart_machine$replace_missing_data_with_x_j_bar,
			impute_missingness_with_rf_impute = bart_machine$impute_missingness_with_rf_impute,
			impute_missingness_with_x_j_bar_for_lm = bart_machine$impute_missingness_with_x_j_bar_for_lm,					
			verbose = FALSE)
	#just return the variable proportions	
	var_props = get_var_props_over_chain(bart_machine_with_permuted_y)
	if (verbose){
		cat(".")
	}
	var_props
}

##private - used to compute constant for global se method. simple bisection algo.
bisectK = function(tol, coverage, permute_mat, x_left, x_right, countLimit, perm_mean, perm_se){
	count = 0
	guess = (x_left + x_right) / 2
	# Pre-center and scale the permutation matrix to speed up the loop
	# We want to check if all(permute_mat[s,] - perm_mean <= guess * perm_se)
	# which is equivalent to all((permute_mat[s,] - perm_mean) / perm_se <= guess)
	# which is equivalent to max((permute_mat[s,] - perm_mean) / perm_se) <= guess
	
	# Precompute row-wise maximums of the standardized differences
	row_max_std_diffs = apply(sweep(sweep(permute_mat, 2, perm_mean, "-"), 2, perm_se, "/"), 1, max)
	
	while ((x_right - x_left) / 2 >= tol & count < countLimit){
		empirical_coverage = mean(row_max_std_diffs <= guess)
		if (empirical_coverage == coverage){
			break
		} else if (empirical_coverage < coverage){
			x_left = guess
		} else {
			x_right = guess
		}
		guess = (x_left + x_right) / 2
		count = count + 1
	}
	guess
}

##var selection -- choose best method via CV
#' Perform Variable Selection Using Cross-validation Procedure
#'
#' @description
#' Performs variable selection by cross-validating over the three threshold-based procedures outlined in Bleich et al. (2013) and selecting the single procedure that returns the lowest cross-validation RMSE.
#'
#' @details
#' See Bleich et al. (2013) for a complete description of the procedures outlined above as well as the corresponding vignette for a brief summary with examples.
#' @param bart_machine An object of class ``bartMachine''.
#' @param k_folds Number of folds to be used in cross-validation.
#' @param folds_vec An integer vector of indices specifying which fold each observation belongs to.
#' @param num_reps_for_avg Number of replicates to over over to for the BART model's variable inclusion proportions.
#' @param num_permute_samples Number of permutations of the response to be made to generate the ``null'' permutation distribution.
#' @param num_trees_for_permute Number of trees to use in the variable selection procedure. As with \cr \code{\link{investigate_var_importance}}, a small number of trees should be used to force variables to compete for entry into the model. Note that this number is used to estimate both the ``true'' and ``null'' variable inclusion proportions.
#' @param alpha Cut-off level for the thresholds.
#' @param num_trees_pred_cv Number of trees to use for prediction on the hold-out portion of each fold. Once variables have been selected using the training portion of each fold, a new model is built using only those variables with \code{num_trees_pred_cv} trees in the sum-of-trees model. Forecasts for the holdout sample are made using this model. A larger number of trees is recommended to exploit the full forecasting power of BART.
#' @param verbose If TRUE, prints progress messages.
#'
#' @return
#' Returns a list with the following components:
#' 
#'   \item{best_method}{The name of the best variable selection procedure, as chosen via cross-validation.}
#'   \item{important_vars_cv}{The variables chosen by the \code{best_method} above.}
#'
#' @references
#' J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian
#' Additive Regression Trees. ArXiv e-prints, 2013.
#' 
#' Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning
#' with Bayesian Additive Regression Trees. Journal of Statistical
#' Software, 70(4), 1-40. \doi{10.18637/jss.v070.i04}
#'
#' @seealso
#' \code{\link{var_selection_by_permute}}, \code{\link{investigate_var_importance}}
#'
#' @author
#' Adam Kapelner and Justin Bleich
#'
#' @note
#' This function can have substantial run-time.
#' This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}.
#'
#' @examples
#' \dontrun{
#' #generate Friedman data
#' set.seed(11)
#' n  = 150
#' p = 100 ##95 useless predictors
#' X = data.frame(matrix(runif(n * p), ncol = p))
#' y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)
#' 
#' ##build BART regression model (not actually used in variable selection)
#' bart_machine = bartMachine(X, y)
#' 
#' #variable selection via cross-validation
#' var_sel_cv = var_selection_by_permute_cv(bart_machine, k_folds = 3)
#' print(var_sel_cv$best_method)
#' print(var_sel_cv$important_vars_cv)
#' }
#' @export
var_selection_by_permute_cv = function(bart_machine, k_folds = 5, folds_vec = NULL, num_reps_for_avg = 5, num_permute_samples = 100, num_trees_for_permute = 20, alpha = 0.05, num_trees_pred_cv = 50, verbose = TRUE){
  assert_class(bart_machine, "bartMachine")
  assert_number(k_folds, lower = 2) # Inf ok
  assert_integerish(folds_vec, null.ok = TRUE)
  assert_int(num_reps_for_avg, lower = 1)
  assert_int(num_permute_samples, lower = 1)
  assert_int(num_trees_for_permute, lower = 1)
  assert_number(alpha, lower = 0, upper = 1)
  assert_int(num_trees_pred_cv, lower = 1)
  assert_flag(verbose)

	check_serialization(bart_machine) #ensure the Java object exists and fire an error if not
	
	if (k_folds <= 1 || k_folds > bart_machine$n){
		stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")
	}	
	
	if (k_folds == Inf){ #leave-one-out
		k_folds = bart_machine$n
	}	
	if (!is.null(folds_vec) & !inherits(folds_vec, "integer")){
		stop("folds_vec must be an a vector of integers specifying the indexes of each folds.")  
	}
	
	#set up k folds
	if (is.null(folds_vec)){ ##if folds were not pre-set:
		n = nrow(bart_machine$X)
		if (k_folds == Inf){ #leave-one-out
			k_folds = n
		}
		
		if (k_folds <= 1 || k_folds > n){
			stop("The number of folds must be at least 2 and less than or equal to n, use \"Inf\" for leave one out")
		}
		
		temp = rnorm(n)
		
		folds_vec = cut(temp, breaks = quantile(temp, seq(0, 1, length.out = k_folds + 1)), 
				include.lowest= T, labels = FALSE)
	} else {
		k_folds = length(unique(folds_vec)) ##otherwise we know the folds, so just get k
	}
	
	L2_err_mat = matrix(NA, nrow = k_folds, ncol = 3)
	colnames(L2_err_mat) = c("important_vars_local_names", "important_vars_global_max_names", "important_vars_global_se_names")
	
	for (k in 1 : k_folds){
		if (verbose){
			cat("cv #", k, "\n", sep = "")
		}
		#find out the indices of the holdout sample
		train_idx = which(folds_vec != k)
		test_idx = setdiff(1 : n, train_idx)
		
		#pull out training data
		training_X_k = bart_machine$model_matrix_training_data[train_idx, -ncol(bart_machine$model_matrix_training_data)] ##toss last col bc its response
		training_y_k = bart_machine$y[train_idx]		
		
		#make a temporary bart machine just so we can run the var selection for all three methods
		bart_machine_temp = bart_machine_duplicate(bart_machine, X = as.data.frame(training_X_k), y = training_y_k, run_in_sample = FALSE, verbose = FALSE)
    
        ##do variable selection
		bart_variables_select_obj_k = var_selection_by_permute(bart_machine_temp, 
				num_permute_samples = num_permute_samples, 
				num_trees_for_permute = num_trees_for_permute,
        		num_reps_for_avg = num_reps_for_avg,                                                                          
				alpha = alpha, 
				plot = FALSE,
				verbose = verbose)
		
		#pull out test data
		test_X_k = bart_machine$model_matrix_training_data[test_idx, -ncol(bart_machine$model_matrix_training_data)]
		test_y_k = bart_machine$y[test_idx]
		
		if (verbose){
			cat("method")
		}
		for (method in colnames(L2_err_mat)){
			if (verbose){
				cat(".")
			}
			#pull out the appropriate vars
			vars_selected_by_method = bart_variables_select_obj_k[[method]]
  
			if (length(vars_selected_by_method) == 0){
				#we just predict ybar
				ybar_est = mean(training_y_k)
				#and we take L2 error against ybar
				L2_err_mat[k, method] = sum((test_y_k - ybar_est)^2)
			} else {
				#now build the bart machine based on reduced model
			  training_X_k_red_by_vars_picked_by_method = data.frame(training_X_k[, vars_selected_by_method])
			  colnames(training_X_k_red_by_vars_picked_by_method) = vars_selected_by_method #bug fix for single column - maybe drop  = F?
        
        ##need to account for cov_prior_vec update
        
				bart_machine_temp = bart_machine_duplicate(bart_machine, X = training_X_k_red_by_vars_picked_by_method, y = training_y_k,
						num_trees = num_trees_pred_cv,
						run_in_sample = FALSE,
            			cov_prior_vec = rep(1, times = ncol(training_X_k_red_by_vars_picked_by_method)),   ##do not want old vec -- standard here                                   
						verbose = FALSE)
				#and calculate oos-L2 and cleanup
				test_X_k_red_by_vars_picked_by_method = data.frame(test_X_k[, vars_selected_by_method])
        		colnames(test_X_k_red_by_vars_picked_by_method) = vars_selected_by_method #bug fix for single column

        		predict_obj = bart_predict_for_test_data(
					bart_machine_temp,
					test_X_k_red_by_vars_picked_by_method,
					test_y_k,
					verbose = verbose
				)
				#now record it
				L2_err_mat[k, method] = predict_obj$L2_err
			}
		}
		if (verbose){
			cat("\n")
		}
	}
	
	#now extract the lowest oos-L2 to find the "best" method for variable selection
	L2_err_by_method = colSums(L2_err_mat)
	min_var_selection_method = colnames(L2_err_mat)[which(L2_err_by_method == min(L2_err_by_method))]
	min_var_selection_method = min_var_selection_method[1]

	#now (finally) do var selection on the entire data and then return the vars from the best method found via cross-validation
	if (verbose){
		cat("final", "\n")
	}
	bart_variables_select_obj = var_selection_by_permute(bart_machine, 
			num_permute_samples = num_permute_samples, 
			num_trees_for_permute = num_trees_for_permute, 
	    	num_reps_for_avg = num_reps_for_avg,                                                                        
			alpha = alpha, 
			plot = FALSE,
			verbose = verbose)
	
    #return vars from best method and method name
	list(best_method = min_var_selection_method, important_vars_cv = sort(bart_variables_select_obj[[min_var_selection_method]]))
}

Try the bartMachine package in your browser

Any scripts or data that you put into this service are public.

bartMachine documentation built on Jan. 19, 2026, 9:06 a.m.