tests/testthat/test-residual.R

test_that("Residual updates correctly propagated after forest sampling step", {
  # Setup
  # Create dataset
  # fmt: skip
  X = matrix(c(1.5, 8.7, 1.2, 
                 2.7, 3.4, 5.4, 
                 3.6, 1.2, 9.3, 
                 4.4, 5.4, 10.4, 
                 5.3, 9.3, 3.6, 
                 6.1, 10.4, 4.4), 
               byrow = TRUE, nrow = 6)
  W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6)
  n = nrow(X)
  p = ncol(X)
  y = as.matrix(ifelse(X[, 1] > 4, -5, 5) + rnorm(n, 0, 1))
  y_bar = mean(y)
  y_std = sd(y)
  resid = (y - y_bar) / y_std
  forest_dataset = createForestDataset(X, W)
  residual = createOutcome(resid)
  variable_weights = rep(1.0 / p, p)
  feature_types = as.integer(rep(0, p))

  # Forest parameters
  num_trees = 50
  alpha = 0.95
  beta = 2.0
  min_samples_leaf = 1
  current_sigma2 = 1.
  current_leaf_scale = as.matrix(1. / num_trees, nrow = 1, ncol = 1)
  cutpoint_grid_size = 100
  max_depth = 10
  a_forest = 0
  b_forest = 0

  # RNG
  cpp_rng = createCppRNG(-1)

  # Create forest sampler and forest container
  global_model_config = createGlobalModelConfig(
    global_error_variance = current_sigma2
  )
  forest_model_config = createForestModelConfig(
    feature_types = feature_types,
    num_trees = num_trees,
    num_observations = n,
    alpha = alpha,
    beta = beta,
    min_samples_leaf = min_samples_leaf,
    max_depth = max_depth,
    leaf_model_type = 0,
    leaf_model_scale = current_leaf_scale,
    variable_weights = variable_weights,
    variance_forest_shape = a_forest,
    variance_forest_scale = b_forest,
    cutpoint_grid_size = cutpoint_grid_size
  )
  forest_model = createForestModel(
    forest_dataset,
    forest_model_config,
    global_model_config
  )
  forest_samples = createForestSamples(num_trees, 1, FALSE)
  active_forest = createForest(num_trees, 1, FALSE)

  # Initialize the leaves of each tree in the prognostic forest
  active_forest$prepare_for_sampler(
    forest_dataset,
    residual,
    forest_model,
    0,
    mean(resid)
  )
  active_forest$adjust_residual(
    forest_dataset,
    residual,
    forest_model,
    FALSE,
    FALSE
  )

  # Run the forest sampling algorithm for a single iteration
  forest_model$sample_one_iteration(
    forest_dataset,
    residual,
    forest_samples,
    active_forest,
    cpp_rng,
    forest_model_config,
    global_model_config,
    keep_forest = TRUE,
    gfr = TRUE
  )

  # Get the current residual after running the sampler
  initial_resid = residual$get_data()

  # Get initial prediction from the tree ensemble
  initial_yhat = as.numeric(forest_samples$predict(forest_dataset))

  # Update the basis vector
  scalar = 2.0
  W_update = W * scalar
  forest_dataset$update_basis(W_update)

  # Update residual to reflect adjusted basis
  forest_model$propagate_basis_update(forest_dataset, residual, active_forest)

  # Get updated prediction from the tree ensemble
  updated_yhat = as.numeric(forest_samples$predict(forest_dataset))

  # Compute the expected residual
  expected_resid = initial_resid + initial_yhat - updated_yhat

  # Get the current residual after running the sampler
  updated_resid = residual$get_data()

  # Assertion
  expect_equal(expected_resid, updated_resid)
})

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on Nov. 22, 2025, 9:06 a.m.