context("test grf: Generalized Random Forests")
if (FALSE) {
setwd("..")
getwd()
library("devtools")
document()
# load all R files in /R and datasets in /data. Ignores NAMESPACE:
load_all("./")
setwd("..")
# INSTALL W/ devtools:
install("sl3", build_vignettes = FALSE, dependencies = FALSE)
}
# Preliminaries
library(grf)
set.seed(791)
# Generate some data
n <- 50
p <- 10
X <- matrix(rnorm(n * p), n, p)
X.test <- matrix(0, 101, p)
X.test[, 1] <- seq(-2, 2, length.out = 101)
Y <- X[, 1] * rnorm(n)
data <- data.frame(list(Y = Y, X = X))
# Make sl3 Task
covars <- names(data)[2:ncol(data)]
outcome <- names(data)[1]
task <- sl3_Task$new(data, covariates = covars, outcome = outcome)
test_that("Lrnr_grf predictions match those of grf::quantile_forest", {
seed_int <- 496L
set.seed(seed_int)
# GRF learner class
grf_learner <- Lrnr_grf$new(seed = seed_int)
grf_fit <- grf_learner$train(task)
grf_pred <- grf_fit$predict(task)
set.seed(seed_int)
# GRF package
grf_pkg <- grf::quantile_forest(
X = X, Y = Y, seed = seed_int,
num.threads = 1L
)
grf_pkg_pred_out <- predict(
grf_pkg,
quantiles = grf_fit$params$quantiles_pred
)
grf_pkg_pred <- as.numeric(grf_pkg_pred_out$predictions)
# test equivalence
expect_equal(grf_pred, grf_pkg_pred)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.