tests/testthat/test-colcheck.R

library(testthat)
library(recipes)

rp1 <- recipe(mtcars, cyl ~ .)
rp2 <- recipe(mtcars, cyl ~ mpg + drat)

test_that("check_col works in the prep stage", {
  expect_error(rp1 %>% check_cols(everything()) %>% prep(), NA)
  expect_error(rp2 %>% check_cols(everything()) %>% prep(), NA)
  expect_error(rp2 %>% check_cols(cyl, mpg, drat) %>% prep(), NA)
  expect_error(rp2 %>% check_cols(cyl, mpg) %>% prep(), NA)
})


test_that("check_col works in the bake stage", {

  expect_error(rp1 %>% check_cols(everything()) %>% prep() %>% bake(mtcars),
               NA)
  expect_equal(rp1 %>% check_cols(everything()) %>% prep() %>% bake(mtcars),
               tibble(mtcars[ ,c(1, 3:11, 2)]))
  expect_error(rp2 %>% check_cols(cyl, mpg, drat) %>% prep %>% bake(mtcars), NA)
  expect_equal(rp2 %>% check_cols(cyl, mpg, drat) %>% prep %>% bake(mtcars),
               tibble(mtcars[ ,c(1, 5, 2)]))

  expect_error(
    rp1 %>% check_cols(everything()) %>% prep() %>% bake(mtcars),
    NA
  )
  expect_equal(
    rp1 %>% check_cols(everything()) %>% prep() %>% bake(mtcars),
    tibble(mtcars[, c(1, 3:11, 2)])
  )
  expect_error(rp2 %>% check_cols(cyl, mpg, drat) %>% prep() %>% bake(mtcars), NA)
  expect_equal(
    rp2 %>% check_cols(cyl, mpg, drat) %>% prep() %>% bake(mtcars),
    tibble(mtcars[, c(1, 5, 2)])
  )
  expect_snapshot(error = TRUE,
    rp1 %>% check_cols(everything()) %>% prep() %>% bake(mtcars[-1])
  )
  expect_snapshot(error = TRUE,
    rp2 %>% check_cols(cyl, mpg, drat) %>% prep() %>%
      bake(mtcars[, c(2, 5)])
  )
})

test_that("non-standard roles during bake/predict", {
  skip_if_not_installed("modeldata")
  skip_if_not_installed("workflows")
  skip_if_not_installed("parsnip")
  skip_if(packageVersion("workflows") < "0.2.6.9001")
  skip_if(packageVersion("parsnip") < "1.0.0")

  # ----------------------------------------------------------------------------

  library(workflows)
  library(parsnip)

  # ----------------------------------------------------------------------------

  data(Chicago, package = "modeldata", envir = current_env())

  Chicago <- Chicago %>% dplyr::select(ridership, date, Austin, Belmont)

  set.seed(1)
  Chicago$wts <- importance_weights(runif(nrow(Chicago)))

  # ----------------------------------------------------------------------------

  base_wflow <-
    workflow() %>%
    add_model(linear_reg())

  # ----------------------------------------------------------------------------
  # non-standard role used in a step

  ## no case weights, default blueprint
  role_rec <-
    recipe(ridership ~ date + Austin + Belmont, data = Chicago) %>%
    update_role(date, new_role = "date") %>%
    step_date(date)

  role_wflow <-
    base_wflow %>%
    add_recipe(role_rec)

  role_fit <- fit(role_wflow, data = Chicago)

  expect_error(predict(role_fit, head(Chicago)), NA)

  # This should require 'date' to predict.
  # The error comes from hardhat, so we don't snapshot it because we don't own it.
  expect_error(predict(role_fit, Chicago %>% select(-date)))

  # ----------------------------------------------------------------------------
  # non-standard role used in a step and case weights

  role_wts_rec <-
    recipe(ridership ~ ., data = Chicago) %>%
    update_role(date, new_role = "date") %>%
    step_date(date)

  role_wts_wflow <-
    base_wflow %>%
    add_recipe(role_wts_rec) %>%
    add_case_weights(wts)

  role_wts_fit <- fit(role_wts_wflow, data = Chicago)

  # This should require 'date' but not 'wts' to predict
  expect_error(predict(role_wts_fit, head(Chicago)), NA)
  expect_error(predict(role_wts_fit, head(Chicago) %>% select(-wts)), NA)
  expect_error(predict(role_wts_fit, head(Chicago) %>% select(-date)))

  # ----------------------------------------------------------------------------
  # Removing variable after use

  rm_rec <-
    recipe(ridership ~ date + Austin + Belmont, data = Chicago) %>%
    step_date(date, keep_original_cols = FALSE)

  rm_wflow <-
    base_wflow %>%
    add_recipe(rm_rec)

  rm_fit <- fit(rm_wflow, data = Chicago)

  # This should require 'date' to predict
  expect_error(predict(rm_fit, Chicago %>% select(-date)))

  # ----------------------------------------------------------------------------
  # Removing variable after use, with case weights

  rm_wts_rec <-
    recipe(ridership ~ ., data = Chicago) %>%
    step_date(date, keep_original_cols = FALSE)

  rm_wts_wflow <-
    base_wflow %>%
    add_recipe(rm_wts_rec) %>%
    add_case_weights(wts)

  rm_wts_fit <- fit(rm_wts_wflow, data = Chicago)

  # This should require 'date' but not 'wts' to predict
  expect_error(predict(rm_fit, Chicago %>% select(-wts)), NA)
  expect_error(predict(rm_fit, Chicago %>% select(-date)))
})

# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
  # Here for completeness
  # the main check_cols() purpose of this function is to test for this
  # event
  expect_true(TRUE)
})

test_that("empty printing", {
  rec <- recipe(mpg ~ ., mtcars)
  rec <- check_cols(rec)

  expect_snapshot(rec)

  rec <- prep(rec, mtcars)

  expect_snapshot(rec)
})

test_that("empty selection prep/bake is a no-op", {
  rec1 <- recipe(mpg ~ ., mtcars)
  rec2 <- check_cols(rec1)

  rec1 <- prep(rec1, mtcars)
  rec2 <- prep(rec2, mtcars)

  baked1 <- bake(rec1, mtcars)
  baked2 <- bake(rec2, mtcars)

  expect_identical(baked1, baked2)
})

test_that("empty selection tidy method works", {
  rec <- recipe(mpg ~ ., mtcars)
  rec <- check_cols(rec)

  expect <- tibble(terms = character(), id = character())

  expect_identical(tidy(rec, number = 1), expect)

  rec <- prep(rec, mtcars)

  expect_identical(tidy(rec, number = 1), expect)
})

test_that("printing", {
  rec <- recipe(mpg ~ ., mtcars) %>%
    check_cols(everything())

  expect_snapshot(print(rec))
  expect_snapshot(prep(rec))
})

Try the recipes package in your browser

Any scripts or data that you put into this service are public.

recipes documentation built on Aug. 26, 2023, 1:08 a.m.