vc_softbart_regression: SoftBart Varying Coefficient Regression

View source: R/vc_softbart_regression.R

vc_softbart_regressionR Documentation

SoftBart Varying Coefficient Regression

Description

Fits a semiparametric varying coefficient regression model with the nonparametric slope and intercept

Y = \alpha(X) + Z \beta(X) + \epsilon

using a soft BART model.

Usage

vc_softbart_regression(
  formula,
  linear_var_name,
  data,
  test_data,
  num_tree = 20,
  k = 2,
  hypers_intercept = NULL,
  hypers_slope = NULL,
  opts = NULL,
  verbose = TRUE,
  warn = TRUE
)

Arguments

formula

A model formula with a numeric variable on the left-hand-side and non-linear predictors on the right-hand-side.

linear_var_name

A string containing the variable in the data that is to be treated linearly.

data

A data frame consisting of the training data.

test_data

A data frame consisting of the testing data.

num_tree

The number of trees in the ensemble to use.

k

Determines the standard deviation of the leaf node parameters, which is given by 3 / k / sqrt(num_tree) (intercept) and defaults to 1/k/sqrt(num_tree) (slope). This can be modified for the slope by specifying your own hyperparameter.

hypers_intercept

A list of hyperparameters constructed from the Hypers() function (num_tree, k, and sigma_mu are overridden by this function).

hypers_slope

A list of hyperparameters constructed from the Hypers() function (num_tree is overridden by this function).

opts

A list of options for running the chain constructed from the Opts() function (update_sigma is overridden by this function).

verbose

If TRUE, progress of the chain will be printed to the console.

warn

If TRUE, remind the user that they probably don't want the linear term to be included in the formula for the nonlinear part.

Value

Returns a list with the following components

  • sigma_mu_alpha: samples of the standard deviation of the leaf node parameters for the intercept.

  • sigma_mu_beta: samples of the standard deviation of the leaf node parameters for the slope.

  • sigma: samples of the error standard deviation.

  • var_counts_alpha: a matrix with a column for each predictor group containing the number of times each predictor is used in the ensemble at each iteration for the intercept.

  • var_counts_beta: a matrix with a column for each predictor group containing the number of times each predictor is used in the ensemble at each iteration for the slope.

  • alpha_train: samples of the nonparametric intercept evaluated on the training set.

  • alpha_test: samples of the nonparametric intercept evaluated on the test set.

  • beta_train: samples of the nonparametric slope evaluated on the training set.

  • beta_test: samples of the nonparametric slope evaluated on the test set.

  • mu_train: samples of the predictions evaluated on the training set.

  • mu_test: samples of the predictions evaluated on the test set.

  • formula: the formula specified by the user.

  • ecdfs: empirical distribution functions, used by the predict function.

  • opts: the options used when running the chain.

  • mu_Y, sd_Y: used with the predict function to transform predictions.

  • alpha_forest: a forest object for the intercept; see the MakeForest documentation for more details.

  • beta_forest: a forest object for the slope; see the MakeForest documentation for more details.

Examples


## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE

num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000

set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
  10 * x[,4] + 5 * x[,5]

gen_data <- function(n_train, n_test, P, sigma) {
  X <- matrix(runif(n_train * P), nrow = n_train)
  Z <- rnorm(n_train)
  r <- f_fried(X)
  mu <- Z * r
  X_test <- matrix(runif(n_test * P), nrow = n_test)
  Z_test <- rnorm(n_test)
  r_test <- f_fried(X_test)
  mu_test <- Z_test * r_test
  Y <- mu + sigma * rnorm(n_train)
  Y_test <- mu + sigma * rnorm(n_test)

  return(list(X = X, Y = Y, Z = Z, r = r, mu = mu, X_test = X_test, Y_test =
              Y_test, Z_test = Z_test, r_test = r_test, mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 250, 100, 1)

df <- data.frame(X = sim_data$X, Y = sim_data$Y, Z = sim_data$Z)
df_test <- data.frame(X = sim_data$X_test, Y = sim_data$Y_test, Z = sim_data$Z_test)

## Fit the model

opts <- Opts(num_burn = num_burn, num_save = num_save)
fitted_vc <- vc_softbart_regression(Y ~ . -Z, "Z", df, df_test, opts = opts)

## Plot results

plot(colMeans(fitted_vc$mu_test), sim_data$mu_test)
abline(a = 0, b = 1)


SoftBart documentation built on June 8, 2025, 9:40 p.m.