tests/testthat/test-extract.R

library(parsnip)
library(rsample)
library(recipes)
data(Chicago, package = "modeldata")

lr_spec <- linear_reg() %>% set_engine("lm")

set.seed(1)
car_set_1 <-
  workflow_set(
    list(
      reg = recipe(mpg ~ ., data = mtcars) %>% step_log(disp),
      nonlin = mpg ~ wt + 1 / sqrt(disp)
    ),
    list(lm = lr_spec)
  ) %>%
  workflow_map("fit_resamples",
    resamples = vfold_cv(mtcars, v = 3),
    control = tune::control_resamples(save_pred = TRUE)
  )

# ------------------------------------------------------------------------------

test_that("extracts", {
  # workflows specific errors, so we don't capture their messages
  expect_error(extract_fit_engine(car_set_1, id = "reg_lm"))
  expect_error(extract_fit_parsnip(car_set_1, id = "reg_lm"))
  expect_error(extract_mold(car_set_1, id = "reg_lm"))
  expect_error(extract_recipe(car_set_1, id = "reg_lm"))

  expect_s3_class(
    extract_preprocessor(car_set_1, id = "reg_lm"),
    "recipe"
  )
  expect_s3_class(
    extract_spec_parsnip(car_set_1, id = "reg_lm"),
    "model_spec"
  )
  expect_s3_class(
    extract_workflow(car_set_1, id = "reg_lm"),
    "workflow"
  )
  expect_s3_class(
    extract_recipe(car_set_1, id = "reg_lm", estimated = FALSE),
    "recipe"
  )

  expect_equal(
    car_set_1 %>% extract_workflow("reg_lm"),
    car_set_1$info[[1]]$workflow[[1]]
  )

  expect_equal(
    car_set_1 %>% extract_workflow_set_result("reg_lm"),
    car_set_1$result[[1]]
  )

  expect_snapshot(error = TRUE, {
     car_set_1 %>% extract_workflow_set_result("Gideon Nav")
  })

  expect_snapshot(error = TRUE, {
     car_set_1 %>% extract_workflow("Coronabeth Tridentarius")
  })
})


test_that("extract parameter set from workflow set with untunable workflow", {
   rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
      recipes::step_rm(date, ends_with("away"))
   lm_model <- parsnip::linear_reg() %>%
      parsnip::set_engine("lm")
   bst_model <-
      parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
      parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
   wf_set <- workflow_set(
      list(reg = rm_rec),
      list(lm = lm_model, bst = bst_model)
   )

   lm_info <- hardhat::extract_parameter_set_dials(wf_set, id = "reg_lm")
   check_parameter_set_tibble(lm_info)
   expect_equal(nrow(lm_info), 0)
})

test_that("extract parameter set from workflow set with tunable workflow", {
   rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
      recipes::step_rm(date, ends_with("away"))
   lm_model <- parsnip::linear_reg() %>%
      parsnip::set_engine("lm")
   bst_model <-
      parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
      parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
   wf_set <- workflow_set(
      list(reg = rm_rec),
      list(lm = lm_model, bst = bst_model)
   )

   c5_info <- extract_parameter_set_dials(wf_set, id = "reg_bst")
   expect_equal(
      c5_info,
      extract_parameter_set_dials(bst_model)
   )
   check_parameter_set_tibble(c5_info)
   expect_equal(nrow(c5_info), 2)
   expect_true(all(c5_info$source == "model_spec"))
   expect_true(all(c5_info$component == "boost_tree"))
   expect_equal(c5_info$component_id, c("main", "engine"))
   nms <- c("trees", "rules")
   expect_equal(c5_info$name, nms)
   ids <- c("funky name \n", "rules")
   expect_equal(c5_info$id, ids)

   expect_equal(c5_info$object[[1]], dials::trees(c(1, 100)))
   expect_equal(c5_info$object[[2]], NA)

   c5_new_info <-
      c5_info %>%
      update(
         rules = dials::new_qual_param("logical",
                                       values = c(TRUE, FALSE),
                                       label = c(rules = "Rules"))
      )

   wf_set_2 <-
      wf_set %>%
      option_add(id = "reg_bst", param_info = c5_new_info)

   check_parameter_set_tibble(c5_new_info)
   expect_s3_class(c5_new_info$object[[2]], "qual_param")
   expect_equal(
      c5_new_info,
      extract_parameter_set_dials(wf_set_2, "reg_bst")
   )
})



test_that("extract single parameter from workflow set with untunable workflow", {
   rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
      recipes::step_rm(date, ends_with("away"))
   lm_model <- parsnip::linear_reg() %>%
      parsnip::set_engine("lm")
   bst_model <-
      parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
      parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
   wf_set <- workflow_set(
      list(reg = rm_rec),
      list(lm = lm_model, bst = bst_model)
   )

   expect_error(
      hardhat::extract_parameter_dials(wf_set, id = "reg_lm", parameter = "non there")
   )
})

test_that("extract single parameter from workflow set with tunable workflow", {
   rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
      recipes::step_rm(date, ends_with("away"))
   lm_model <- parsnip::linear_reg() %>%
      parsnip::set_engine("lm")
   bst_model <-
      parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
      parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
   wf_set <- workflow_set(
      list(reg = rm_rec),
      list(lm = lm_model, bst = bst_model)
   )

   expect_equal(
      hardhat::extract_parameter_dials(wf_set, id = "reg_bst", parameter = "funky name \n"),
      dials::trees(c(1, 100))
   )
   expect_equal(
      extract_parameter_dials(wf_set, id = "reg_bst", parameter = "rules"),
      NA
   )
})

Try the workflowsets package in your browser

Any scripts or data that you put into this service are public.

workflowsets documentation built on May 29, 2024, 3:12 a.m.