tests/testthat/test-caret.R

context("test-caret.R -- General testing for Caret package support")

library(sl3)
library(testthat)
library(caret)

# define test dataset
data(mtcars)
task <- sl3_Task$new(mtcars, covariates = c(
  "cyl", "disp", "hp", "drat", "wt", "qsec",
  "vs", "am", "gear", "carb"
), outcome = "mpg")

mtcars$mpg_binary <- as.numeric(mtcars$mpg < 19)
task_binaryY <- sl3_Task$new(mtcars, covariates = c(
  "cyl", "disp", "hp", "drat", "wt", "qsec",
  "vs", "am", "gear", "carb"
), outcome = "mpg_binary")

mtcars$mpg_categorical <- as.factor(ifelse(
  mtcars$mpg < 17, "low", ifelse(mtcars$mpg < 21.5, "medium", "high")
))
task_catY <- sl3_Task$new(mtcars, covariates = c(
  "cyl", "disp", "hp", "drat", "wt", "qsec",
  "vs", "am", "gear", "carb"
), outcome = "mpg_categorical")

test_learner <- function(learner, task, ...) {
  learner_obj <- learner$new(...)
  task2 <- task$clone()
  print(sprintf("Testing Learner: %s", learner_obj$name))
  # test learner training
  fit_obj <- learner_obj$train(task)
  test_that("Learner can be trained on data", expect_true(fit_obj$is_trained))

  # test learner prediction
  train_preds <- fit_obj$predict()
  test_that("Learner can generate training set predictions", expect_equal(
    sl3:::safe_dim(train_preds)[1],
    length(task$Y)
  ))

  holdout_preds <- fit_obj$predict(task2)
  test_that("Learner can generate holdout set predictions", expect_equal(
    train_preds,
    holdout_preds
  ))

  # test learner chaining
  chained_task <- fit_obj$chain()
  test_that("Chaining returns a task", expect_true(is(chained_task, "sl3_Task")))
  test_that("Chaining returns the correct number of rows", expect_equal(
    nrow(chained_task$X),
    nrow(task$X)
  ))
}

## test caret learner with the example of Random Forest:
test_learner(Lrnr_caret, task, method = "rf")
test_learner(Lrnr_caret, task_binaryY, method = "rf")
test_learner(Lrnr_caret, task_catY, method = "rf")
test_learner(Lrnr_caret, task, method = "xgbLinear")
test_learner(Lrnr_caret, task_binaryY, method = "xgbLinear")
test_learner(Lrnr_caret, task_catY, method = "xgbLinear")

test_that("Lrnr_caret RF match caret RF preds for continuous outcome", {
  ## instantiate Lrnr_caret, train on task, and predict on task
  lrnr_caret_rf <- Lrnr_caret$new(method = "rf", seed = 123)
  set.seed(1530)
  fit_lrnr_caret_rf <- lrnr_caret_rf$train(task)
  prd_lrnr_caret_rf <- fit_lrnr_caret_rf$predict()
  rmse_sl3 <- sqrt(mean((prd_lrnr_caret_rf - task$Y)^2))

  ## fit caret RF using the data from the task
  set.seed(1530)
  fit_caret_rf <- caret::train(
    x = task$X, y = task$Y, method = "rf", metric = "RMSE",
    trControl = caret::trainControl(
      method = "cv",
      indexOut = fit_lrnr_caret_rf$fit_object$control$indexOut
    ),
    seed = 123
  )
  prd_caret_rf <- as.numeric(predict(fit_caret_rf, newdata = task$X))
  rmse_classic <- sqrt(mean((prd_caret_rf - task$Y)^2))

  expect_equal(rmse_sl3, rmse_classic, tolerance = 0.1)
})

test_that("Lrnr_caret RF match caret RF preds for binary classification", {
  ## instantiate Lrnr_caret, train on task, and predict on task
  lrnr_caret_rf <- Lrnr_caret$new(method = "rf")
  set.seed(1530)
  fit_lrnr_caret_rf <- lrnr_caret_rf$train(task_binaryY)
  prd_lrnr_caret_rf <- fit_lrnr_caret_rf$predict()
  prd_lrnr_caret_rf <- as.numeric(prd_lrnr_caret_rf > 0.5)

  ## fit caret RF using the data from the task
  set.seed(1530)
  fit_caret_rf <- suppressWarnings(caret::train(
    x = task_binaryY$X, y = as.factor(task_binaryY$Y), method = "rf",
    metric = "Accuracy",
    trControl = caret::trainControl(
      method = "cv", indexOut = fit_lrnr_caret_rf$fit_object$control$indexOut
    )
  ))
  prd_caret_rf <- as.numeric(
    predict(fit_caret_rf, newdata = task$X, type = "prob")[, 2]
  )
  prd_caret_rf <- as.numeric(prd_caret_rf > 0.5)

  expect_equal(sum(abs(prd_lrnr_caret_rf - prd_caret_rf)), 0, tolerance = 1)
})

test_that("Lrnr_caret RF preds match caret RF preds for categorical outcome", {
  ## instantiate Lrnr_caret, train on task, and predict on task
  lrnr_caret_rf <- Lrnr_caret$new(method = "rf")
  set.seed(1530)
  fit_lrnr_caret_rf <- lrnr_caret_rf$train(task_catY)
  prd_lrnr_caret_rf <- fit_lrnr_caret_rf$predict()
  prd_lrnr_caret_rf <- unpack_predictions(fit_lrnr_caret_rf$predict())
  prd_lrnr_caret_rf <- max.col(prd_lrnr_caret_rf)

  ## fit caret RF using the data from the task
  set.seed(1530)
  fit_caret_rf <- suppressWarnings(caret::train(
    x = task_catY$X, y = task_catY$Y, method = "rf", metric = "Accuracy",
    trControl = caret::trainControl(
      method = "cv",
      indexOut = fit_lrnr_caret_rf$fit_object$control$indexOut
    )
  ))
  prd_caret_rf <- pack_predictions(
    predict(fit_caret_rf, newdata = task$X, type = "prob")
  )
  prd_caret_rf <- unpack_predictions(prd_caret_rf)
  prd_caret_rf <- max.col(prd_caret_rf)

  expect_equal(prd_lrnr_caret_rf, prd_caret_rf, tolerance = 1)
})

test_that("Lrnr_caret RF preds match caret RF preds for binary regression", {
  ## instantiate Lrnr_caret, train on task, and predict on task
  lrnr_caret_rf <- Lrnr_caret$new(
    method = "rf", metric = "RMSE",
    factor_binary_outcome = FALSE
  )
  set.seed(1530)
  fit_lrnr_caret_rf <- lrnr_caret_rf$train(task_binaryY)
  prd_lrnr_caret_rf <- fit_lrnr_caret_rf$predict()
  rmse_sl3 <- sqrt(mean((prd_lrnr_caret_rf - task_binaryY$Y)^2))

  ## fit caret RF using the data from the task
  set.seed(1530)
  fit_caret_rf <- suppressWarnings(caret::train(
    x = task_binaryY$X, y = task_binaryY$Y, method = "rf",
    metric = "RMSE", trControl = caret::trainControl(
      method = "cv",
      indexOut = fit_lrnr_caret_rf$fit_object$control$indexOut
    )
  ))
  prd_caret_rf <- as.numeric(predict(fit_caret_rf, newdata = task$X))
  rmse_classic <- sqrt(mean((prd_caret_rf - task_binaryY$Y)^2))

  expect_equal(rmse_sl3, rmse_classic, tolerance = 0.1)
})
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.