context("test-dbarts.R -- Lrnr_dbarts")
library(dbarts)
set.seed(99)
generateFriedmanData <- function(n = 100, sigma = 1.0) {
f <- function(x) {
10 * sin(pi * x[, 1] * x[, 2]) + 20 * (x[, 3] - 0.5)^2 +
10 * x[, 4] + 5 * x[, 5]
}
x <- matrix(runif(n * 10), n, 10)
y <- rnorm(n, f(x), sigma)
return(list(x = x, y = y))
}
set.seed(11)
testData <- generateFriedmanData()
x <- data.frame(testData$x)
names(x) <- paste("Covs", seq_along(1:ncol(x)), sep = "_")
y <- data.frame(y = testData$y)
covars <- names(x)
outcome <- names(y)
data <- cbind.data.frame(y, x)
task <- sl3_Task$new(data, covariates = covars, outcome = outcome)
test_that("Lrnr_dbarts produces results matching those of dbarts::barts", {
# get predictions from Lrnr_* wrapper
set.seed(123)
lrnr_dbarts <- make_learner(Lrnr_dbarts, verbose = FALSE)
fit <- lrnr_dbarts$train(task)
preds <- fit$predict(task)
rmse_sl3 <- sqrt(mean((preds - task$Y)^2))
# get predictions from classic implementation
set.seed(123)
fit_classic <- dbarts::bart(
x.train = data.frame(task$X), y.train = task$Y, keeptrees = TRUE,
ndpost = 500, verbose = FALSE
)
preds_classic <- rowMeans(t(predict(fit_classic, newdata = task$X)))
rmse_classic <- sqrt(mean((preds_classic - task$Y)^2))
# check equality of predictions
expect_equal(rmse_sl3, rmse_classic, tolerance = 0.1)
})
test_that("Lrnr_dbarts with continuous outcome works", {
dbart_learner <- Lrnr_dbarts$new(ndpost = 200)
dbart_fit <- dbart_learner$train(task)
mean_pred_sl3 <- mean(dbart_fit$predict(task))
expect_true(mean_pred_sl3 < 20)
})
## Classification
generateProbitData <- function() {
n <- 800
beta <- c(0.12, -0.89, 0.3)
p <- 3
set.seed(0)
X <- matrix(rnorm(p * n), n, p)
mu <- pnorm(X %*% beta)
Z <- rbinom(n, 1, mu)
return(list(X = X, Z = Z, p = mu))
}
testData <- generateProbitData()
x <- data.frame(testData$X)
names(x) <- paste("Covs", seq_along(1:ncol(x)), sep = "_")
y <- data.frame(y = testData$Z)
covars <- names(x)
outcome <- names(y)
data <- cbind.data.frame(y, x)
task <- sl3_Task$new(data, covariates = covars, outcome = outcome)
test_that("Lrnr_dbarts with binary outcome works", {
dbart_learner <- Lrnr_dbarts$new()
dbart_fit <- dbart_learner$train(task)
mean_pred_sl3 <- mean(dbart_fit$predict(task))
expect_true(mean_pred_sl3 < 20)
})
# test_that("Lrnr_dbarts can be serialized with appropriate param", {
# dbart_learner_serializable <- Lrnr_dbarts$new(serializeable=TRUE)
# dbarts_s_fit <- dbart_learner_serializable$train(task)
#
# preds <- dbarts_s_fit$predict()
# tmp <- tempfile()
# save(dbarts_s_fit, file=tmp)
# rm(dbarts_s_fit)
# load(tmp)
# preds_from_serialized <- dbarts_s_fit$predict()
# tmp <- tempfile()
#
# db_fit <- dbarts_s_fit$fit_object
# z <- db_fit$fit$storeState()
# invisible(db_fit$fit$state)
# predict(db_fit, as.matrix(task$X))
# save(db_fit, file=tmp)
# rm(db_fit)
# load(tmp)
# predict(db_fit, as.matrix(task$X))
# expect_equal(preds, preds_from_serialized)
# })
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.