test_that("parsnip_model works", {
skip_if_not_installed("parsnip")
library(parsnip)
spec <- parsnip::logistic_reg() |>
set_mode("classification")
result <- parsnip_model(
model_spec = spec,
save_path = tempdir()
)
expect_class(result, "te_parsnip_model")
expect_equal(result@model_spec, spec)
})
test_that("parsnip_model fails for invalid model_specifications", {
skip_if_not_installed("parsnip")
library(parsnip)
tmp <- tempdir()
# regression mode
spec <- parsnip::decision_tree(tree_depth = 30) |>
set_mode("regression") |>
set_engine("rpart")
expect_error(parsnip_model(model_spec = spec, save_path = tmp), "mode")
# no mode set
spec2 <- parsnip::decision_tree(tree_depth = 30) |>
set_engine("rpart")
expect_error(parsnip_model(model_spec = spec2, save_path = tmp), "mode")
# not a parsnip model_spec
spec3 <- list(1:10)
expect_error(parsnip_model(model_spec = spec3, save_path = tmp), "model_spec")
})
test_that("fit_weights_model works for parsnip models", {
skip_if_not_installed("parsnip")
skip_if_not_installed("rpart")
library(parsnip)
set.seed(12345)
save_dir <- withr::local_tempdir(pattern = "model_fitter", tempdir(TRUE))
spec <- parsnip::decision_tree(tree_depth = 5) |>
set_mode("classification") |>
set_engine("rpart")
object <- parsnip_model(model_spec = spec, save_path = save_dir)
result <- fit_weights_model(object, data = data_censored, formula = treatment ~ age, "test_model")
expect_class(result, "te_weights_fitted")
expect_equal(result@label, "test_model")
expect_numeric(result@fitted, len = 725, any.missing = FALSE)
expect_equal(
summary(result@fitted) |> unclass(),
c(
Min. = 0.317343173431734, `1st Qu.` = 0.317343173431734, Median = 0.475247524752475,
Mean = 0.467586206896552, `3rd Qu.` = 0.561224489795918, Max. = 0.662337662337662
)
)
expect_file_exists(result@summary$save_path$path)
saved_model <- readRDS(result@summary$save_path$path)
expect_class(saved_model, c("_rpart", "model_fit"))
})
test_that("fit_weights_model works parsnip logistic regression", {
skip_if_not_installed("parsnip")
library(parsnip)
set.seed(12345)
save_dir <- withr::local_tempdir(pattern = "model_fitter", tempdir(TRUE))
spec <- parsnip::logistic_reg() |>
set_mode("classification") |>
set_engine("glm")
object <- parsnip_model(model_spec = spec, save_path = save_dir)
result <- fit_weights_model(object, data = data_censored, formula = treatment ~ age, "test_model")
expect_class(result, "te_weights_fitted")
expect_equal(result@label, "test_model")
expect_equal(result@summary[["tidy"]]$estimate, c(1.88674470, -0.04206803))
expect_equal(result@summary[["glance"]]$df.null, 724)
expect_numeric(result@fitted, len = 725, any.missing = FALSE)
expect_equal(
summary(result@fitted) |> unclass(),
c(
Min. = 0.198680456313765, `1st Qu.` = 0.384837510987622, Median = 0.456463281021062,
Mean = 0.467586206896568, `3rd Qu.` = 0.55082963205744, Max. = 0.747901620375074
)
)
saved_model <- readRDS(result@summary$save_path$path)
expect_class(saved_model, c("_glm", "model_fit"))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.