Nothing
set.seed(1311)
three_class <- data.frame(
obs = iris$Species,
pred = sample(iris$Species, replace = TRUE)
)
probs <- matrix(runif(150 * 3), nrow = 150)
probs <- t(apply(probs, 1, function(x) x / sum(x)))
colnames(probs) <- levels(iris$Species)
three_class <- cbind(three_class, as.data.frame(probs))
###################################################################
test_that("correct metrics returned", {
expect_equal(
metrics(two_class_example, truth, predicted)[[".metric"]],
c("accuracy", "kap")
)
expect_equal(
metrics(two_class_example, truth, predicted, Class1)[[".metric"]],
c("accuracy", "kap", "mn_log_loss", "roc_auc")
)
expect_equal(
metrics(three_class, "obs", "pred", setosa, versicolor, virginica)[[".metric"]],
c("accuracy", "kap", "mn_log_loss", "roc_auc")
)
expect_equal(
metrics(three_class, "obs", "pred", setosa, versicolor, virginica)[[".estimator"]],
c("multiclass", "multiclass", "multiclass", "hand_till")
)
expect_equal(
metrics(solubility_test, solubility, "prediction")[[".metric"]],
c("rmse", "rsq", "mae")
)
})
###################################################################
test_that("bad args", {
expect_snapshot(
error = TRUE,
metrics(two_class_example, truth, Class1)
)
expect_snapshot(
error = TRUE,
metrics(two_class_example, Class1, truth)
)
expect_snapshot(
error = TRUE,
metrics(three_class, "obs", "pred", setosa, versicolor)
)
})
###################################################################
class_res_1 <- dplyr::bind_rows(
accuracy(two_class_example, truth, predicted),
kap(two_class_example, truth, predicted),
mn_log_loss(two_class_example, truth, Class1),
roc_auc(two_class_example, truth, Class1)
)
reg_res_1 <- dplyr::bind_rows(
rmse(solubility_test, solubility, "prediction"),
rsq(solubility_test, solubility, prediction),
mae(solubility_test, solubility, prediction)
)
test_that("correct results", {
class_idx <- which(class_res_1$.metric %in% c("accuracy", "kap"))
expect_equal(
metrics(two_class_example, truth, predicted)[[".estimate"]],
class_res_1[class_idx, ][[".estimate"]]
)
expect_equal(
metrics(two_class_example, truth, predicted, Class1)[[".estimate"]],
class_res_1[[".estimate"]]
)
expect_equal(
metrics(solubility_test, solubility, prediction)[[".estimate"]],
reg_res_1[[".estimate"]]
)
})
###################################################################
test_that("metrics() - `options` is deprecated", {
skip_if(getRversion() <= "3.5.3", "Base R used a different deprecated warning class.")
rlang::local_options(lifecycle_verbosity = "warning")
expect_snapshot({
out <- metrics(two_class_example, truth, predicted, Class1, options = 1)
})
expect_identical(
out,
metrics(two_class_example, truth, predicted, Class1)
)
})
###################################################################
test_that("numeric metric sets", {
reg_set <- metric_set(rmse, rsq, mae)
expect_equal(
reg_set(solubility_test, solubility, prediction),
reg_res_1
)
# ensure helpful messages are printed
expect_snapshot(
error = TRUE,
metric_set(rmse, "x")
)
# Can mix class and class prob together
mixed_set <- metric_set(accuracy, roc_auc)
expect_no_error(
mixed_set(two_class_example, truth, Class1, estimate = predicted)
)
})
test_that("mixing bad metric sets", {
expect_snapshot(
error = TRUE,
metric_set(rmse, accuracy)
)
})
test_that("can mix class and class prob metrics together", {
expect_no_error(
mixed_set <- metric_set(accuracy, roc_auc)
)
expect_no_error(
mixed_set(two_class_example, truth, Class1, estimate = predicted)
)
})
test_that("dynamic survival metric sets", {
my_set <- metric_set(brier_survival)
expect_equal(
my_set(lung_surv, surv_obj, .pred),
brier_survival(lung_surv, surv_obj, .pred)
)
})
test_that("can mix dynamic and static survival metric together", {
expect_no_error(
mixed_set <- metric_set(brier_survival, concordance_survival)
)
expect_no_error(
mixed_set(lung_surv, surv_obj, .pred, estimate = .pred_time)
)
})
test_that("can mix dynamic and static survival metric together", {
expect_no_error(
mixed_set <- metric_set(
brier_survival,
concordance_survival,
brier_survival_integrated
)
)
expect_no_error(
mixed_set(lung_surv, surv_obj, .pred, estimate = .pred_time)
)
})
test_that("can supply `event_level` even with metrics that don't use it", {
df <- two_class_example
df_rev <- df
df_rev$truth <- stats::relevel(df_rev$truth, "Class2")
df_rev$predicted <- stats::relevel(df_rev$predicted, "Class2")
# accuracy doesn't use it, and doesn't have it as an argument
set <- metric_set(accuracy, recall, roc_auc)
expect_equal(
as.data.frame(set(df, truth, Class1, estimate = predicted)),
as.data.frame(set(df_rev, truth, Class1, estimate = predicted, event_level = "second"))
)
})
test_that("metric set functions are classed", {
expect_s3_class(
metric_set(accuracy, roc_auc),
"class_prob_metric_set"
)
expect_s3_class(
metric_set(mae),
"numeric_metric_set"
)
expect_s3_class(
metric_set(accuracy, roc_auc),
"metric_set"
)
expect_s3_class(
metric_set(mae),
"metric_set"
)
})
test_that("metric set functions retain class/prob metric functions", {
fns <- attr(metric_set(accuracy, roc_auc), "metrics")
expect_equal(
names(fns),
c("accuracy", "roc_auc")
)
expect_equal(
class(fns[[1]]),
c("class_metric", "metric", "function")
)
expect_equal(
class(fns[[2]]),
c("prob_metric", "metric", "function")
)
expect_equal(
vapply(fns, function(fn) attr(fn, "direction"), character(1)),
c(accuracy = "maximize", roc_auc = "maximize")
)
})
test_that("metric set functions retain numeric metric functions", {
fns <- attr(metric_set(mae, rmse), "metrics")
expect_equal(
names(fns),
c("mae", "rmse")
)
expect_equal(
class(fns[[1]]),
c("numeric_metric", "metric", "function")
)
expect_equal(
class(fns[[2]]),
c("numeric_metric", "metric", "function")
)
expect_equal(
vapply(fns, function(fn) attr(fn, "direction"), character(1)),
c(mae = "minimize", rmse = "minimize")
)
})
test_that("`metric_set()` labeling remove namespaces", {
x <- metric_set(yardstick::mase, rmse)
expect_identical(names(attr(x, "metrics")), c("mase", "rmse"))
})
test_that("print metric_set works", {
expect_snapshot(metric_set(rmse, rsq, ccc))
})
test_that("metric_set can be coerced to a tibble", {
x <- metric_set(roc_auc, pr_auc, accuracy)
expect_s3_class(dplyr::as_tibble(x), "tbl_df")
})
test_that("`metric_set()` errors contain env name for unknown functions (#128)", {
foobar <- function() {}
# Store env name in `name` attribute for `environmentName()` to find it
env <- rlang::new_environment(parent = globalenv())
attr(env, "name") <- "test"
rlang::fn_env(foobar) <- env
expect_snapshot(
error = TRUE,
metric_set(accuracy, foobar, sens, rlang::abort)
)
expect_snapshot(
error = TRUE,
metric_set(accuracy, foobar, sens, rlang::abort)
)
})
test_that("`metric_set()` gives an informative error for a single non-metric function (#181)", {
foobar <- function() {}
# Store env name in `name` attribute for `environmentName()` to find it
env <- rlang::new_environment(parent = globalenv())
attr(env, "name") <- "test"
rlang::fn_env(foobar) <- env
expect_snapshot(
error = TRUE,
metric_set(foobar)
)
})
test_that("errors informatively for unevaluated metric factories", {
# one bad metric
expect_snapshot(
error = TRUE,
metric_set(demographic_parity)
)
expect_snapshot(
error = TRUE,
metric_set(demographic_parity, roc_auc)
)
# two bad metrics
expect_snapshot(
error = TRUE,
metric_set(demographic_parity, equal_opportunity)
)
expect_snapshot(
error = TRUE,
metric_set(demographic_parity, equal_opportunity, roc_auc)
)
})
test_that("all class metrics - `metric_set()` works with `case_weights`", {
# Mock a metric that doesn't support weights
accuracy_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) {
# Eat the `...` silently
accuracy(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm
)
}
accuracy_no_weights <- new_class_metric(accuracy_no_weights, "maximize")
set <- metric_set(accuracy, accuracy_no_weights)
df <- data.frame(
truth = factor(c("x", "x", "y"), levels = c("x", "y")),
estimate = factor(c("x", "y", "x"), levels = c("x", "y")),
case_weights = c(1L, 1L, 2L)
)
expect_identical(
set(df, truth, estimate = estimate, case_weights = case_weights)[[".estimate"]],
c(1 / 4, 1 / 3)
)
})
test_that("all numeric metrics - `metric_set()` works with `case_weights`", {
# Mock a metric that doesn't support weights
rmse_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) {
# Eat the `...` silently
rmse(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm
)
}
rmse_no_weights <- new_numeric_metric(rmse_no_weights, "minimize")
set <- metric_set(rmse, rmse_no_weights)
solubility_test$weight <- read_weights_solubility_test()
expect <- c(
rmse(solubility_test, solubility, prediction, case_weights = weight)[[".estimate"]],
rmse(solubility_test, solubility, prediction)[[".estimate"]]
)
expect_identical(
set(solubility_test, solubility, prediction, case_weights = weight)[[".estimate"]],
expect
)
})
test_that("class and prob metrics - `metric_set()` works with `case_weights`", {
# Mock a metric that doesn't support weights
accuracy_no_weights <- function(data, truth, estimate, na_rm = TRUE, ...) {
# Eat the `...` silently
accuracy(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm
)
}
accuracy_no_weights <- new_class_metric(accuracy_no_weights, "maximize")
set <- metric_set(accuracy, accuracy_no_weights, roc_auc)
two_class_example$weight <- read_weights_two_class_example()
expect <- c(
accuracy(two_class_example, truth, predicted, case_weights = weight)[[".estimate"]],
accuracy(two_class_example, truth, predicted)[[".estimate"]],
roc_auc(two_class_example, truth, Class1, case_weights = weight)[[".estimate"]]
)
expect_identical(
set(two_class_example, truth, Class1, estimate = predicted, case_weights = weight)[[".estimate"]],
expect
)
})
test_that("propagates 'caused by' error message when specifying the wrong column name", {
set <- metric_set(accuracy, kap)
# There is no `weight` column!
expect_snapshot(error = TRUE, {
set(two_class_example, truth, Class1, estimate = predicted, case_weights = weight)
})
})
test_that("metric_tweak and metric_set plays nicely together (#351)", {
# Classification
multi_ex <- data_three_by_three()
ref <- dplyr::bind_rows(
j_index(multi_ex, estimator = "macro"),
j_index(multi_ex, estimator = "micro")
)
j_index_macro <- metric_tweak("j_index", j_index, estimator = "macro")
j_index_micro <- metric_tweak("j_index", j_index, estimator = "micro")
expect_identical(
metric_set(j_index_macro, j_index_micro)(multi_ex),
ref
)
# Probability
ref <- dplyr::bind_rows(
roc_auc(two_class_example, truth, Class1, event_level = "first"),
roc_auc(two_class_example, truth, Class1, event_level = "second")
)
roc_auc_first <- metric_tweak("roc_auc", roc_auc, event_level = "first")
roc_auc_second <- metric_tweak("roc_auc", roc_auc, event_level = "second")
expect_identical(
metric_set(roc_auc_first, roc_auc_second)(two_class_example, truth, Class1),
ref
)
# regression
ref <- dplyr::bind_rows(
ccc(mtcars, truth = mpg, estimate = disp, bias = TRUE),
ccc(mtcars, truth = mpg, estimate = disp, bias = FALSE)
)
ccc_bias <- metric_tweak("ccc", ccc, bias = TRUE)
ccc_no_bias <- metric_tweak("ccc", ccc, bias = FALSE)
expect_identical(
metric_set(ccc_bias, ccc_no_bias)(mtcars, truth = mpg, estimate = disp),
ref
)
# Static survival
lung_surv_na <- lung_surv
lung_surv_na$.pred_time[1] <- NA
ref <- dplyr::bind_rows(
concordance_survival(lung_surv_na, surv_obj, .pred_time, na_rm = TRUE),
concordance_survival(lung_surv_na, surv_obj, .pred_time, na_rm = FALSE)
)
concordance_survival_na_rm <- metric_tweak(
"concordance_survival",
concordance_survival,
na_rm = TRUE
)
concordance_survival_no_na_rm <- metric_tweak(
"concordance_survival",
concordance_survival,
na_rm = FALSE
)
expect_identical(
metric_set(concordance_survival_na_rm, concordance_survival_no_na_rm)(
lung_surv_na, truth = surv_obj, estimate = .pred_time
),
ref
)
# dynamic survival
lung_surv_na <- lung_surv
lung_surv_na$surv_obj[1] <- NA
ref <- dplyr::bind_rows(
brier_survival(lung_surv_na, surv_obj, .pred, na_rm = TRUE),
brier_survival(lung_surv_na, surv_obj, .pred, na_rm = FALSE)
)
brier_survival_na_rm <- metric_tweak(
"brier_survival",
brier_survival,
na_rm = TRUE
)
brier_survival_no_na_rm <- metric_tweak(
"brier_survival",
brier_survival,
na_rm = FALSE
)
expect_identical(
metric_set(brier_survival_na_rm, brier_survival_no_na_rm)(
lung_surv_na, truth = surv_obj, .pred
),
ref
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.