tests/testthat/test-collect-predictions.R

skip_on_cran()

# ------------------------------------------------------------------------------

library(parsnip)
suppressPackageStartupMessages(library(rsample))
suppressPackageStartupMessages(library(tune))

# ------------------------------------------------------------------------------

lr_spec <- linear_reg() %>% set_engine("lm")
knn_spec <-
  nearest_neighbor(neighbors = tune()) %>%
  set_engine("kknn") %>%
  set_mode("regression")

set.seed(1)
car_set_1 <-
  workflow_set(
    list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)),
    list(lm = lr_spec)
  ) %>%
  workflow_map("fit_resamples",
    resamples = vfold_cv(mtcars, v = 3),
    control = tune::control_resamples(save_pred = TRUE)
  )

set.seed(1)
car_set_2 <-
  workflow_set(
    list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)),
    list(lm = lr_spec)
  ) %>%
  workflow_map("fit_resamples",
    resamples = vfold_cv(mtcars, v = 3, repeats = 2),
    control = tune::control_resamples(save_pred = TRUE)
  )

set.seed(1)
car_set_3 <-
  workflow_set(
    list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)),
    list(knn = knn_spec)
  ) %>%
  workflow_map("tune_bayes",
    resamples = vfold_cv(mtcars, v = 3, repeats = 2),
    control = tune::control_bayes(save_pred = TRUE),
    seed = 1, iter = 2, initial = 3
  )

car_set_23 <- dplyr::bind_rows(car_set_2, car_set_3)

# ------------------------------------------------------------------------------


check_prediction_results <- function(ind, x, summarize = FALSE, ...) {
  id_val <- x$wflow_id[ind]

  cols <- c(".row", "mpg", ".config", ".pred")

  orig <-
    collect_predictions(x$result[[ind]], summarize = summarize, ...) %>%
    dplyr::select(dplyr::all_of(cols))

  if (any(names(list(...)) == "summarize")) {
    cols <- c(grep("^id", names(orig), value = TRUE), cols)
  }

  everythng <-
    collect_predictions(x, summarize = summarize, ...) %>%
    dplyr::filter(wflow_id == id_val) %>%
    dplyr::select(dplyr::all_of(cols))
  all.equal(orig, everythng)
}

# ------------------------------------------------------------------------------

test_that("collect predictions", {
  expect_error(
    res_car_set_1 <- collect_predictions(car_set_1),
    regexp = NA
  )
  expect_true(nrow(mtcars) * nrow(car_set_1) == nrow(res_car_set_1))

  expect_error(
    res_car_set_2 <- collect_predictions(car_set_2),
    regexp = NA
  )
  expect_true(nrow(mtcars) * nrow(car_set_2) == nrow(res_car_set_2))

  expect_error(
    res_car_set_2_reps <- collect_predictions(car_set_2, summarize = FALSE),
    regexp = NA
  )
  expect_true(nrow(mtcars) * nrow(car_set_2) * 2 == nrow(res_car_set_2_reps))

  expect_error(
    res_car_set_3 <- collect_predictions(car_set_3),
    regexp = NA
  )
  expect_true(nrow(mtcars) * nrow(car_set_2) * 5 == nrow(res_car_set_3))

  expect_error(
    res_car_set_3_reps <- collect_predictions(car_set_3, summarize = FALSE),
    regexp = NA
  )
  expect_true(nrow(mtcars) * nrow(car_set_2) * 5 * 2 == nrow(res_car_set_3_reps))

  # ---------------------------------------------------------------------------
  # These don't seem to get captured by covr
  for (i in 1:nrow(car_set_1)) {
    expect_true(check_prediction_results(i, car_set_1))
  }
  for (i in 1:nrow(car_set_2)) {
    expect_true(check_prediction_results(i, car_set_2))
  }

  for (i in 1:nrow(car_set_1)) {
    expect_true(check_prediction_results(i, car_set_1, summarize = FALSE))
  }
  for (i in 1:nrow(car_set_2)) {
    expect_true(check_prediction_results(i, car_set_2, summarize = FALSE))
  }
})

skip_if(packageVersion("tune") <= "1.1.0")

test_that("dropping tuning parameter columns", {
  expect_equal(
    names(collect_predictions(car_set_1)),
    c("wflow_id", ".config", "preproc", "model", ".row", "mpg", ".pred")
  )
  expect_equal(
    names(collect_predictions(car_set_2)),
    c("wflow_id", ".config", "preproc", "model", ".row", "mpg", ".pred")
  )

  expect_equal(
    names(collect_predictions(car_set_1, summarize = FALSE)),
    c("wflow_id", ".config", "preproc", "model", "id", ".pred", ".row", "mpg")
  )
  expect_equal(
    names(collect_predictions(car_set_2, summarize = FALSE)),
    c("wflow_id", ".config", "preproc", "model", "id", "id2", ".pred", ".row", "mpg")
  )

  expect_error(
    best_iter <- collect_predictions(car_set_3, select_best = TRUE, metric = "rmse"),
    regexp = NA
  )
  expect_true(
    nrow(dplyr::distinct(best_iter[, c(".config", "wflow_id")])) == 2
  )
  expect_error(
    no_param <-
      workflowsets:::select_bare_predictions(car_set_3$result[[1]], metric = "rmse", TRUE),
    regex = NA
  )
  expect_equal(names(no_param), c(".row", "mpg", ".config", ".iter", ".pred"))
})


test_that("mixed object types", {
  expect_true(".iter" %in% names(collect_predictions(car_set_23)))
})

Try the workflowsets package in your browser

Any scripts or data that you put into this service are public.

workflowsets documentation built on April 7, 2023, 1:05 a.m.