tests/testthat/test-glmtree.R

library(partykit)
context("test-Lrnr_glmtree.R -- General testing for GLMtree")

# define test dataset
data(mtcars)

test_that("Lrnr_glmtree continuous outcome preds match those from glmtree", {
  task <- sl3_Task$new(mtcars, covariates = c(
    "cyl", "disp", "hp", "drat", "wt", "qsec",
    "vs", "am", "gear", "carb"
  ), outcome = "mpg")

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new()
  fit_lrnr_glmtree <- lrnr_glmtree$train(task)
  prd_lrnr_glmtree <- fit_lrnr_glmtree$predict()

  ## fit glmtree using the data from the task
  fit_glmtree <- glmtree(mpg ~ ., data = task$data)
  prd_glmtree <- predict(fit_glmtree, newdata = task$data)

  ## test equivalence of prediction from Lrnr_glmtree and glmtree::glmtree
  expect_equal(prd_lrnr_glmtree, as.numeric(prd_glmtree))
})

test_that("Lrnr_glmtree binary outcome preds match those from glmtree", {
  task <- sl3_Task$new(mtcars, covariates = c(
    "cyl", "disp", "hp", "drat", "wt", "qsec",
    "mpg", "am", "gear", "carb"
  ), outcome = "vs")

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new()
  fit_lrnr_glmtree <- lrnr_glmtree$train(task)
  prd_lrnr_glmtree <- fit_lrnr_glmtree$predict()

  ## fit glmtree using the data from the task
  fit_glmtree <- glmtree(vs ~ ., data = task$data)
  prd_glmtree <- predict(fit_glmtree, newdata = task$data)

  ## test equivalence of prediction from Lrnr_glmtree and glmtree::glmtree
  expect_equal(prd_lrnr_glmtree, as.numeric(prd_glmtree))
})

test_that("Lrnr_glmtree includes offset correctly", {
  task <- sl3_Task$new(mtcars,
    covariates = c("disp", "hp", "wt"),
    outcome = "mpg",
    offset = "drat"
  )

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new(alpha = 0.9, prune = "AIC")
  fit_lrnr_glmtree <- lrnr_glmtree$train(task)
  prd_lrnr_glmtree <- fit_lrnr_glmtree$predict()

  # fit glmtree with same specification as sl3
  d <- task$data
  fit_glmtree <- partykit::glmtree(
    formula = as.formula(mpg ~ offset(drat) | disp + hp + wt), data = d,
    alpha = 0.9, prune = "AIC"
  )
  prd_glmtree <- predict(fit_glmtree, newdata = d)

  ## test equivalence of prediction from Lrnr_glmtree and partykit::glmtree
  expect_equal(prd_lrnr_glmtree, as.numeric(prd_glmtree))
})

test_that("Lrnr_glmtree errors when covariates in formula are misspecified", {
  task <- sl3_Task$new(mtcars,
    covariates = c("disp", "hp", "wt"),
    outcome = "mpg",
    offset = "drat"
  )

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new(formula = "mpg ~ drat + disp")
  expect_error(lrnr_glmtree$train(task))
})

test_that("Lrnr_glmtree errors when outcome in formula is misspecified", {
  task <- sl3_Task$new(mtcars,
    covariates = c("disp", "hp", "wt"),
    outcome = "mpg",
    offset = "drat"
  )

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new(formula = "hp ~ wt + disp")
  expect_error(lrnr_glmtree$train(task))
})

test_that("Lrnr_glmtree errors when offset not in user-specified formula", {
  task <- sl3_Task$new(mtcars,
    covariates = c("disp", "hp", "wt"),
    outcome = "mpg",
    offset = "drat"
  )

  ## instantiate Lrnr_glmtree, train on task, and predict on task
  lrnr_glmtree <- Lrnr_glmtree$new(formula = "mpg ~ wt + disp + hp")
  expect_error(lrnr_glmtree$train(task))
})
tlverse/sl3 documentation built on Nov. 18, 2024, 12:46 a.m.