tests/testthat/test-predict.R

test_that("Prediction from trees with constant leaf", {
  # Create dataset and forest container
  num_trees <- 10
  # fmt: skip
  X = matrix(c(1.5, 8.7, 1.2, 
                 2.7, 3.4, 5.4, 
                 3.6, 1.2, 9.3, 
                 4.4, 5.4, 10.4, 
                 5.3, 9.3, 3.6, 
                 6.1, 10.4, 4.4), 
               byrow = TRUE, nrow = 6)
  n <- nrow(X)
  p <- ncol(X)
  forest_dataset = createForestDataset(X)
  forest_samples <- createForestSamples(num_trees, 1, TRUE)

  # Initialize a forest with constant root predictions
  forest_samples$add_forest_with_constant_leaves(0.)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)

  # Assertion
  expect_equal(pred, pred_raw)

  # Split the root of the first tree in the ensemble at X[,1] > 4.0
  forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)

  # Assertion
  expect_equal(pred, pred_raw)

  # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
  forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)

  # Assertion
  expect_equal(pred, pred_raw)

  # Check the split count for the first tree in the ensemble
  split_counts <- forest_samples$get_tree_split_counts(0, 0, p)
  split_counts_expected <- c(1, 1, 0)

  # Assertion
  expect_equal(split_counts, split_counts_expected)
})

test_that("Prediction from trees with univariate leaf basis", {
  # Create dataset and forest container
  num_trees <- 10
  # fmt: skip
  X = matrix(c(1.5, 8.7, 1.2, 
                 2.7, 3.4, 5.4, 
                 3.6, 1.2, 9.3, 
                 4.4, 5.4, 10.4, 
                 5.3, 9.3, 3.6, 
                 6.1, 10.4, 4.4), 
               byrow = TRUE, nrow = 6)
  W = as.matrix(c(-1, -1, -1, 1, 1, 1))
  n <- nrow(X)
  p <- ncol(X)
  forest_dataset = createForestDataset(X, W)
  forest_samples <- createForestSamples(num_trees, 1, FALSE)

  # Initialize a forest with constant root predictions
  forest_samples$add_forest_with_constant_leaves(0.)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)

  # Assertion
  expect_equal(pred, pred_raw)

  # Split the root of the first tree in the ensemble at X[,1] > 4.0
  forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)
  pred_manual <- pred_raw * W

  # Assertion
  expect_equal(pred, pred_manual)

  # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
  forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5)

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)
  pred_manual <- pred_raw * W

  # Assertion
  expect_equal(pred, pred_manual)

  # Check the split count for the first tree in the ensemble
  split_counts <- forest_samples$get_tree_split_counts(0, 0, p)
  split_counts_expected <- c(1, 1, 0)

  # Assertion
  expect_equal(split_counts, split_counts_expected)
})

test_that("Prediction from trees with multivariate leaf basis", {
  # Create dataset and forest container
  num_trees <- 10
  output_dim <- 2
  num_samples <- 0
  # fmt: skip
  X = matrix(c(1.5, 8.7, 1.2, 
                 2.7, 3.4, 5.4, 
                 3.6, 1.2, 9.3, 
                 4.4, 5.4, 10.4, 
                 5.3, 9.3, 3.6, 
                 6.1, 10.4, 4.4), 
               byrow = TRUE, nrow = 6)
  n <- nrow(X)
  p <- ncol(X)
  W = matrix(c(1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1), byrow = FALSE, nrow = 6)
  forest_dataset = createForestDataset(X, W)
  forest_samples <- createForestSamples(num_trees, output_dim, FALSE)

  # Initialize a forest with constant root predictions
  forest_samples$add_forest_with_constant_leaves(c(1., 1.))
  num_samples <- num_samples + 1

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)
  pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
  dim(pred_intermediate) <- c(n, output_dim, num_samples)
  pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))

  # Assertion
  expect_equal(pred, pred_manual)

  # Split the root of the first tree in the ensemble at X[,1] > 4.0
  forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5., -1.), c(5., 1.))

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)
  pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
  dim(pred_intermediate) <- c(n, output_dim, num_samples)
  pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))

  # Assertion
  expect_equal(pred, pred_manual)

  # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
  forest_samples$add_numeric_split_tree(
    0,
    0,
    1,
    1,
    4.0,
    c(-7.5, 2.5),
    c(-2.5, 7.5)
  )

  # Check that regular and "raw" predictions are the same (since the leaf is constant)
  pred <- forest_samples$predict(forest_dataset)
  pred_raw <- forest_samples$predict_raw(forest_dataset)
  pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
  dim(pred_intermediate) <- c(n, output_dim, num_samples)
  pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))

  # Assertion
  expect_equal(pred, pred_manual)

  # Check the split count for the first tree in the ensemble
  split_counts <- forest_samples$get_tree_split_counts(0, 0, 3)
  split_counts_expected <- c(1, 1, 0)

  # Assertion
  expect_equal(split_counts, split_counts_expected)
})

