Nothing
test_that('check nnet parsnip interface', {
skip_if_not_installed("nnet")
skip_if_not_installed("modeldata")
data(two_class_dat, package = "modeldata")
set.seed(4779)
expect_error(
reg_mod <- bag_mlp() %>%
set_engine("nnet", times = 3) %>%
set_mode("regression") %>%
fit(mpg ~ ., data = mtcars),
regexp = NA
)
expect_true(
all(purrr::map_lgl(reg_mod$fit$model_df$model, ~ inherits(.x, "model_fit")))
)
expect_true(
all(purrr::map_lgl(reg_mod$fit$model_df$model, ~ inherits(.x$fit, "nnet")))
)
expect_error(
reg_mod_pred <- predict(reg_mod, mtcars[1:5, -1]),
regexp = NA
)
expect_true(tibble::is_tibble(reg_mod_pred))
expect_equal(nrow(reg_mod_pred), 5)
expect_equal(names(reg_mod_pred), ".pred")
set.seed(4779)
expect_error(
class_cost <- bag_mlp() %>%
set_engine("nnet", times = 3) %>%
set_mode("classification") %>%
fit(Class ~ ., data = two_class_dat),
regexp = NA
)
expect_true(
all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x, "model_fit")))
)
expect_true(
all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x$fit, "nnet")))
)
expect_error(
class_cost_pred <- predict(class_cost, two_class_dat[1:5, -3]),
regexp = NA
)
expect_true(tibble::is_tibble(class_cost_pred))
expect_equal(nrow(class_cost_pred), 5)
expect_equal(names(class_cost_pred), ".pred_class")
expect_error(
class_cost_prob <- predict(class_cost, two_class_dat[1:5, -3], type = "prob"),
regexp = NA
)
expect_true(tibble::is_tibble(class_cost_prob))
expect_equal(nrow(class_cost_prob), 5)
expect_equal(names(class_cost_prob), c(".pred_Class1", ".pred_Class2"))
})
test_that('mode specific package dependencies', {
expect_identical(
get_from_env(paste0("bag_mlp", "_pkgs")) %>%
dplyr::filter(engine == "nnet", mode == "classification") %>%
dplyr::pull(pkg),
list(c("nnet", "baguette"))
)
expect_identical(
get_from_env(paste0("bag_mlp", "_pkgs")) %>%
dplyr::filter(engine == "nnet", mode == "regression") %>%
dplyr::pull(pkg),
list(c("nnet", "baguette"))
)
})
test_that('variable importance', {
skip_if_not_installed("nnet")
# See inst/helper-objects-for-testing.R
# Values from another implementation
exp_vip <-
tibble::tribble(
~predictor, ~importance,
"cyl", 10.4382541964352,
"disp", 6.40411349529868,
"hp", 10.5671716879969,
"drat", 11.816366389055,
"wt", 9.78795915963821,
"qsec", 18.3934915487232,
"vs", 6.28589044459608,
"am", 5.70032668136104,
"gear", 12.9721190483202,
"carb", 7.6343073485756
)
set.seed(1)
reg_mod <- nnet::nnet(mpg ~ ., data = mtcars, size = 3, trace = FALSE)
baguette_imp <- baguette:::nnet_imp_garson(reg_mod)
expect_equal(exp_vip, baguette_imp, tolerance = 0.0001)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.