tests/testthat/test-collect_parameters.R

if ((!on_cran()) || interactive()) {
  if (on_github()) {
    load(paste0(
      Sys.getenv("GITHUB_WORKSPACE"),
      "/tests/testthat/helper_data.Rda"
    ))
  } else {
    load(test_path("helper_data.Rda"))
  }
}

skip_if_not_installed("modeldata")
library(modeldata)

skip_if_not_installed("ranger")
library(ranger)

skip_if_not_installed("kernlab")
library(kernlab)

skip_if_not_installed("nnet")
library(nnet)

test_that("collect_parameters dispatch works", {
  skip_on_cran()

  expect_snapshot(error = TRUE, 1 |> collect_parameters())

  expect_snapshot(error = TRUE, mtcars |> collect_parameters())
})

test_that("collect_parameters errors informatively with bad arguments", {
  skip_on_cran()

  expect_snapshot(
    error = TRUE,
    st_reg_1 |> collect_parameters("the first one")
  )

  expect_snapshot(error = TRUE, stacks() |> collect_parameters("all of them"))
})

test_that("collect_parameters on a data stack works (regression)", {
  skip_on_cran()

  res <- collect_parameters(st_reg_1, "reg_res_svm")
  res2 <- collect_parameters(st_reg_2, "reg_res_sp")
  res3 <-
    collect_parameters(
      stacks() |> add_candidates(reg_res_lr, name = "lr"),
      "lr"
    )

  expect_s3_class(res, "tbl_df")
  expect_s3_class(res2, "tbl_df")
  expect_s3_class(res3, "tbl_df")

  expect_equal(ncol(res), 3)
  expect_equal(nrow(res), 5)

  expect_equal(ncol(res2), 2)
  expect_equal(nrow(res2), 10)

  expect_equal(ncol(res3), 1)
  expect_equal(nrow(res3), 1)
})

test_that("collect_parameters on a model stack works (regression)", {
  skip_on_cran()

  res <- collect_parameters(st_reg_1_, "reg_res_svm")
  res2 <- collect_parameters(st_reg_2 |> blend_predictions(), "reg_res_sp")

  expect_s3_class(res, "tbl_df")

  expect_equal(ncol(res), 4)
  expect_equal(nrow(res), 5)

  expect_equal(ncol(res2), 3)
  expect_equal(nrow(res2), 10)

  expect_true(
    all(
      c(
        "member",
        parsnip::extract_parameter_set_dials(
          st_reg_1_$model_defs$reg_res_svm
        ) |>
          dplyr::pull(id),
        "coef"
      ) %in%
        colnames(res)
    )
  )
})

# collecting parameters on a classification stack is a bit
# trickier, so test separately
test_that("collect_parameters works (classification)", {
  skip_on_cran()

  res <- collect_parameters(st_class_1, "class_res_rf")
  res2 <- collect_parameters(st_class_1 |> blend_predictions(), "class_res_rf")

  expect_s3_class(res, "tbl_df")
  expect_s3_class(res2, "tbl_df")

  expect_equal(ncol(res), 3)
  expect_equal(nrow(res), 10)

  expect_equal(ncol(res2), 6)
  expect_equal(nrow(res2), 57)
})

Try the stacks package in your browser

Any scripts or data that you put into this service are public.

stacks documentation built on June 10, 2025, 9:14 a.m.