test_that("BART predictions with pre-summarization", {
  # Generate data and test-train split
  n <- 100
  p <- 5
  X <- matrix(runif(n * p), ncol = p)
  # fmt: skip
  f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
                 ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
                 ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
                 ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5))
  noise_sd <- 1
  y <- f_XW + rnorm(n, 0, noise_sd)
  test_set_pct <- 0.2
  n_test <- round(test_set_pct * n)
  n_train <- n - n_test
  test_inds <- sort(sample(1:n, n_test, replace = FALSE))
  train_inds <- (1:n)[!((1:n) %in% test_inds)]
  X_test <- X[test_inds, ]
  X_train <- X[train_inds, ]
  y_test <- y[test_inds]
  y_train <- y[train_inds]

  # Fit a "classic" BART model
  bart_model <- bart(
    X_train = X_train,
    y_train = y_train,
    num_gfr = 10,
    num_burnin = 0,
    num_mcmc = 10
  )

  # Check that the default predict method returns a list
  pred <- predict(bart_model, X = X_test)
  y_hat_posterior_test <- pred$y_hat
  expect_equal(dim(y_hat_posterior_test), c(20, 10))

  # Check that the pre-aggregated predictions match with those computed by rowMeans
  pred_mean <- predict(bart_model, X = X_test, type = "mean")
  y_hat_mean_test <- pred_mean$y_hat
  expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))

  # Check that we warn and return a NULL when requesting terms that weren't fit
  expect_warning({
    pred_mean <- predict(
      bart_model,
      X = X_test,
      type = "mean",
      terms = c("rfx", "variance_forest")
    )
  })
  expect_equal(NULL, pred_mean)

  # Fit a heteroskedastic BART model
  var_params <- list(num_trees = 20)
  het_bart_model <- bart(
    X_train = X_train,
    y_train = y_train,
    num_gfr = 10,
    num_burnin = 0,
    num_mcmc = 10,
    variance_forest_params = var_params
  )

  # Check that the default predict method returns a list
  pred <- predict(het_bart_model, X = X_test)
  y_hat_posterior_test <- pred$y_hat
  sigma2_hat_posterior_test <- pred$variance_forest_predictions

  # Assertion
  expect_equal(dim(y_hat_posterior_test), c(20, 10))
  expect_equal(dim(sigma2_hat_posterior_test), c(20, 10))

  # Check that the pre-aggregated predictions match with those computed by rowMeans
  pred_mean <- predict(het_bart_model, X = X_test, type = "mean")
  y_hat_mean_test <- pred_mean$y_hat
  sigma2_hat_mean_test <- pred_mean$variance_forest_predictions

  # Assertion
  expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
  expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test))

  # Check that the "single-term" pre-aggregated predictions
  # match those computed by pre-aggregated predictions returned in a list
  y_hat_mean_test_single_term <- predict(
    het_bart_model,
    X = X_test,
    type = "mean",
    terms = "y_hat"
  )
  sigma2_hat_mean_test_single_term <- predict(
    het_bart_model,
    X = X_test,
    type = "mean",
    terms = "variance_forest"
  )

  # Assertion
  expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
  expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})

