tests/testthat/helper-bartmachine-models.R

options(java.parameters = c("-Xmx20g", "--add-modules=jdk.incubator.vector", "-XX:+UseZGC"))
options(warn = 2)

library(bartMachine)

if (!exists("BART_TESTS")) {
  BART_TESTS <- list(
    core_count = 4,
    pdf_width = 8,
    pdf_height = 6,
    seed = 11,
    verbose = FALSE,
    default_n = 120,
    default_p = 4,
    default_class_p = 3,
    reg_coef_1 = 3,
    reg_coef_2 = 2,
    reg_noise_sd = 0.2,
    model_num_trees = 20,
    model_burn_in = 10,
    model_iter = 20,
    class_threshold = 0.5,
    small_data_n = 30,
    small_data_sd = 0.1,
    small_model_num_trees = 10,
    small_model_burn_in = 5,
    small_model_iter = 10,
    dummify_b_values = c(1, 2),
    arr_r = 2,
    posterior_subset_n = 5,
    credible_interval_cols = 2,
    prediction_samples_per_point = 20,
    covariate_index = 1,
    reg_feature_index_1 = 1,
    reg_feature_index_2 = 2,
    class_feature_index = 1,
    first_dim_index = 1,
    num_permute_samples = 5,
    linearity_num_trees = 10,
    linearity_num_burn_in = 5,
    linearity_num_iter = 10,
    num_replicates_for_avg = 1,
    num_trees_bottleneck = 5,
    num_var_plot = 3,
    min_plot_count = 1,
    pd_plot_prop_data = 0.2,
    rmse_tree_list = c(5, 10),
    rmse_num_replicates = 1,
    var_sel_k_folds = 2,
    var_sel_num_reps_for_avg = 1,
    var_sel_num_permute_samples = 5,
    var_sel_num_trees_for_permute = 5,
    var_sel_num_trees_pred_cv = 5,
    cv_data_n = 40,
    cv_data_sd = 0.1,
    cv_num_tree_cvs = c(5, 10),
    cv_num_tree_cvs_single = c(5),
    cv_k_cvs = c(2),
    cv_k_folds = 2,
    cv_num_burn_in = 5,
    cv_num_iter = 10,
    cv_num_trees = 5,
    raw_node_g = 1,
    test_idx_n = 20,
    test_files = c(
      "test-exports-basic.R",
      "test-exports-diagnostics.R",
      "test-exports-model-selection.R"
    ),
    np_grid = list(
      list(n = 60, p = 3, num_trees = 20, num_burn_in = 10, num_iter = 20),
      list(n = 120, p = 4, num_trees = 20, num_burn_in = 10, num_iter = 20),
      list(n = 200, p = 6, num_trees = 30, num_burn_in = 15, num_iter = 30),
      list(n = 240, p = 8, num_trees = 35, num_burn_in = 15, num_iter = 35),
      list(n = 320, p = 10, num_trees = 40, num_burn_in = 20, num_iter = 40),
      list(n = 400, p = 12, num_trees = 45, num_burn_in = 20, num_iter = 45),
      list(n = 480, p = 15, num_trees = 50, num_burn_in = 20, num_iter = 50)
    )
  )
}

get_test_env_int <- function(name, default) {
  val <- Sys.getenv(name, "")
  if (nzchar(val)) {
    as.integer(val)
  } else {
    default
  }
}

with_plot_device <- function(expr) {
  path <- tempfile(fileext = ".pdf")
  grDevices::pdf(path, width = BART_TESTS$pdf_width, height = BART_TESTS$pdf_height)
  on.exit(grDevices::dev.off(), add = TRUE)
  force(expr)
  invisible(path)
}

build_regression_model <- function() {
  set.seed(BART_TESTS$seed)
  n <- as.integer(Sys.getenv("N_TEST", as.character(BART_TESTS$default_n)))
  p <- as.integer(Sys.getenv("P_TEST", as.character(BART_TESTS$default_p)))
  X <- data.frame(matrix(runif(n * p), ncol = p))
  y <- BART_TESTS$reg_coef_1 * X[, BART_TESTS$reg_feature_index_1] +
    BART_TESTS$reg_coef_2 * X[, BART_TESTS$reg_feature_index_2] +
    rnorm(n, sd = BART_TESTS$reg_noise_sd)
  set_bart_machine_num_cores(BART_TESTS$core_count, verbose = BART_TESTS$verbose)
  model <- bartMachine(
    X,
    y,
    num_trees = get_test_env_int("MODEL_NUM_TREES", BART_TESTS$model_num_trees),
    num_burn_in = get_test_env_int("MODEL_BURN_IN", BART_TESTS$model_burn_in),
    num_iterations_after_burn_in = get_test_env_int("MODEL_ITER", BART_TESTS$model_iter),
    flush_indices_to_save_RAM = FALSE,
    verbose = BART_TESTS$verbose
  )
  set_bart_machine_num_cores(1, verbose = BART_TESTS$verbose)
  model
}

build_classification_model <- function() {
  set.seed(BART_TESTS$seed)
  n <- as.integer(Sys.getenv("N_CLASS_TEST", Sys.getenv("N_TEST", as.character(BART_TESTS$default_n))))
  p <- as.integer(Sys.getenv("P_CLASS_TEST", Sys.getenv("P_TEST", as.character(BART_TESTS$default_class_p))))
  X <- data.frame(matrix(runif(n * p), ncol = p))
  y <- factor(ifelse(
    X[, BART_TESTS$class_feature_index] > BART_TESTS$class_threshold,
    "yes",
    "no"
  ))
  set_bart_machine_num_cores(BART_TESTS$core_count, verbose = BART_TESTS$verbose)
  bartMachine(
    X,
    y,
    num_trees = get_test_env_int("MODEL_NUM_TREES", BART_TESTS$model_num_trees),
    num_burn_in = get_test_env_int("MODEL_BURN_IN", BART_TESTS$model_burn_in),
    num_iterations_after_burn_in = get_test_env_int("MODEL_ITER", BART_TESTS$model_iter),
    flush_indices_to_save_RAM = FALSE,
    verbose = BART_TESTS$verbose
  )
}

Try the bartMachine package in your browser

Any scripts or data that you put into this service are public.

bartMachine documentation built on Jan. 19, 2026, 9:06 a.m.