R/bart_node_related_methods.R

Defines functions extract_node_data extract_raw_node_data get_projection_weights node_prediction_training_data_indices

Documented in extract_raw_node_data get_projection_weights node_prediction_training_data_indices

#' Gets node predictions indices of the training data for new data.
#'
#' @description
#' This returns a binary tensor for all gibbs samples after burn-in for all trees and for all training observations.
#' @param bart_machine An object of class ``bartMachine''.
#' @param new_data Data that you wish to investigate the training sample weights. If \code{NULL}, the original training data is used.
#'
#' @return
#' Returns a binary tensor indicating whether the prediction node contained a training datum or not. For each observation in new data, the size of this tensor is number of gibbs sample after burn-in
#' times the number of trees times the number of training data observations. This the size of the full tensor is the number of observations in the new data times the three dimensional object just explained.
#'
#' @examples
#' \dontrun{
#' set.seed(11)
#' n = 50
#' X = data.frame(x1 = rnorm(n), x2 = runif(n))
#' y = X$x1 + rnorm(n)
#' bart_machine = bartMachine(X, y, flush_indices_to_save_RAM = FALSE)
#' idx = node_prediction_training_data_indices(bart_machine)
#' }
#' @export
node_prediction_training_data_indices = function(bart_machine, new_data = NULL){
  assert_class(bart_machine, "bartMachine")
  assert_data_frame(new_data, null.ok = TRUE)

	if (bart_machine$flush_indices_to_save_RAM){
		stop("Node prediction training data indices cannot be computed if \"flush_indices_to_save_RAM\" was used to construct the BART model.")
	}
	
	call_java <- function(arg){
		res <- .jcall(
			bart_machine$java_bart_machine,
			"[[[[Z",
			"getNodePredictionTrainingIndicies",
			arg,
			simplify = TRUE
		)
		if (is.null(res)){
			stop(
				"Java failed to produce node prediction training data indices. ",
				"Ensure the model was built with flush_indices_to_save_RAM = FALSE ",
				"and that the Java object is still valid.",
				call. = FALSE
			)
		}
		res
	}

	if (is.null(new_data)){
		double_vec_null = .jcast(.jnull(), new.class = "[[D", check = FALSE, convert.array = FALSE)
		call_java(double_vec_null)
	} else {
		call_java(.jarray(new_data, dispatch = TRUE))
	}	
}