test_that("BCF predictions with pre-summarization", {
  # Generate data and test-train split
  n <- 100
  g <- function(x) {
    ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4))
  }
  x1 <- rnorm(n)
  x2 <- rnorm(n)
  x3 <- rnorm(n)
  x4 <- as.numeric(rbinom(n, 1, 0.5))
  x5 <- as.numeric(sample(1:3, n, replace = TRUE))
  X <- cbind(x1, x2, x3, x4, x5)
  p <- ncol(X)
  mu_x <- 1 + g(X) + X[, 1] * X[, 3]
  tau_x <- 1 + 2 * X[, 2] * X[, 4]
  pi_x <- 0.8 *
    pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) +
    0.05 +
    runif(n) / 10
  Z <- rbinom(n, 1, pi_x)
  E_XZ <- mu_x + Z * tau_x
  snr <- 2
  y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)
  X <- as.data.frame(X)
  X$x4 <- factor(X$x4, ordered = TRUE)
  X$x5 <- factor(X$x5, ordered = TRUE)
  test_set_pct <- 0.2
  n_test <- round(test_set_pct * n)
  n_train <- n - n_test
  test_inds <- sort(sample(1:n, n_test, replace = FALSE))
  train_inds <- (1:n)[!((1:n) %in% test_inds)]
  X_test <- X[test_inds, ]
  X_train <- X[train_inds, ]
  pi_test <- pi_x[test_inds]
  pi_train <- pi_x[train_inds]
  Z_test <- Z[test_inds]
  Z_train <- Z[train_inds]
  y_test <- y[test_inds]
  y_train <- y[train_inds]

  # Fit a "classic" BCF model
  bcf_model <- bcf(
    X_train = X_train,
    Z_train = Z_train,
    y_train = y_train,
    propensity_train = pi_train,
    X_test = X_test,
    Z_test = Z_test,
    propensity_test = pi_test,
    num_gfr = 10,
    num_burnin = 0,
    num_mcmc = 10
  )

  # Check that the default predict method returns a list
  pred <- predict(
    bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test
  )
  y_hat_posterior_test <- pred$y_hat
  expect_equal(dim(y_hat_posterior_test), c(20, 10))

  # Check that the pre-aggregated predictions match with those computed by rowMeans
  pred_mean <- predict(
    bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test,
    type = "mean"
  )
  y_hat_mean_test <- pred_mean$y_hat
  expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))

  # Check that we warn and return a NULL when requesting terms that weren't fit
  expect_warning({
    pred_mean <- predict(
      bcf_model,
      X = X_test,
      Z = Z_test,
      propensity = pi_test,
      type = "mean",
      terms = c("rfx", "variance_forest")
    )
  })
  expect_equal(NULL, pred_mean)

  # Fit a heteroskedastic BCF model
  var_params <- list(num_trees = 20)
  expect_warning(
    het_bcf_model <- bcf(
      X_train = X_train,
      Z_train = Z_train,
      y_train = y_train,
      propensity_train = pi_train,
      X_test = X_test,
      Z_test = Z_test,
      propensity_test = pi_test,
      num_gfr = 10,
      num_burnin = 0,
      num_mcmc = 10,
      variance_forest_params = var_params
    )
  )

  # Check that the default predict method returns a list
  pred <- predict(
    het_bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test
  )
  y_hat_posterior_test <- pred$y_hat
  sigma2_hat_posterior_test <- pred$variance_forest_predictions

  # Assertion
  expect_equal(dim(y_hat_posterior_test), c(20, 10))
  expect_equal(dim(sigma2_hat_posterior_test), c(20, 10))

  # Check that the pre-aggregated predictions match with those computed by rowMeans
  pred_mean <- predict(
    het_bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test,
    type = "mean"
  )
  y_hat_mean_test <- pred_mean$y_hat
  sigma2_hat_mean_test <- pred_mean$variance_forest_predictions

  # Assertion
  expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
  expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test))

  # Check that the "single-term" pre-aggregated predictions
  # match those computed by pre-aggregated predictions returned in a list
  y_hat_mean_test_single_term <- predict(
    het_bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test,
    type = "mean",
    terms = "y_hat"
  )
  sigma2_hat_mean_test_single_term <- predict(
    het_bcf_model,
    X = X_test,
    Z = Z_test,
    propensity = pi_test,
    type = "mean",
    terms = "variance_forest"
  )

  # Assertion
  expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
  expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on Nov. 22, 2025, 9:06 a.m.