tests/testthat/test_multinom_reg_nnet.R

hpc <- hpc_data[1:150, c(2:5, 8)]

# ------------------------------------------------------------------------------

set.seed(352)
dat <- hpc[order(runif(150)),]

tr_dat <- dat[1:140, ]
te_dat <- dat[141:150, ]

# ------------------------------------------------------------------------------

basic_mod <-
  multinom_reg() %>%
  set_engine("nnet", penalty = .1)

ctrl <- control_parsnip(verbosity = 0, catch = FALSE)

# ------------------------------------------------------------------------------

test_that('model fitting', {
  skip_if_not_installed("nnet")

  set.seed(257)
  expect_error(
    fit1 <-
      fit_xy(
        basic_mod,
        control = ctrl,
        x = tr_dat[, -5],
        y = tr_dat$class
      ),
    regexp = NA
  )

  set.seed(257)
  expect_error(
    fit2 <-
      fit_xy(
        basic_mod,
        control = ctrl,
        x = tr_dat[, -5],
        y = tr_dat$class
      ),
    regexp = NA
  )
  fit1$elapsed <- fit2$elapsed
  expect_equal(fit1, fit2, ignore_formula_env = TRUE)

  expect_error(
    fit(
      basic_mod,
      class ~ .,
      data = tr_dat,
      control = ctrl
    ),
    regexp = NA
  )

})


test_that('classification prediction', {
  skip_if_not_installed("nnet")

  set.seed(257)
  lr_fit <-
    fit_xy(
      basic_mod,
      control = ctrl,
      x = tr_dat[, -5],
      y = tr_dat$class
    )

  nnet_pred <-
    predict(extract_fit_engine(lr_fit), as.matrix(te_dat[, -5]))

  parsnip_pred <- predict(lr_fit, te_dat[, -5])
  expect_equal(nnet_pred, parsnip_pred$.pred_class)

})


test_that('classification probabilities', {
  skip_if_not_installed("nnet")

  set.seed(257)
  lr_fit <-
    fit_xy(
      basic_mod,
      control = ctrl,
      x = tr_dat[, -5],
      y = tr_dat$class
    )

  nnet_pred <-
    predict(extract_fit_engine(lr_fit), as.matrix(te_dat[, -5]), type = "prob") %>%
    as_tibble(.name_repair = "minimal") %>%
    setNames(paste0(".pred_", lr_fit$lvl))

  parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob")
  expect_equal(as.data.frame(nnet_pred), as.data.frame(parsnip_pred))

})

test_that('prob prediction with 1 row', {
  # For issue 612
  skip_if_not_installed("nnet")

  set.seed(257)
  lr_fit <-
    fit_xy(
      basic_mod,
      control = ctrl,
      x = tr_dat[, -5],
      y = tr_dat$class
    )

  nnet_pred <-
    predict(extract_fit_engine(lr_fit), as.matrix(te_dat[1, -5]), type = "prob") %>%
    as.matrix() %>%
    t() %>%
    tibble::as_tibble(.name_repair = "minimal") %>%
    setNames(paste0(".pred_", lr_fit$lvl))

  parsnip_pred <- predict(lr_fit, te_dat[1, -5], type = "prob")

  expect_equal(nnet_pred, parsnip_pred)
  expect_identical(nrow(parsnip_pred), 1L)
})
tidymodels/parsnip documentation built on March 25, 2024, 10:17 p.m.