tests/testthat/test-residual.R

test_that("Residual updates correctly propagated after forest sampling step", {
    # Setup
    # Create dataset
    X = matrix(c(1.5, 8.7, 1.2, 
                 2.7, 3.4, 5.4, 
                 3.6, 1.2, 9.3, 
                 4.4, 5.4, 10.4, 
                 5.3, 9.3, 3.6, 
                 6.1, 10.4, 4.4), 
               byrow = TRUE, nrow = 6)
    W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6)
    n = nrow(X)
    p = ncol(X)
    y = as.matrix(ifelse(X[,1]>4,-5,5) + rnorm(n,0,1))
    y_bar = mean(y)
    y_std = sd(y)
    resid = (y-y_bar)/y_std
    forest_dataset = createForestDataset(X, W)
    residual = createOutcome(resid)
    variable_weights = rep(1.0/p, p)
    feature_types = as.integer(rep(0, p))
    
    # Forest parameters
    num_trees = 50
    alpha = 0.95
    beta = 2.0
    min_samples_leaf = 1
    current_sigma2 = 1.
    current_leaf_scale = as.matrix(1./num_trees,nrow=1,ncol=1)
    cutpoint_grid_size = 100
    max_depth = 10
    a_forest = 0
    b_forest = 0
    
    # RNG
    cpp_rng = createCppRNG(-1)
    
    # Create forest sampler and forest container
    global_model_config = createGlobalModelConfig(global_error_variance=current_sigma2)
    forest_model_config = createForestModelConfig(feature_types=feature_types, num_trees=num_trees, 
                                                  num_observations=n, alpha=alpha, beta=beta, 
                                                  min_samples_leaf=min_samples_leaf, max_depth=max_depth, 
                                                  leaf_model_type=0, leaf_model_scale=current_leaf_scale, 
                                                  variable_weights=variable_weights, variance_forest_shape=a_forest, 
                                                  variance_forest_scale=b_forest, cutpoint_grid_size=cutpoint_grid_size)
    forest_model = createForestModel(forest_dataset, forest_model_config, global_model_config)
    forest_samples = createForestSamples(num_trees, 1, FALSE)
    active_forest = createForest(num_trees, 1, FALSE)
    
    # Initialize the leaves of each tree in the prognostic forest
    active_forest$prepare_for_sampler(forest_dataset, residual, forest_model, 0, mean(resid))
    active_forest$adjust_residual(forest_dataset, residual, forest_model, FALSE, FALSE)
    
    # Run the forest sampling algorithm for a single iteration
    forest_model$sample_one_iteration(
        forest_dataset, residual, forest_samples, active_forest, 
        cpp_rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE
    )

    # Get the current residual after running the sampler
    initial_resid = residual$get_data()
    
    # Get initial prediction from the tree ensemble
    initial_yhat = as.numeric(forest_samples$predict(forest_dataset))
    
    # Update the basis vector
    scalar = 2.0
    W_update = W*scalar
    forest_dataset$update_basis(W_update)
    
    # Update residual to reflect adjusted basis
    forest_model$propagate_basis_update(forest_dataset, residual, active_forest)
    
    # Get updated prediction from the tree ensemble
    updated_yhat = as.numeric(forest_samples$predict(forest_dataset))
    
    # Compute the expected residual
    expected_resid = initial_resid + initial_yhat - updated_yhat
    
    # Get the current residual after running the sampler
    updated_resid = residual$get_data()
    
    # Assertion
    expect_equal(expected_resid, updated_resid)
})

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 4, 2025, 2:11 a.m.