Nothing
test_that("parsnip fit model works", {
# default params
expect_no_error(
model <- tabnet() %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
)
expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = small_ames)
)
# some setup params
expect_no_error(
model <- tabnet(epochs = 2, learn_rate = 1e-5) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
)
expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = small_ames)
)
# new batch of setup params
expect_no_error(
model <- tabnet(penalty = 0.2, verbose = FALSE, early_stopping_tolerance = 1e-3) %>%
parsnip::set_mode("classification") %>%
parsnip::set_engine("torch")
)
expect_no_error(
fit <- model %>%
parsnip::fit(Overall_Cond ~ ., data = small_ames)
)
})
test_that("parsnip fit model works from a pretrained model", {
# default params
expect_no_error(
model <- tabnet(tabnet_model = ames_pretrain, from_epoch = 1, epoch = 1) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
)
expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = small_ames)
)
})
test_that("multi_predict works as expected", {
model <- tabnet(checkpoint_epoch = 1) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = small_ames)
)
preds <- parsnip::multi_predict(fit, small_ames, epochs = c(1,2,3,4,5))
expect_equal(nrow(preds), nrow(small_ames))
expect_equal(nrow(preds$.pred[[1]]), 5)
})
test_that("Check we can finalize a workflow", {
model <- tabnet(penalty = tune(), epochs = tune()) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
wf <- workflows::workflow() %>%
workflows::add_model(model) %>%
workflows::add_formula(Sale_Price ~ .)
wf <- tune::finalize_workflow(wf, tibble::tibble(penalty = 0.01, epochs = 1))
expect_no_error(
fit <- wf %>% parsnip::fit(data = small_ames)
)
expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$penalty), 0.01)
expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$epochs), 1)
})
test_that("Check we can finalize a workflow from a tune_grid", {
model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
wf <- workflows::workflow() %>%
workflows::add_model(model) %>%
workflows::add_formula(Sale_Price ~ .)
custom_grid <- tidyr::crossing(epochs = c(1,2,3))
cv_folds <- small_ames %>%
rsample::vfold_cv(v = 2, repeats = 1)
at <- tune::tune_grid(
object = wf,
resamples = cv_folds,
grid = custom_grid,
metrics = yardstick::metric_set(yardstick::rmse),
control = tune::control_grid(verbose = F)
)
best_rmse <- tune::select_best(at, metric = "rmse")
expect_no_error(
final_wf <- tune::finalize_workflow(wf, best_rmse)
)
})
test_that("tabnet grid reduction - torch", {
mod <- tabnet() %>%
parsnip::set_engine("torch")
# A typical grid
reg_grid <- expand.grid(epochs = 1:3, penalty = 1:2)
reg_grid_smol <- tune::min_grid(mod, reg_grid)
expect_equal(reg_grid_smol$epochs, rep(3, 2))
expect_equal(reg_grid_smol$penalty, 1:2)
for (i in 1:nrow(reg_grid_smol)) {
expect_equal(reg_grid_smol$.submodels[[i]], list(epochs = 1:2))
}
# Unbalanced grid
reg_ish_grid <- expand.grid(epochs = 1:3, penalty = 1:2)[-3, ]
reg_ish_grid_smol <- tune::min_grid(mod, reg_ish_grid)
expect_equal(reg_ish_grid_smol$epochs, 2:3)
expect_equal(reg_ish_grid_smol$penalty, 1:2)
for (i in 2:nrow(reg_ish_grid_smol)) {
expect_equal(reg_ish_grid_smol$.submodels[[i]], list(epochs = 1:2))
}
# Grid with a third parameter
reg_grid_extra <- expand.grid(epochs = 1:3, penalty = 1:2, batch_size = 10:12)
reg_grid_extra_smol <- tune::min_grid(mod, reg_grid_extra)
expect_equal(reg_grid_extra_smol$epochs, rep(3, 6))
expect_equal(reg_grid_extra_smol$penalty, rep(1:2, each = 3))
expect_equal(reg_grid_extra_smol$batch_size, rep(10:12, 2))
for (i in 1:nrow(reg_grid_extra_smol)) {
expect_equal(reg_grid_extra_smol$.submodels[[i]], list(epochs = 1:2))
}
# Only epochs
only_epochs <- expand.grid(epochs = 1:3)
only_epochs_smol <- tune::min_grid(mod, only_epochs)
expect_equal(only_epochs_smol$epochs, 3)
expect_equal(only_epochs_smol$.submodels, list(list(epochs = 1:2)))
# No submodels
no_sub <- tibble::tibble(epochs = 1, penalty = 1:2)
no_sub_smol <- tune::min_grid(mod, no_sub)
expect_equal(no_sub_smol$epochs, rep(1, 2))
expect_equal(no_sub_smol$penalty, 1:2)
for (i in 1:nrow(no_sub_smol)) {
expect_length(no_sub_smol$.submodels[[i]], 0)
}
# different id names
mod_1 <- tabnet(epochs = tune("Amos")) %>%
parsnip::set_engine("torch")
reg_grid <- expand.grid(Amos = 1:3, penalty = 1:2)
reg_grid_smol <- tune::min_grid(mod_1, reg_grid)
expect_equal(reg_grid_smol$Amos, rep(3, 2))
expect_equal(reg_grid_smol$penalty, 1:2)
for (i in 1:nrow(reg_grid_smol)) {
expect_equal(reg_grid_smol$.submodels[[i]], list(Amos = 1:2))
}
all_sub <- expand.grid(Amos = 1:3)
all_sub_smol <- tune::min_grid(mod_1, all_sub)
expect_equal(all_sub_smol$Amos, 3)
expect_equal(all_sub_smol$.submodels[[1]], list(Amos = 1:2))
mod_2 <- tabnet(epochs = tune("Ade Tukunbo")) %>%
parsnip::set_engine("torch")
reg_grid <- expand.grid(`Ade Tukunbo` = 1:3, penalty = 1:2, ` \t123` = 10:11)
reg_grid_smol <- tune::min_grid(mod_2, reg_grid)
expect_equal(reg_grid_smol$`Ade Tukunbo`, rep(3, 4))
expect_equal(reg_grid_smol$penalty, rep(1:2, each = 2))
expect_equal(reg_grid_smol$` \t123`, rep(10:11, 2))
for (i in 1:nrow(reg_grid_smol)) {
expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2))
}
})
test_that("Check workflow can use case_weight", {
small_ames_cw <- small_ames %>% dplyr::mutate(case_weight = hardhat::frequency_weights(Year_Sold))
model <- tabnet(epochs = 2) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
wf <- workflows::workflow() %>%
workflows::add_model(model) %>%
workflows::add_formula(Sale_Price ~ .) %>%
workflows::add_case_weights(case_weight)
expect_no_error(
fit <- wf %>% parsnip::fit(data = small_ames_cw)
)
expect_no_error(
predict(fit, small_ames)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.