tests/testthat/test_expressions.R

library(parsnip)

skip_if_not_installed("modeldata")
library(modeldata)

skip_if_not_installed("ranger")
library(ranger)

skip_if_not_installed("kernlab")
library(kernlab)

skip_if_not_installed("nnet")
library(nnet)

## -----------------------------------------------------------------------------

data("penguins", package = "modeldata")

penguins <- 
  penguins %>% 
  na.omit() %>% 
  dplyr::select(bill_length_mm, bill_depth_mm, flipper_length_mm, body_mass_g, species)

data(two_class_dat, package = "modeldata")

## -----------------------------------------------------------------------------

test_that("linear regression", {

  reg_model <- 
    linear_reg(penalty = .1) %>% 
    set_engine("glmnet") %>% 
    set_mode("regression") %>% 
    fit(body_mass_g ~ bill_length_mm + bill_depth_mm + flipper_length_mm, data = penguins)
  
  reg_preds_parsnip <- predict(reg_model, penguins)
  
  reg_eqns <- prediction_eqn(reg_model)
  
  reg_preds_eqns <- stack_predict(reg_eqns, data = penguins)
  
  expect_equal(reg_preds_parsnip, reg_preds_eqns)
  
})


## -----------------------------------------------------------------------------

test_that("logistic regression", {
  
  bin_model <- 
    logistic_reg(penalty = .1) %>% 
    set_engine("glmnet") %>% 
    set_mode("classification") %>% 
    fit(Class ~ ., data = two_class_dat)
  
  bin_cls_eqns <- prediction_eqn(bin_model, type = "class")
  bin_cls_parsnip <- predict(bin_model, two_class_dat, type = "class")
  bin_cls_eqns <- stack_predict(bin_cls_eqns, data = two_class_dat)
  expect_equal(bin_cls_parsnip, bin_cls_eqns)
  
  bin_prob_eqns <- prediction_eqn(bin_model, type = "prob")
  bin_prob_parsnip <- predict(bin_model, two_class_dat, type = "prob")
  bin_prob_eqns <- stack_predict(bin_prob_eqns, data = two_class_dat)
  expect_equal(bin_prob_parsnip, bin_prob_eqns)
  
})


## -----------------------------------------------------------------------------

test_that("multiclass regression", {
  
  mltn_model <- 
    multinom_reg(penalty = .1) %>% 
    set_engine("glmnet") %>% 
    set_mode("classification") %>% 
    fit(species ~ ., data = penguins)
  
  mltn_cls_eqns <- prediction_eqn(mltn_model, type = "class")
  mltn_cls_parsnip <- predict(mltn_model, penguins, type = "class")
  mltn_cls_eqns <- stack_predict(mltn_cls_eqns, data = penguins)
  expect_equal(mltn_cls_parsnip, mltn_cls_eqns)
  
  mltn_prob_eqns <- prediction_eqn(mltn_model, type = "prob")
  mltn_prob_parsnip <- predict(mltn_model, penguins, type = "prob")
  mltn_prob_eqns <- stack_predict(mltn_prob_eqns, data = penguins)
  expect_equal(mltn_prob_parsnip, mltn_prob_eqns)
  
})

Try the stacks package in your browser

Any scripts or data that you put into this service are public.

stacks documentation built on Nov. 6, 2023, 5:08 p.m.