Nothing
test_that("binary roc curve uses equivalent of pROC `direction = <`", {
# In yardstick we do events (or cases) as the first level
truth <- factor(c("control", "case", "case"), levels = c("case", "control"))
# Make really bad predictions
# This would force `direction = "auto"` to choose `>`,
# which would be incorrect. We are required to force `direction = <` for
# our purposes of having `estimate` match the event
estimate <- c(.8, .2, .1)
# # pROC expects levels to be in the order of control, then event.
# roc <- pROC::roc(
# truth,
# estimate,
# levels = c("control", "case"),
# direction = "<"
# )
# expect_specificity <- c(0, roc$specificities)
# expect_sensitivity <- c(1, roc$sensitivities)
expect_specificity <- c(0, 0, 0, 0, 1)
expect_sensitivity <- c(1, 1, 0.5, 0, 0)
curve <- roc_curve_vec(truth, estimate)
expect_identical(curve$specificity, expect_specificity)
expect_identical(curve$sensitivity, expect_sensitivity)
})
test_that("ROC Curve", {
pROC_two_class_example_curve <- data_pROC_two_class_example_curve()
# Equal to pROC up to a reasonable tolerance
expect_equal(
roc_curve(two_class_example, truth, Class1),
pROC_two_class_example_curve
)
})
test_that("Multiclass ROC Curve", {
# HPC_CV takes too long
hpc_cv2 <- dplyr::filter(hpc_cv, Resample %in% c("Fold06", "Fold07", "Fold08", "Fold09", "Fold10"))
res <- roc_curve(hpc_cv2, obs, VF:L)
# structural tests
expect_equal(colnames(res), c(".level", ".threshold", "specificity", "sensitivity"))
expect_equal(unique(res$.level), levels(hpc_cv2$obs))
res_g <- roc_curve(dplyr::group_by(hpc_cv2, Resample), obs, VF:L)
# structural tests
expect_equal(colnames(res_g), c("Resample", ".level", ".threshold", "specificity", "sensitivity"))
})
# ------------------------------------------------------------------------------
test_that("grouped multiclass (one-vs-all) weighted example matches expanded equivalent", {
hpc_cv$weight <- rep(1, times = nrow(hpc_cv))
hpc_cv$weight[c(100, 200, 150, 2)] <- 5
hpc_cv <- dplyr::group_by(hpc_cv, Resample)
hpc_cv_expanded <- hpc_cv[vec_rep_each(seq_len(nrow(hpc_cv)), times = hpc_cv$weight), ]
expect_identical(
roc_curve(hpc_cv, obs, VF:L, case_weights = weight),
roc_curve(hpc_cv_expanded, obs, VF:L)
)
})
test_that("can use hardhat case weights", {
two_class_example$weight <- read_weights_two_class_example()
curve1 <- roc_curve(two_class_example, truth, Class1, case_weights = weight)
two_class_example$weight <- hardhat::importance_weights(two_class_example$weight)
curve2 <- roc_curve(two_class_example, truth, Class1, case_weights = weight)
expect_identical(curve1, curve2)
})
# ------------------------------------------------------------------------------
test_that("zero weights don't affect the curve", {
# If they weren't removed, we'd get a `NaN` from a division by zero issue
df <- dplyr::tibble(
truth = factor(c("b", "a", "b", "a", "a"), levels = c("a", "b")),
a = c(.75, .7, .4, .9, .8),
weight = c(0, 1, 3, 0, 5)
)
expect_identical(
roc_curve(df, truth, a, case_weights = weight),
roc_curve(df[df$weight != 0, ], truth, a, case_weights = weight)
)
})
# ------------------------------------------------------------------------------
test_that("Binary results are the same as scikit-learn", {
curve <- roc_curve(two_class_example, truth, Class1)
expect_identical(
curve,
read_pydata("py-roc-curve")$binary
)
})
test_that("Binary weighted results are the same as scikit-learn", {
two_class_example$weight <- read_weights_two_class_example()
curve <- roc_curve(two_class_example, truth, Class1, case_weights = weight)
expect_identical(
curve,
read_pydata("py-roc-curve")$case_weight$binary
)
})
# ------------------------------------------------------------------------------
test_that("roc_curve() - error is thrown when missing events", {
no_event <- dplyr::filter(two_class_example, truth == "Class2")
expect_snapshot(
error = TRUE,
roc_curve_vec(no_event$truth, no_event$Class1)[[".estimate"]]
)
})
test_that("roc_curve() - error is thrown when missing controls", {
no_control <- dplyr::filter(two_class_example, truth == "Class1")
expect_snapshot(
error = TRUE,
roc_curve_vec(no_control$truth, no_control$Class1)[[".estimate"]]
)
})
test_that("roc_curve() - multiclass one-vs-all approach results in error", {
no_event <- dplyr::filter(hpc_cv, Resample == "Fold01", obs == "VF")
expect_snapshot(
error = TRUE,
roc_curve_vec(no_event$obs, as.matrix(dplyr::select(no_event, VF:L)))[[".estimate"]]
)
})
test_that("roc_curve() - `options` is deprecated", {
skip_if(getRversion() <= "3.5.3", "Base R used a different deprecated warning class.")
rlang::local_options(lifecycle_verbosity = "warning")
expect_snapshot({
out <- roc_curve(two_class_example, truth, Class1, options = 1)
})
expect_identical(
out,
roc_curve(two_class_example, truth, Class1)
)
})
test_that("errors with class_pred input", {
skip_if_not_installed("probably")
cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1)
fct_truth <- two_class_example$truth
fct_truth[1] <- NA
estimate <- two_class_example$Class1
expect_snapshot(
error = TRUE,
roc_curve_vec(cp_truth, estimate)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.