#' Gets Training Sample Projection / Weights
#'
#' @description
#' Returns the matrix H where yhat is approximately equal to H y where yhat is the predicted values for \code{new_data}. If \code{new_data} is unspecified, yhat will be the in-sample fits.
#' If BART was the same as OLS, H would be an orthogonal projection matrix. Here it is a projection matrix, but clearly non-orthogonal. Unfortunately, I cannot get
#' this function to work correctly because of three possible reasons (1) BART does not work by averaging tree predictions: it is a sum of trees model where each tree sees the residuals
#' via backfitting (2) the prediction in each node is a bayesian posterior draw which is close to ybar of the observations contained in the node if noise is gauged to be small and
#' (3) there are transformations of the original y variable. I believe I got close and I think I'm off by a constant multiple which is a function of the number of trees. I can
#' use regression to estimate the constant multiple and correct for it. Turn \code{regression_kludge} to \code{TRUE} for this. Note that the weights do not add up to one here.
#' The intuition is because due to the backfitting there is multiple counting. But I'm not entirely sure.
#' @param bart_machine An object of class ``bartMachine''.
#' @param new_data Data that you wish to investigate the training sample projection / weights. If \code{NULL}, the original training data is used.
#' @param regression_kludge See explanation in the description. Default is \code{FALSE}.
#'
#' @return
#' Returns a matrix of proportions with number of rows equal to the number of rows of \code{new_data} and number of columns equal to the number of rows of the original training data, n.
#'
#' @examples
#' \dontrun{
#' options(java.parameters = c("-Xmx20g", "--add-modules=jdk.incubator.vector", "-XX:+UseZGC"))
#' pacman::p_load(bartMachine, tidyverse)
#' 
#' seed = 1984
#' set.seed(seed)
#' n = 100
#' x = rnorm(n, 0, 1)
#' sigma = 0.1
#' y = x + rnorm(n, 0, sigma)
#' 
#' num_trees = 200
#' num_iterations_after_burn_in = 1000
#' bart_mod = bartMachine(data.frame(x = x), y,
#' 	flush_indices_to_save_RAM = FALSE,
#' 	num_trees = num_trees,
#' 	num_iterations_after_burn_in = num_iterations_after_burn_in,
#' 	seed = seed)
#' bart_mod
#' 
#' n_star = 100
#' x_star = rnorm(n_star)
#' y_star = as.numeric(x_star + rnorm(n_star, 0, sigma))
#' yhat_star_bart = predict(bart_mod, data.frame(x = x_star))
#' 
#' Hstar = get_projection_weights(bart_mod, data.frame(x = x_star))
#' rowSums(Hstar)
#' yhat_star_projection = as.numeric(Hstar %*% y)
#' 
#' ggplot(data.frame(
#' 	yhat_star = yhat_star_bart,
#' 	yhat_star_projection = yhat_star_projection,
#' 	y_star = y_star)) +
#'   geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") +
#'   geom_abline(slope = 1, intercept = 0)
#' 
#' Hstar = get_projection_weights(bart_mod, data.frame(x = x_star), regression_kludge = TRUE)
#' rowSums(Hstar)
#' yhat_star_projection = as.numeric(Hstar %*% y)
#' 
#' ggplot(data.frame(
#' 	yhat_star = yhat_star_bart,
#' 	yhat_star_projection = yhat_star_projection,
#' 	y_star = y_star)) +
#'   geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") +
#'   geom_abline(slope = 1, intercept = 0)
#' 
#' }
#' @export
get_projection_weights = function(bart_machine, new_data = NULL, regression_kludge = FALSE){
  assert_class(bart_machine, "bartMachine")
  assert_data_frame(new_data, null.ok = TRUE)
  assert_flag(regression_kludge)

	if (bart_machine$flush_indices_to_save_RAM){
		stop("Node prediction training data indices cannot be computed if \"flush_indices_to_save_RAM\" was used to construct the BART model.")
	}
	if (regression_kludge){
		if (is.null(new_data)){
			new_data = bart_machine$X
		}
		yhats = predict(bart_machine, new_data)
	}
	if (is.null(new_data)){
		double_vec_null = .jcast(.jnull(), new.class = "[[D", check = FALSE, convert.array = FALSE)
		weights = .jcall(bart_machine$java_bart_machine, "[[D", "getProjectionWeights", double_vec_null, simplify = TRUE)
	} else {
		weights = .jcall(bart_machine$java_bart_machine, "[[D", "getProjectionWeights", .jarray(as.matrix(new_data), dispatch = TRUE), simplify = TRUE)
	}	
	if (regression_kludge){
		yhat_star_projection = as.numeric(weights %*% bart_machine$y)
		weights * as.numeric(coef(lm(yhats ~ yhat_star_projection))[2]) #scale it back
	} else {
		weights
	}
}

