tests/testthat/test-te_parsnip.R

test_that("parsnip_model works", {
  skip_if_not_installed("parsnip")
  library(parsnip)

  spec <- parsnip::logistic_reg() |>
    set_mode("classification")

  result <- parsnip_model(
    model_spec = spec,
    save_path = tempdir()
  )
  expect_class(result, "te_parsnip_model")
  expect_equal(result@model_spec, spec)
})

test_that("parsnip_model fails for invalid model_specifications", {
  skip_if_not_installed("parsnip")
  library(parsnip)
  tmp <- tempdir()

  # regression mode
  spec <- parsnip::decision_tree(tree_depth = 30) |>
    set_mode("regression") |>
    set_engine("rpart")
  expect_error(parsnip_model(model_spec = spec, save_path = tmp), "mode")

  # no mode set
  spec2 <- parsnip::decision_tree(tree_depth = 30) |>
    set_engine("rpart")
  expect_error(parsnip_model(model_spec = spec2, save_path = tmp), "mode")

  # not a parsnip model_spec
  spec3 <- list(1:10)
  expect_error(parsnip_model(model_spec = spec3, save_path = tmp), "model_spec")
})

test_that("fit_weights_model works for parsnip models", {
  skip_if_not_installed("parsnip")
  skip_if_not_installed("rpart")
  library(parsnip)
  set.seed(12345)
  save_dir <- withr::local_tempdir(pattern = "model_fitter", tempdir(TRUE))

  spec <- parsnip::decision_tree(tree_depth = 5) |>
    set_mode("classification") |>
    set_engine("rpart")

  object <- parsnip_model(model_spec = spec, save_path = save_dir)
  result <- fit_weights_model(object, data = data_censored, formula = treatment ~ age, "test_model")

  expect_class(result, "te_weights_fitted")
  expect_equal(result@label, "test_model")

  expect_numeric(result@fitted, len = 725, any.missing = FALSE)
  expect_equal(
    summary(result@fitted) |> unclass(),
    c(
      Min. = 0.317343173431734, `1st Qu.` = 0.317343173431734, Median = 0.475247524752475,
      Mean = 0.467586206896552, `3rd Qu.` = 0.561224489795918, Max. = 0.662337662337662
    )
  )
  expect_file_exists(result@summary$save_path$path)
  saved_model <- readRDS(result@summary$save_path$path)
  expect_class(saved_model, c("_rpart", "model_fit"))
})

test_that("fit_weights_model works parsnip logistic regression", {
  skip_if_not_installed("parsnip")
  library(parsnip)
  set.seed(12345)
  save_dir <- withr::local_tempdir(pattern = "model_fitter", tempdir(TRUE))

  spec <- parsnip::logistic_reg() |>
    set_mode("classification") |>
    set_engine("glm")

  object <- parsnip_model(model_spec = spec, save_path = save_dir)
  result <- fit_weights_model(object, data = data_censored, formula = treatment ~ age, "test_model")

  expect_class(result, "te_weights_fitted")
  expect_equal(result@label, "test_model")
  expect_equal(result@summary[["tidy"]]$estimate, c(1.88674470, -0.04206803))
  expect_equal(result@summary[["glance"]]$df.null, 724)

  expect_numeric(result@fitted, len = 725, any.missing = FALSE)
  expect_equal(
    summary(result@fitted) |> unclass(),
    c(
      Min. = 0.198680456313765, `1st Qu.` = 0.384837510987622, Median = 0.456463281021062,
      Mean = 0.467586206896568, `3rd Qu.` = 0.55082963205744, Max. = 0.747901620375074
    )
  )
  saved_model <- readRDS(result@summary$save_path$path)
  expect_class(saved_model, c("_glm", "model_fit"))
})
CAM-Roche/RandomisedTrialsEmulation documentation built on April 14, 2025, 7:44 a.m.