revdep/checks.noindex/tabnet/old/tabnet.Rcheck/tests/testthat/test-explain.R

test_that("explain provides correct result with data.frame", {

  set.seed(2022)
  torch::torch_manual_seed(2022)

  n <- 2000
  x <- data.frame(
    x = rnorm(n),
    y = rnorm(n),
    z = rnorm(n)
  )

  y <- x$x

  fit <- tabnet_fit(x, y, epochs = 15,
                    num_steps = 1,
                    batch_size = 512,
                    attention_width = 1,
                    num_shared = 1,
                    num_independent = 1)

  expect_equal(which.max(fit$fit$importances$importance), 1)
  expect_equal(fit$fit$importances$variables, colnames(x))

  ex <- tabnet_explain(fit, x)

  expect_length(ex, 2)
  expect_length(ex[[2]], 1)
  expect_equal(nrow(ex[[1]]), nrow(x))
  expect_equal(nrow(ex[[2]][[1]]), nrow(x))

})

test_that("explain works for dataframe, formula and recipe", {

  # data.frame, regression
  expect_no_error(
    tabnet_explain(ames_pretrain_vsplit, new_data=small_ames)
  )

  expect_no_error(
    tabnet_explain(ames_fit_vsplit, new_data=small_ames)
  )

  # data.frame, classification
  expect_no_error(
    tabnet_explain(attr_pretrained_vsplit, attrix)
  )
  expect_no_error(
    tabnet_explain(attr_fitted_vsplit, attrix)
  )


  # formula
  tabnet_pretrain <- tabnet_pretrain(Sale_Price ~., data=small_ames, epochs = 3, valid_split=.2,
                                     num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1)
  expect_no_error(
    tabnet_explain(tabnet_pretrain, new_data=small_ames)
  )

  tabnet_fit <- tabnet_fit(Sale_Price ~., data=small_ames, tabnet_model=tabnet_pretrain, epochs = 3,
                           num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1)
  expect_no_error(
    tabnet_explain(tabnet_fit, new_data=small_ames)
  )

  # recipe
  rec <- recipe(Sale_Price ~., data = small_ames) %>%
    step_zv(all_predictors()) %>%
    step_normalize(all_numeric_predictors())

  tabnet_pretrain <- tabnet_pretrain(rec, data=small_ames, epochs = 3, valid_split=.2,
                                     num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1)
  expect_no_error(
    tabnet_explain(tabnet_pretrain, new_data=small_ames)
  )

  tabnet_fit <- tabnet_fit(rec, data=small_ames, tabnet_model=tabnet_pretrain, epochs = 3,
                           num_steps = 1, attention_width = 1, num_shared = 1, num_independent = 1)
  expect_no_error(
    tabnet_explain(tabnet_fit, new_data=small_ames)
  )
})

test_that("support for vip on tabnet_fit and tabnet_pretrain", {

  skip_if_not_installed("vip")

  n <- 1000
  x <- data.frame(
    x = runif(n),
    y = runif(n),
    z = runif(n)
  )

  y <- x$x

  pretrain <- tabnet_pretrain(x, y, epochs = 1,
                    num_steps = 1,
                    batch_size = 512,
                    attention_width = 1,
                    num_shared = 1,
                    num_independent = 1)

  fit <- tabnet_fit(x, y, epochs = 1,
                    num_steps = 1,
                    batch_size = 512,
                    attention_width = 1,
                    num_shared = 1,
                    num_independent = 1)

  expect_no_error(vip::vip(pretrain))
  expect_no_error(vip::vip(fit))

})


test_that("Importance is skipped if skip_importance flag is used", {

  set.seed(2022)
  torch::torch_manual_seed(2022)

  n <- 1000
  x <- data.frame(
    x = rnorm(n),
    y = rnorm(n),
    z = rnorm(n)
  )

  y <- x$x

  fit <- tabnet_fit(x, y, epochs = 15,
                    num_steps = 1,
                    batch_size = 512,
                    attention_width = 1,
                    num_shared = 1,
                    num_independent = 1,
                    skip_importance = TRUE)

  expect_equal(fit$fit$importances, NULL)

  fit <- tabnet_fit(x, y, epochs = 15,
                    num_steps = 1,
                    batch_size = 512,
                    attention_width = 1,
                    num_shared = 1,
                    num_independent = 1,
                    skip_importance = FALSE)


  expect_equal(which.max(fit$fit$importances$importance), 1)
  expect_equal(fit$fit$importances$variables, colnames(x))

})

test_that("explain works for parsnip model", {

  model <- tabnet() %>%
    parsnip::set_mode("regression") %>%
    parsnip::set_engine("torch")
  fit <- model %>%
    parsnip::fit(Sale_Price ~ ., data = small_ames)

  expect_no_error(
    tabnet_explain(fit, new_data = small_ames),
  )

})

test_that("explain works for multi-outcome classification model", {

  fit <- tabnet_fit(x, data.frame(y = y, z = y + 1), epochs = 1)

  expect_no_error(tabnet_explain(fit, new_data = x))

})
AFIT-R/vip documentation built on Aug. 22, 2023, 8:59 a.m.