# ------------------------------------------------------------------------------
# add_case_weights()
test_that("case weights + formula uses weights in the model", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ x)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$fit$fit$fit$weights, 1L)
test_that("case weights + variables uses weights in the model", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_variables(wf, y, x)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$fit$fit$fit$weights, 1L)
test_that("case weights + recipe uses weights in the model", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
# Step that might use case weights
rec <- recipes::step_center(rec, x)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$fit$fit$fit$weights, 1L)
expect_named(wf$pre$mold$outcomes, "y")
expect_named(wf$pre$mold$predictors, "x")
expect_named(wf$pre$mold$extras$roles$case_weights, "w")
test_that("case weights are used with model formula override", {
df <- vctrs::data_frame(y = 1, x = 2, z = 3, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + z + w, df)
# Step that might use case weights
rec <- recipes::step_center(rec, x)
wf <- workflow()
wf <- add_model(wf, spec, formula = y ~ x)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$fit$fit$fit$weights, 1L)
# Importantly, supplying a model formula that uses `~ .` doesn't include weights
wf <- update_model(wf, spec, formula = y ~ .)
wf <- fit(wf, df)
expect_named(wf$fit$fit$fit$coefficients, c("(Intercept)", "x", "z"))
test_that("case weights + formula doesn't need case weights at predict time", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ x)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
df$y <- NULL
df$w <- NULL
predict(wf, df)$.pred,
c(1/6, 2, 1/6)
test_that("case weights + variables doesn't need case weights at predict time", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_variables(wf, y, x)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
df$y <- NULL
df$w <- NULL
predict(wf, df)$.pred,
c(1/6, 2, 1/6)
test_that("case weights + recipe doesn't need case weights at predict time", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
df$y <- NULL
df$w <- NULL
predict(wf, df)$.pred,
c(1/6, 2, 1/6)
test_that("case weights + recipe can optionally require case weights at predict time", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::update_role_requirements(rec, "case_weights", bake = TRUE)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
df$y <- NULL
predict(wf, df)$.pred,
c(1/6, 2, 1/6)
df$w <- NULL
expect_snapshot(error = TRUE, {
predict(wf, df)
test_that("case weights + recipe requires extra roles at predict time by default", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = c(2, 5, 10)
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
# In this recipe, `w` is used as some `id` role that is actually required at
# `bake()` time
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::update_role(rec, w, new_role = "id")
rec <- recipes::step_center(rec, w)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- fit(wf, df)
df$w - mean(df$w)
df$y <- NULL
predict(wf, df)$.pred,
c(1/2, 2, 1/2)
df$w <- NULL
expect_snapshot(error = TRUE, {
predict(wf, df)
test_that("case weights + recipe can optionally not require extra roles at predict time", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 2),
w = c(2, 5, 10)
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
# In this recipe, `w` is used as some `id` role that isn't actually needed
# at `bake()` time
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::update_role(rec, w, new_role = "id")
rec <- recipes::update_role_requirements(rec, "id", bake = FALSE)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- fit(wf, df)
df$y <- NULL
predict(wf, df)$.pred,
c(1/2, 2, 1/2)
df$w <- NULL
# Works without `w`
predict(wf, df)$.pred,
c(1/2, 2, 1/2)
test_that("case weights + recipe updates the case weights if the recipe filters rows", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 3),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
# The filter drops one of the rows, so the case weights need to be updated accordingly
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::step_filter(rec, x > 2)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
df$w[c(2, 3)]
df$w[c(2, 3)]
test_that("case weights + recipe doesn't allow the recipe to drop the case weights column", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 3),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::step_mutate(rec, w = NULL)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
expect_snapshot(error = TRUE, {
fit(wf, df)
test_that("case weights + recipe doesn't allow the recipe to adjust the case weights column class", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 3),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::step_mutate(rec, w = unclass(w))
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
expect_snapshot(error = TRUE, {
fit(wf, df)
test_that("case weights + recipe doesn't allow the recipe to change the name of the case weights column", {
df <- vctrs::data_frame(
y = c(1, 2, 0),
x = c(2, 5, 3),
w = hardhat::frequency_weights(c(2, 5, 10))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
# Adds a new case weights column, but removes the old one
rec <- recipes::recipe(y ~ x + w, df)
rec <- recipes::step_mutate(rec, w2 = w, w = NULL, role = "case_weights")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
expect_snapshot(error = TRUE, {
fit(wf, df)
test_that("case weights + formula removes weights before formula evaluation", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ .)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$pre$case_weights, df$w)
expect_false("w" %in% names(wf$pre$mold$predictors))
test_that("case weights + variables removes weights before variables evaluation", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_variables(wf, y, everything())
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
expect_identical(wf$pre$case_weights, df$w)
expect_false("w" %in% names(wf$pre$mold$predictors))
test_that("case weights + recipe retains weights for use in recipe", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
rec <- recipes::recipe(y ~ x + w, df)
# Step that might use case weights
rec <- recipes::step_center(rec, x)
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
# recipe won't run unless the `w` column is there
wf <- fit(wf, df)
expect_identical(wf$pre$case_weights, df$w)
expect_identical(wf$pre$mold$extras$roles$case_weights$w, df$w)
test_that("case weights added after preprocessors get reordered", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
rec <- recipes::recipe(y ~ x + w, df)
wf <- workflow()
wf <- add_formula(wf, y ~ x)
wf <- add_case_weights(wf, w)
# Order matters
expect_identical(names(wf$pre$actions), c("case_weights", "formula"))
wf <- workflow()
wf <- add_recipe(wf, rec)
wf <- add_case_weights(wf, w)
# Order matters
expect_identical(names(wf$pre$actions), c("case_weights", "recipe"))
wf <- workflow()
wf <- add_variables(wf, y, x)
wf <- add_case_weights(wf, w)
# Order matters
expect_identical(names(wf$pre$actions), c("case_weights", "variables"))
test_that("case weights `col` must exist in `data`", {
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, mpg ~ .)
wf <- add_case_weights(wf, foo)
# Tidyselect error
expect_error(fit(wf, mtcars))
test_that("case weights `col` can't select >1 columns in `data`", {
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, mpg ~ .)
wf <- add_case_weights(wf, c(cyl, disp))
expect_snapshot(error = TRUE, fit(wf, mtcars))
test_that("case weights must inherit from the base case weights class", {
df <- vctrs::data_frame(y = 1, x = 1, weights = 1)
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ x)
wf <- add_case_weights(wf, weights)
expect_snapshot(error = TRUE, fit(wf, df))
# ------------------------------------------------------------------------------
# remove_case_weights()
test_that("can remove case weights (and keep preprocessor)", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
wf <- workflow()
wf <- add_formula(wf, y ~ .)
wf <- add_case_weights(wf, w)
wf <- remove_case_weights(wf)
expect_identical(names(wf$pre$actions), "formula")
test_that("removing case weights resets model, mold, and case-weights slots", {
df <- vctrs::data_frame(y = 1, x = 2, w = hardhat::frequency_weights(1))
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ .)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
wf <- remove_case_weights(wf)
# ------------------------------------------------------------------------------
# update_case_weights()
test_that("updating case weights resets model, mold, and case-weights slots", {
df <- vctrs::data_frame(
y = 1,
x = 2,
w = hardhat::frequency_weights(1),
z = hardhat::frequency_weights(2)
spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")
wf <- workflow()
wf <- add_model(wf, spec)
wf <- add_formula(wf, y ~ x)
wf <- add_case_weights(wf, w)
wf <- fit(wf, df)
wf <- update_case_weights(wf, z)
expect_identical(wf$pre$actions$case_weights$col, quo(z))
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.