tests/testthat/test-nnet.R

test_that('check nnet parsnip interface', {
  skip_if_not_installed("nnet")
  skip_if_not_installed("modeldata")

  data(two_class_dat, package = "modeldata")

  set.seed(4779)
  expect_error(
    reg_mod <- bag_mlp() %>%
      set_engine("nnet", times = 3) %>%
      set_mode("regression") %>%
      fit(mpg ~ ., data = mtcars),
    regexp = NA
  )
  expect_true(
    all(purrr::map_lgl(reg_mod$fit$model_df$model, ~ inherits(.x, "model_fit")))
  )
  expect_true(
    all(purrr::map_lgl(reg_mod$fit$model_df$model, ~ inherits(.x$fit, "nnet")))
  )
  expect_error(
    reg_mod_pred <- predict(reg_mod, mtcars[1:5, -1]),
    regexp = NA
  )
  expect_true(tibble::is_tibble(reg_mod_pred))
  expect_equal(nrow(reg_mod_pred), 5)
  expect_equal(names(reg_mod_pred), ".pred")

  set.seed(4779)
  expect_error(
    class_cost <- bag_mlp() %>%
      set_engine("nnet", times = 3) %>%
      set_mode("classification") %>%
      fit(Class ~ ., data = two_class_dat),
    regexp = NA
  )
  expect_true(
    all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x, "model_fit")))
  )
  expect_true(
    all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x$fit, "nnet")))
  )
  expect_error(
    class_cost_pred <- predict(class_cost, two_class_dat[1:5, -3]),
    regexp = NA
  )
  expect_true(tibble::is_tibble(class_cost_pred))
  expect_equal(nrow(class_cost_pred), 5)
  expect_equal(names(class_cost_pred), ".pred_class")

  expect_error(
    class_cost_prob <- predict(class_cost, two_class_dat[1:5, -3], type = "prob"),
    regexp = NA
  )
  expect_true(tibble::is_tibble(class_cost_prob))
  expect_equal(nrow(class_cost_prob), 5)
  expect_equal(names(class_cost_prob), c(".pred_Class1", ".pred_Class2"))

})

test_that('mode specific package dependencies', {
  expect_identical(
    get_from_env(paste0("bag_mlp", "_pkgs")) %>%
      dplyr::filter(engine == "nnet", mode == "classification") %>%
      dplyr::pull(pkg),
    list(c("nnet", "baguette"))
  )

  expect_identical(
    get_from_env(paste0("bag_mlp", "_pkgs")) %>%
      dplyr::filter(engine == "nnet", mode == "regression") %>%
      dplyr::pull(pkg),
    list(c("nnet", "baguette"))
  )
})

test_that('variable importance', {
  skip_if_not_installed("nnet")

  # See inst/helper-objects-for-testing.R
  # Values from another implementation
  exp_vip <-
    tibble::tribble(
 ~predictor,      ~importance,
      "cyl", 10.4382541964352,
      "disp", 6.40411349529868,
      "hp", 10.5671716879969,
      "drat",  11.816366389055,
      "wt", 9.78795915963821,
      "qsec", 18.3934915487232,
      "vs", 6.28589044459608,
      "am", 5.70032668136104,
      "gear", 12.9721190483202,
      "carb",  7.6343073485756
    )

  set.seed(1)
  reg_mod <- nnet::nnet(mpg ~ ., data = mtcars, size = 3, trace = FALSE)
  baguette_imp <- baguette:::nnet_imp_garson(reg_mod)
  expect_equal(exp_vip, baguette_imp, tolerance = 0.0001)

})

Try the baguette package in your browser

Any scripts or data that you put into this service are public.

baguette documentation built on April 4, 2025, 12:22 a.m.