tests/testthat/test-mda.R

test_that("mda + predict() works", {
  skip_on_cran()
  skip_if_not_installed("mda")
  suppressPackageStartupMessages(library(mda))
  fit <- mda(Species ~ ., data = iris)
  x <- axe_call(fit)
  expect_equal(x$call, rlang::expr(dummy_call()))
  expect_error(update(x, subclasses = 4))
  expect_equal(attr(x, "butcher_disabled"),
               c("print()", "summary()", "update()"))
  x <- axe_data(fit)
  expect_identical(x, fit)
  x <- axe_env(fit)
  expect_identical(attr(x$terms, ".Environment"), rlang::base_env())
  x <- axe_fitted(fit)
  expect_equal(x$fit$fitted.values, matrix(NA))
  x <- butcher(fit)
  expect_equal(predict(x, iris[1:3, ]),
               predict(fit, iris[1:3, ]))
})

test_that("mda + custom parsnip model + predict() works", {
  skip_on_cran()
  skip_if_not_installed("mda")
  skip_if_not_installed("parsnip")
  suppressPackageStartupMessages(library(mda))
  suppressPackageStartupMessages(library(parsnip))
  # Create a custom parsnip model using mda engine
  set_new_model("mixture_da")
  set_model_mode(model = "mixture_da", mode = "classification")
  set_model_engine(
    "mixture_da",
    mode = "classification",
    eng = "mda"
  )
  set_model_arg(
    model = "mixture_da",
    eng = "mda",
    parsnip = "sub_classes",
    original = "subclasses",
    func = list(pkg = "foo", fun = "bar"),
    has_submodel = FALSE
  )
  mixture_da <- function(mode = "classification",  sub_classes = NULL) {
    if (mode  != "classification") {
      stop("`mode` should be 'classification'", call. = FALSE)
    }
    args <- list(sub_classes = rlang::enquo(sub_classes))
    new_model_spec(
      "mixture_da",
      args = args,
      eng_args = NULL,
      mode = mode,
      method = NULL,
      engine = NULL
    )
  }
  set_fit(
    model = "mixture_da",
    eng = "mda",
    mode = "classification",
    value = list(
      interface = "formula",
      protect = c("formula", "data"),
      func = c(pkg = "mda", fun = "mda"),
      defaults = list()
    )
  )
  set_encoding(
    model = "mixture_da",
    eng = "mda",
    mode = "classification",
    options = list(
      predictor_indicators = "traditional",
      compute_intercept = TRUE,
      remove_intercept = TRUE,
      allow_sparse_x = FALSE
    )
  )
  mda_fit <- mixture_da(sub_classes = 2) %>%
    set_engine("mda") %>%
    fit(Species ~ ., data = iris)
  x <- axe_call(mda_fit)
  expect_equal(x$fit$call, rlang::expr(dummy_call()))
  expect_error(update(x$fit, subclasses = 4))
  expect_equal(attr(x$fit, "butcher_disabled"),
               c("print()", "summary()", "update()"))
  x <- axe_env(mda_fit)
  expect_identical(attr(x$fit$terms, ".Environment"), rlang::base_env())
  x <- axe_fitted(mda_fit)
  expect_equal(x$fit$fit$fitted.values, matrix(NA))
})

test_that("fda + predict() works", {
  skip_on_cran()
  skip_if_not_installed("mda")
  suppressPackageStartupMessages(library(mda))
  mtcars$cyl <- as.factor(mtcars$cyl)
  fit <- fda(cyl ~ ., data = mtcars)
  x <- axe_call(fit)
  expect_equal(x$call, rlang::expr(dummy_call()))
  expect_error(update(x, subclasses = 4))
  expect_equal(attr(x, "butcher_disabled"),
               c("print()", "summary()", "update()"))
  x <- axe_data(fit)
  expect_identical(x, fit)
  x <- axe_env(fit)
  expect_identical(attr(x$terms, ".Environment"), rlang::base_env())
  x <- axe_fitted(fit)
  expect_equal(x$fit$fitted.values, matrix(NA))
  x <- butcher(fit)
  expect_equal(predict(x, mtcars[1:3, ]),
               predict(fit, mtcars[1:3, ]))
})
tidymodels/butcher documentation built on April 15, 2024, 9:18 p.m.