Nothing
test_that("Prediction from trees with constant leaf", {
# Create dataset and forest container
num_trees <- 10
# fmt: skip
X = matrix(c(1.5, 8.7, 1.2,
2.7, 3.4, 5.4,
3.6, 1.2, 9.3,
4.4, 5.4, 10.4,
5.3, 9.3, 3.6,
6.1, 10.4, 4.4),
byrow = TRUE, nrow = 6)
n <- nrow(X)
p <- ncol(X)
forest_dataset = createForestDataset(X)
forest_samples <- createForestSamples(num_trees, 1, TRUE)
# Initialize a forest with constant root predictions
forest_samples$add_forest_with_constant_leaves(0.)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
# Assertion
expect_equal(pred, pred_raw)
# Split the root of the first tree in the ensemble at X[,1] > 4.0
forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
# Assertion
expect_equal(pred, pred_raw)
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
# Assertion
expect_equal(pred, pred_raw)
# Check the split count for the first tree in the ensemble
split_counts <- forest_samples$get_tree_split_counts(0, 0, p)
split_counts_expected <- c(1, 1, 0)
# Assertion
expect_equal(split_counts, split_counts_expected)
})
test_that("Prediction from trees with univariate leaf basis", {
# Create dataset and forest container
num_trees <- 10
# fmt: skip
X = matrix(c(1.5, 8.7, 1.2,
2.7, 3.4, 5.4,
3.6, 1.2, 9.3,
4.4, 5.4, 10.4,
5.3, 9.3, 3.6,
6.1, 10.4, 4.4),
byrow = TRUE, nrow = 6)
W = as.matrix(c(-1, -1, -1, 1, 1, 1))
n <- nrow(X)
p <- ncol(X)
forest_dataset = createForestDataset(X, W)
forest_samples <- createForestSamples(num_trees, 1, FALSE)
# Initialize a forest with constant root predictions
forest_samples$add_forest_with_constant_leaves(0.)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
# Assertion
expect_equal(pred, pred_raw)
# Split the root of the first tree in the ensemble at X[,1] > 4.0
forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
pred_manual <- pred_raw * W
# Assertion
expect_equal(pred, pred_manual)
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
pred_manual <- pred_raw * W
# Assertion
expect_equal(pred, pred_manual)
# Check the split count for the first tree in the ensemble
split_counts <- forest_samples$get_tree_split_counts(0, 0, p)
split_counts_expected <- c(1, 1, 0)
# Assertion
expect_equal(split_counts, split_counts_expected)
})
test_that("Prediction from trees with multivariate leaf basis", {
# Create dataset and forest container
num_trees <- 10
output_dim <- 2
num_samples <- 0
# fmt: skip
X = matrix(c(1.5, 8.7, 1.2,
2.7, 3.4, 5.4,
3.6, 1.2, 9.3,
4.4, 5.4, 10.4,
5.3, 9.3, 3.6,
6.1, 10.4, 4.4),
byrow = TRUE, nrow = 6)
n <- nrow(X)
p <- ncol(X)
W = matrix(c(1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1), byrow = FALSE, nrow = 6)
forest_dataset = createForestDataset(X, W)
forest_samples <- createForestSamples(num_trees, output_dim, FALSE)
# Initialize a forest with constant root predictions
forest_samples$add_forest_with_constant_leaves(c(1., 1.))
num_samples <- num_samples + 1
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
dim(pred_intermediate) <- c(n, output_dim, num_samples)
pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))
# Assertion
expect_equal(pred, pred_manual)
# Split the root of the first tree in the ensemble at X[,1] > 4.0
forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5., -1.), c(5., 1.))
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
dim(pred_intermediate) <- c(n, output_dim, num_samples)
pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))
# Assertion
expect_equal(pred, pred_manual)
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
forest_samples$add_numeric_split_tree(
0,
0,
1,
1,
4.0,
c(-7.5, 2.5),
c(-2.5, 7.5)
)
# Check that regular and "raw" predictions are the same (since the leaf is constant)
pred <- forest_samples$predict(forest_dataset)
pred_raw <- forest_samples$predict_raw(forest_dataset)
pred_intermediate <- as.numeric(pred_raw) * as.numeric(W)
dim(pred_intermediate) <- c(n, output_dim, num_samples)
pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x))
# Assertion
expect_equal(pred, pred_manual)
# Check the split count for the first tree in the ensemble
split_counts <- forest_samples$get_tree_split_counts(0, 0, 3)
split_counts_expected <- c(1, 1, 0)
# Assertion
expect_equal(split_counts, split_counts_expected)
})
test_that("BART predictions with pre-summarization", {
# Generate data and test-train split
n <- 100
p <- 5
X <- matrix(runif(n * p), ncol = p)
# fmt: skip
f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5))
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
y_test <- y[test_inds]
y_train <- y[train_inds]
# Fit a "classic" BART model
bart_model <- bart(
X_train = X_train,
y_train = y_train,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10
)
# Check that the default predict method returns a list
pred <- predict(bart_model, X = X_test)
y_hat_posterior_test <- pred$y_hat
expect_equal(dim(y_hat_posterior_test), c(20, 10))
# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(bart_model, X = X_test, type = "mean")
y_hat_mean_test <- pred_mean$y_hat
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
# Check that we warn and return a NULL when requesting terms that weren't fit
expect_warning({
pred_mean <- predict(
bart_model,
X = X_test,
type = "mean",
terms = c("rfx", "variance_forest")
)
})
expect_equal(NULL, pred_mean)
# Fit a heteroskedastic BART model
var_params <- list(num_trees = 20)
het_bart_model <- bart(
X_train = X_train,
y_train = y_train,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10,
variance_forest_params = var_params
)
# Check that the default predict method returns a list
pred <- predict(het_bart_model, X = X_test)
y_hat_posterior_test <- pred$y_hat
sigma2_hat_posterior_test <- pred$variance_forest_predictions
# Assertion
expect_equal(dim(y_hat_posterior_test), c(20, 10))
expect_equal(dim(sigma2_hat_posterior_test), c(20, 10))
# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(het_bart_model, X = X_test, type = "mean")
y_hat_mean_test <- pred_mean$y_hat
sigma2_hat_mean_test <- pred_mean$variance_forest_predictions
# Assertion
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test))
# Check that the "single-term" pre-aggregated predictions
# match those computed by pre-aggregated predictions returned in a list
y_hat_mean_test_single_term <- predict(
het_bart_model,
X = X_test,
type = "mean",
terms = "y_hat"
)
sigma2_hat_mean_test_single_term <- predict(
het_bart_model,
X = X_test,
type = "mean",
terms = "variance_forest"
)
# Assertion
expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})
test_that("BCF predictions with pre-summarization", {
# Generate data and test-train split
n <- 100
g <- function(x) {
ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4))
}
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n, 1, 0.5))
x5 <- as.numeric(sample(1:3, n, replace = TRUE))
X <- cbind(x1, x2, x3, x4, x5)
p <- ncol(X)
mu_x <- 1 + g(X) + X[, 1] * X[, 3]
tau_x <- 1 + 2 * X[, 2] * X[, 4]
pi_x <- 0.8 *
pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) +
0.05 +
runif(n) / 10
Z <- rbinom(n, 1, pi_x)
E_XZ <- mu_x + Z * tau_x
snr <- 2
y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
# Fit a "classic" BCF model
bcf_model <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
propensity_train = pi_train,
X_test = X_test,
Z_test = Z_test,
propensity_test = pi_test,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10
)
# Check that the default predict method returns a list
pred <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test
)
y_hat_posterior_test <- pred$y_hat
expect_equal(dim(y_hat_posterior_test), c(20, 10))
# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean"
)
y_hat_mean_test <- pred_mean$y_hat
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
# Check that we warn and return a NULL when requesting terms that weren't fit
expect_warning({
pred_mean <- predict(
bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = c("rfx", "variance_forest")
)
})
expect_equal(NULL, pred_mean)
# Fit a heteroskedastic BCF model
var_params <- list(num_trees = 20)
expect_warning(
het_bcf_model <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
propensity_train = pi_train,
X_test = X_test,
Z_test = Z_test,
propensity_test = pi_test,
num_gfr = 10,
num_burnin = 0,
num_mcmc = 10,
variance_forest_params = var_params
)
)
# Check that the default predict method returns a list
pred <- predict(
het_bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test
)
y_hat_posterior_test <- pred$y_hat
sigma2_hat_posterior_test <- pred$variance_forest_predictions
# Assertion
expect_equal(dim(y_hat_posterior_test), c(20, 10))
expect_equal(dim(sigma2_hat_posterior_test), c(20, 10))
# Check that the pre-aggregated predictions match with those computed by rowMeans
pred_mean <- predict(
het_bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean"
)
y_hat_mean_test <- pred_mean$y_hat
sigma2_hat_mean_test <- pred_mean$variance_forest_predictions
# Assertion
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test))
# Check that the "single-term" pre-aggregated predictions
# match those computed by pre-aggregated predictions returned in a list
y_hat_mean_test_single_term <- predict(
het_bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = "y_hat"
)
sigma2_hat_mean_test_single_term <- predict(
het_bcf_model,
X = X_test,
Z = Z_test,
propensity = pi_test,
type = "mean",
terms = "variance_forest"
)
# Assertion
expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.