tests/testthat/test-expressions.R

library(parsnip)

skip_if_not_installed("modeldata")
library(modeldata)

skip_if_not_installed("ranger")
library(ranger)

skip_if_not_installed("kernlab")
library(kernlab)

skip_if_not_installed("nnet")
library(nnet)

## -----------------------------------------------------------------------------

data("penguins", package = "modeldata")

penguins <-
  penguins |>
  na.omit() |>
  dplyr::select(
    bill_length_mm,
    bill_depth_mm,
    flipper_length_mm,
    body_mass_g,
    species
  )

data(two_class_dat, package = "modeldata")

## -----------------------------------------------------------------------------

test_that("linear regression", {
  reg_model <-
    linear_reg(penalty = .1) |>
    set_engine("glmnet") |>
    set_mode("regression") |>
    fit(
      body_mass_g ~ bill_length_mm + bill_depth_mm + flipper_length_mm,
      data = penguins
    )

  reg_preds_parsnip <- predict(reg_model, penguins)

  reg_eqns <- prediction_eqn(reg_model)

  reg_preds_eqns <- stack_predict(reg_eqns, data = penguins)

  expect_equal(reg_preds_parsnip, reg_preds_eqns)
})


## -----------------------------------------------------------------------------

test_that("logistic regression", {
  bin_model <-
    logistic_reg(penalty = .1) |>
    set_engine("glmnet") |>
    set_mode("classification") |>
    fit(Class ~ ., data = two_class_dat)

  bin_cls_eqns <- prediction_eqn(bin_model, type = "class")
  bin_cls_parsnip <- predict(bin_model, two_class_dat, type = "class")
  bin_cls_eqns <- stack_predict(bin_cls_eqns, data = two_class_dat)
  expect_equal(bin_cls_parsnip, bin_cls_eqns)

  bin_prob_eqns <- prediction_eqn(bin_model, type = "prob")
  bin_prob_parsnip <- predict(bin_model, two_class_dat, type = "prob")
  bin_prob_eqns <- stack_predict(bin_prob_eqns, data = two_class_dat)
  expect_equal(bin_prob_parsnip, bin_prob_eqns)
})


## -----------------------------------------------------------------------------

test_that("multiclass regression", {
  mltn_model <-
    multinom_reg(penalty = .1) |>
    set_engine("glmnet") |>
    set_mode("classification") |>
    fit(species ~ ., data = penguins)

  mltn_cls_eqns <- prediction_eqn(mltn_model, type = "class")
  mltn_cls_parsnip <- predict(mltn_model, penguins, type = "class")
  mltn_cls_eqns <- stack_predict(mltn_cls_eqns, data = penguins)
  expect_equal(mltn_cls_parsnip, mltn_cls_eqns)

  mltn_prob_eqns <- prediction_eqn(mltn_model, type = "prob")
  mltn_prob_parsnip <- predict(mltn_model, penguins, type = "prob")
  mltn_prob_eqns <- stack_predict(mltn_prob_eqns, data = penguins)
  expect_equal(mltn_prob_parsnip, mltn_prob_eqns)
})

Try the stacks package in your browser

Any scripts or data that you put into this service are public.

stacks documentation built on June 10, 2025, 9:14 a.m.