BAD_FLAG_INT = -2147483647
BAD_FLAG_DOUBLE = -1.7976931348623157e+308
#' Gets Raw Node data
#'
#' @description
#' Returns a list object that contains all the information for all trees in a given Gibbs sample. Daughter nodes are nested
#' in the list structure recursively.
#' @param bart_machine An object of class ``bartMachine''.
#' @param g The gibbs sample number. It must be a natural number between 1 and the number of iterations after burn in. Default is 1.
#'
#' @return
#' Returns a list object that contains all the information for all trees in a given Gibbs sample.
#'
#' @examples
#' \dontrun{
#' options(java.parameters = c("-Xmx20g", "--add-modules=jdk.incubator.vector", "-XX:+UseZGC"))
#' pacman::p_load(bartMachine)
#' 
#' seed = 1984
#' set.seed(seed)
#' n = 100
#' x = rnorm(n, 0, 1)
#' sigma = 0.1
#' y = x + rnorm(n, 0, sigma)
#' 
#' num_trees = 200
#' num_iterations_after_burn_in = 1000
#' bart_mod = bartMachine(data.frame(x = x), y,
#' 	flush_indices_to_save_RAM = FALSE,
#' 	num_trees = num_trees,
#' 	num_iterations_after_burn_in = num_iterations_after_burn_in,
#' 	seed = seed)
#' 
#' raw_node_data = extract_raw_node_data(bart_mod)
#' 
#' }
#' @export
extract_raw_node_data = function(bart_machine, g = 1){
  assert_class(bart_machine, "bartMachine")
  assert_int(g, lower = 1)
  
	if (g < 1 | g > bart_machine$num_iterations_after_burn_in){
		stop("g is the gibbs sample number i.e. it must be a natural number between 1 and the number of iterations after burn in")
	}
	raw_data_java = .jcall(bart_machine$java_bart_machine, "[LbartMachine/bartMachineTreeNode;", "extractRawNodeInformation", as.integer(g - 1), simplify = TRUE)
	raw_data = list()
	for (m in 1 : bart_machine$num_trees){
		raw_data[[m]] = extract_node_data(raw_data_java[[m]])	
	}
	raw_data
}

extract_node_data = function(node_java){
	node_data = list()
	node_data$java_obj = node_java
	node_data$left_java_obj = node_java$left
	node_data$right_java_obj = node_java$right
	node_data$depth = node_java$depth
	node_data$isLeaf = node_java$isLeaf

	node_data$sendMissingDataRight = node_java$sendMissingDataRight

	node_data$n_eta = node_java$n_eta
	node_data$string_id = node_java$stringID()
	node_data$is_stump = node_java$isStump()
	node_data$string_location = node_java$stringLocation()
	
	if (node_java$splitAttributeM == BAD_FLAG_INT){
		node_data$splitAttributeM = NA
	} else {
		node_data$splitAttributeM = node_java$splitAttributeM
	}
	
	if (node_java$splitValue == BAD_FLAG_DOUBLE){
		node_data$splitValue = NA
	} else {
		node_data$splitValue = node_java$splitValue
	}
	
	if (node_java$y_pred == BAD_FLAG_DOUBLE){
		node_data$y_pred = NA
	} else {
		node_data$y_pred = node_java$y_pred
	}
	
	if (node_java$y_avg == BAD_FLAG_DOUBLE){
		node_data$y_avg = NA
	} else {
		node_data$y_avg = node_java$y_avg
	}
	
	if (node_java$posterior_var == BAD_FLAG_DOUBLE){
		node_data$posterior_var = NA
	} else {
		node_data$posterior_var = node_java$posterior_var
	}
	
	if (node_java$posterior_mean == BAD_FLAG_DOUBLE){
		node_data$posterior_mean = NA
	} else {
		node_data$posterior_mean = node_java$posterior_mean
	}
	
	if (!is.jnull(node_java$parent)){
		node_data$parent_java_obj = node_java$parent
	} else {
		node_data$parent_java_obj = NA	
	}
	
	if (!is.jnull(node_java$left)){
		node_data$left = extract_node_data(node_java$left)
	} else {
		node_data$left = NA
	}
	if (!is.jnull(node_java$right)){
		node_data$right = extract_node_data(node_java$right)
	} else {
		node_data$right = NA
	}
	node_data
}
		
		

Try the bartMachine package in your browser

Any scripts or data that you put into this service are public.

bartMachine documentation built on Jan. 19, 2026, 9:06 a.m.