tests/testthat/test-dbarts.R

context("test-dbarts.R -- Lrnr_dbarts")

library(dbarts)

set.seed(99)
generateFriedmanData <- function(n = 100, sigma = 1.0) {
  f <- function(x) {
    10 * sin(pi * x[, 1] * x[, 2]) + 20 * (x[, 3] - 0.5)^2 +
      10 * x[, 4] + 5 * x[, 5]
  }
  x <- matrix(runif(n * 10), n, 10)
  y <- rnorm(n, f(x), sigma)
  return(list(x = x, y = y))
}

set.seed(11)
testData <- generateFriedmanData()
x <- data.frame(testData$x)
names(x) <- paste("Covs", seq_along(1:ncol(x)), sep = "_")
y <- data.frame(y = testData$y)

covars <- names(x)
outcome <- names(y)

data <- cbind.data.frame(y, x)
task <- sl3_Task$new(data, covariates = covars, outcome = outcome)

test_that("Lrnr_dbarts produces results matching those of dbarts::barts", {
  # get predictions from Lrnr_* wrapper
  set.seed(123)
  lrnr_dbarts <- make_learner(Lrnr_dbarts, verbose = FALSE)
  fit <- lrnr_dbarts$train(task)
  preds <- fit$predict(task)
  rmse_sl3 <- sqrt(mean((preds - task$Y)^2))

  # get predictions from classic implementation
  set.seed(123)
  fit_classic <- dbarts::bart(
    x.train = data.frame(task$X), y.train = task$Y, keeptrees = TRUE,
    ndpost = 500, verbose = FALSE
  )

  preds_classic <- rowMeans(t(predict(fit_classic, newdata = task$X)))
  rmse_classic <- sqrt(mean((preds_classic - task$Y)^2))

  # check equality of predictions
  expect_equal(rmse_sl3, rmse_classic, tolerance = 0.1)
})


test_that("Lrnr_dbarts with continuous outcome works", {
  dbart_learner <- Lrnr_dbarts$new(ndpost = 200)
  dbart_fit <- dbart_learner$train(task)
  mean_pred_sl3 <- mean(dbart_fit$predict(task))
  expect_true(mean_pred_sl3 < 20)
})

## Classification

generateProbitData <- function() {
  n <- 800
  beta <- c(0.12, -0.89, 0.3)
  p <- 3

  set.seed(0)
  X <- matrix(rnorm(p * n), n, p)

  mu <- pnorm(X %*% beta)
  Z <- rbinom(n, 1, mu)

  return(list(X = X, Z = Z, p = mu))
}

testData <- generateProbitData()
x <- data.frame(testData$X)
names(x) <- paste("Covs", seq_along(1:ncol(x)), sep = "_")
y <- data.frame(y = testData$Z)

covars <- names(x)
outcome <- names(y)

data <- cbind.data.frame(y, x)
task <- sl3_Task$new(data, covariates = covars, outcome = outcome)

test_that("Lrnr_dbarts with binary outcome works", {
  dbart_learner <- Lrnr_dbarts$new()
  dbart_fit <- dbart_learner$train(task)
  mean_pred_sl3 <- mean(dbart_fit$predict(task))
  expect_true(mean_pred_sl3 < 20)
})


# test_that("Lrnr_dbarts can be serialized with appropriate param", {
#   dbart_learner_serializable <- Lrnr_dbarts$new(serializeable=TRUE)
#   dbarts_s_fit <- dbart_learner_serializable$train(task)
#
#   preds <- dbarts_s_fit$predict()
#   tmp <- tempfile()
#   save(dbarts_s_fit, file=tmp)
#   rm(dbarts_s_fit)
#   load(tmp)
#   preds_from_serialized <- dbarts_s_fit$predict()
#   tmp <- tempfile()
#
#   db_fit <- dbarts_s_fit$fit_object
#   z <- db_fit$fit$storeState()
#   invisible(db_fit$fit$state)
#   predict(db_fit, as.matrix(task$X))
#   save(db_fit, file=tmp)
#   rm(db_fit)
#   load(tmp)
#   predict(db_fit, as.matrix(task$X))
#   expect_equal(preds, preds_from_serialized)
# })
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.