tests/testthat/test-collect-metrics.R

check_metric_results <- function(ind, x, ...) {
  id_val <- x$wflow_id[ind]

  if (any(names(list(...)) == "summarize")) {
    cols <- c(".metric", ".estimator", ".estimate", ".config", "id")
  } else {
    cols <- c(".metric", ".estimator", "mean", "n", "std_err", ".config")
  }

  orig <-
    collect_metrics(x$result[[ind]], ...) %>%
    dplyr::select(dplyr::all_of(cols))

  everythng <-
    collect_metrics(x, ...) %>%
    dplyr::filter(wflow_id == id_val) %>%
    dplyr::select(dplyr::all_of(cols))
  all.equal(orig, everythng)
}

# ------------------------------------------------------------------------------

test_that("collect summarized metrics", {
  for (i in 1:nrow(two_class_res)) {
    expect_true(check_metric_results(i, two_class_res))
  }
  for (i in 1:nrow(chi_features_res)) {
    expect_true(check_metric_results(i, chi_features_res))
  }

  for (i in 1:nrow(two_class_res)) {
    expect_true(check_metric_results(i, two_class_res, summarize = FALSE))
  }
  for (i in 1:nrow(chi_features_res)) {
    expect_true(check_metric_results(i, chi_features_res, summarize = FALSE))
  }
})

test_that("ranking models", {
  # expected number of rows per metric per model
  param_lines <-
    c(
      none_cart = 10, none_glm = 1, none_mars = 2,
      yj_trans_cart = 10, yj_trans_glm = 1, yj_trans_mars = 2
    )

  expect_error(ranking_1 <- rank_results(two_class_res), regexp = NA)
  expect_equal(nrow(ranking_1), sum(param_lines * 2))

  expect_error(ranking_2 <- rank_results(two_class_res, select_best = TRUE), regexp = NA)
  expect_equal(nrow(ranking_2), nrow(two_class_res) * 2)
})

Try the workflowsets package in your browser

Any scripts or data that you put into this service are public.

workflowsets documentation built on April 7, 2023, 1:05 a.m.