test_that("Binary breaks functions work", {
x10 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "first")
expect_equal(
x10$predicted_midpoint,
seq(0.05, 0.95, by = 0.10)
)
expect_s3_class(
cal_plot_breaks(segment_logistic, Class, .pred_good),
"ggplot"
)
x11 <- .cal_table_breaks(testthat_cal_binary())
expect_equal(
x11$predicted_midpoint,
rep(seq(0.05, 0.95, by = 0.10), times = 8)
)
expect_s3_class(
cal_plot_breaks(testthat_cal_binary()),
"ggplot"
)
brks_configs <-
bin_with_configs() %>% cal_plot_breaks(truth = Class, estimate = .pred_good)
expect_true(has_facet(brks_configs))
})
test_that("Binary breaks functions work with group argument", {
res <- segment_logistic %>%
dplyr::mutate(id = dplyr::row_number() %% 2) %>%
cal_plot_breaks(Class, .pred_good, .by = id)
expect_s3_class(res, "ggplot")
expect_equal(
res$data[0,],
dplyr::tibble(
id = factor(0, levels = paste(0:1)),
predicted_midpoint = double(), event_rate = double(), events = double(),
total = integer(), lower = double(), upper = double()
)
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~predicted_midpoint"
)
expect_equal(
rlang::expr_text(res$mapping$colour),
"~id"
)
expect_equal(
rlang::expr_text(res$mapping$fill),
"~id"
)
expect_equal(
res$labels,
list(x = "Bin Midpoint", y = "Event Rate", colour = "id", fill = "id",
intercept = "intercept", slope = "slope", ymin = "lower",
ymax = "upper")
)
expect_equal(length(res$layers), 4)
expect_snapshot_error(
segment_logistic %>%
dplyr::mutate(group1 = 1, group2 = 2) %>%
cal_plot_breaks(Class, .pred_good, .by = c(group1, group2))
)
})
test_that("Multi-class breaks functions work", {
skip_if_not_installed("modeldata")
x10 <- .cal_table_breaks(species_probs, Species, dplyr::starts_with(".pred"))
expect_equal(
x10$predicted_midpoint,
rep(seq(0.05, 0.95, by = 0.10), times = 3)
)
expect_s3_class(
cal_plot_breaks(species_probs, Species),
"ggplot"
)
x11 <- .cal_table_breaks(testthat_cal_multiclass())
expect_equal(
sort(unique(x11$predicted_midpoint)),
seq(0.05, 0.95, by = 0.10)
)
multi_configs <- cal_plot_breaks(testthat_cal_multiclass())
# should be faceted by .config and class
expect_s3_class(multi_configs, "ggplot")
expect_true(inherits(multi_configs$facet, "FacetGrid"))
expect_error(
cal_plot_breaks(species_probs, Species, event_level = "second")
)
# ------------------------------------------------------------------------------
# multinomial outcome, binary logistic plots
multi_configs_from_tune <-
testthat_cal_multiclass() %>% cal_plot_breaks()
expect_s3_class(multi_configs_from_tune, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid"))
multi_configs_from_df <-
mnl_with_configs() %>% cal_plot_breaks(truth = obs, estimate = c(VF:L))
expect_s3_class(multi_configs_from_df, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_df$facet, "FacetGrid"))
})
test_that("breaks plot function errors - grouped_df", {
expect_snapshot_error(
cal_plot_breaks(dplyr::group_by(mtcars, vs))
)
})
test_that("Binary logistic functions work", {
skip_if_not_installed("modeldata")
x20 <- .cal_table_logistic(segment_logistic, Class, .pred_good)
model20 <- mgcv::gam(Class ~ s(.pred_good, k = 10),
data = segment_logistic,
family = binomial()
)
preds20 <- predict(model20,
data.frame(.pred_good = seq(0, 1, by = .01)),
type = "response"
)
expect_equal(sd(x20$prob), sd(preds20), tolerance = 0.000001)
expect_equal(mean(x20$prob), mean(1 - preds20), tolerance = 0.000001)
x21 <- cal_plot_logistic(segment_logistic, Class, .pred_good)
expect_s3_class(x21, "ggplot")
expect_false(has_facet(x21))
x22 <- .cal_table_logistic(testthat_cal_binary())
x22_1 <- testthat_cal_binary() %>%
tune::collect_predictions(summarize = TRUE) %>%
dplyr::group_by(.config) %>%
dplyr::group_map(~ {
model <- mgcv::gam(
class ~ s(.pred_class_1, k = 10),
data = .x,
family = binomial()
)
preds <- predict(model,
data.frame(.pred_class_1 = seq(0, 1, by = .01)),
type = "response"
)
1 - preds
}) %>%
purrr::reduce(c)
expect_equal(sd(x22$prob), sd(x22_1), tolerance = 0.000001)
expect_equal(mean(x22$prob), mean(x22_1), tolerance = 0.000001)
x23 <- cal_plot_logistic(testthat_cal_binary())
expect_s3_class(x23, "ggplot")
expect_true(has_facet(x23))
x24 <- .cal_table_logistic(segment_logistic, Class, .pred_good, smooth = FALSE)
model24 <- stats::glm(Class ~ .pred_good, data = segment_logistic, family = binomial())
preds24 <- predict(model24,
data.frame(.pred_good = seq(0, 1, by = .01)),
type = "response"
)
expect_equal(sd(x24$prob), sd(preds24), tolerance = 0.000001)
expect_equal(mean(x24$prob), mean(1 - preds24), tolerance = 0.000001)
x25 <- .cal_table_logistic(
segment_logistic,
Class,
.pred_poor,
event_level = "second"
)
expect_equal(
which(x25$prob == max(x25$prob)),
nrow(x25)
)
lgst_configs <-
bin_with_configs() %>% cal_plot_logistic(truth = Class, estimate = .pred_good)
expect_true(has_facet(lgst_configs))
# ------------------------------------------------------------------------------
# multinomial outcome, binary logistic plots
multi_configs_from_tune <-
testthat_cal_multiclass() %>% cal_plot_logistic(smooth = FALSE)
expect_s3_class(multi_configs_from_tune, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid"))
multi_configs_from_df <-
mnl_with_configs() %>% cal_plot_logistic(truth = obs, estimate = c(VF:L))
expect_s3_class(multi_configs_from_df, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_df$facet, "FacetGrid"))
})
test_that("Binary logistic functions work with group argument", {
res <- segment_logistic %>%
dplyr::mutate(id = dplyr::row_number() %% 2) %>%
cal_plot_logistic(Class, .pred_good, .by = id)
expect_s3_class(
res,
"ggplot"
)
expect_true(has_facet(res))
expect_s3_class(res, "ggplot")
expect_equal(
res$data[0,],
dplyr::tibble(
id = factor(0, levels = paste(0:1)),
estimate = double(), prob = double(), lower = double(), upper = double()
)
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~estimate"
)
expect_equal(
rlang::expr_text(res$mapping$colour),
"~id"
)
expect_equal(
rlang::expr_text(res$mapping$fill),
"~id"
)
expect_equal(
res$labels,
list(x = "Probability", y = "Predicted Event Rate", colour = "id",
fill = "id", intercept = "intercept", slope = "slope", ymin = "lower",
ymax = "upper")
)
expect_equal(length(res$layers), 3)
expect_snapshot_error(
segment_logistic %>%
dplyr::mutate(group1 = 1, group2 = 2) %>%
cal_plot_logistic(Class, .pred_good, .by = c(group1, group2))
)
lgst_configs <-
bin_with_configs() %>% cal_plot_logistic(truth = Class, estimate = .pred_good)
expect_true(has_facet(lgst_configs))
})
test_that("logistic plot function errors - grouped_df", {
expect_snapshot_error(
cal_plot_logistic(dplyr::group_by(mtcars, vs))
)
})
test_that("Binary windowed functions work", {
skip_if_not_installed("modeldata")
x30 <- .cal_table_windowed(
segment_logistic,
truth = Class,
estimate = .pred_good,
step_size = 0.11,
window_size = 0.10
)
x30_1 <- segment_logistic %>%
dplyr::mutate(x = dplyr::case_when(
.pred_good <= 0.05 ~ 1,
.pred_good >= 0.06 & .pred_good <= 0.16 ~ 2,
.pred_good >= 0.17 & .pred_good <= 0.27 ~ 3,
.pred_good >= 0.28 & .pred_good <= 0.38 ~ 4,
.pred_good >= 0.39 & .pred_good <= 0.49 ~ 5,
.pred_good >= 0.50 & .pred_good <= 0.60 ~ 6,
.pred_good >= 0.61 & .pred_good <= 0.71 ~ 7,
.pred_good >= 0.72 & .pred_good <= 0.82 ~ 8,
.pred_good >= 0.83 & .pred_good <= 0.93 ~ 9,
.pred_good >= 0.94 & .pred_good <= 1 ~ 10,
)) %>%
dplyr::filter(!is.na(x)) %>%
dplyr::count(x)
expect_equal(
x30$total,
x30_1$n
)
x31 <- cal_plot_windowed(segment_logistic, Class, .pred_good)
expect_s3_class(x31, "ggplot")
expect_false(has_facet(x31))
x32 <- .cal_table_windowed(
testthat_cal_binary(),
step_size = 0.11,
window_size = 0.10
)
x32_1 <- testthat_cal_binary() %>%
tune::collect_predictions(summarize = TRUE) %>%
dplyr::mutate(x = dplyr::case_when(
.pred_class_1 <= 0.05 ~ 1,
.pred_class_1 >= 0.06 & .pred_class_1 <= 0.16 ~ 2,
.pred_class_1 >= 0.17 & .pred_class_1 <= 0.27 ~ 3,
.pred_class_1 >= 0.28 & .pred_class_1 <= 0.38 ~ 4,
.pred_class_1 >= 0.39 & .pred_class_1 <= 0.49 ~ 5,
.pred_class_1 >= 0.50 & .pred_class_1 <= 0.60 ~ 6,
.pred_class_1 >= 0.61 & .pred_class_1 <= 0.71 ~ 7,
.pred_class_1 >= 0.72 & .pred_class_1 <= 0.82 ~ 8,
.pred_class_1 >= 0.83 & .pred_class_1 <= 0.93 ~ 9,
.pred_class_1 >= 0.94 & .pred_class_1 <= 1 ~ 10,
)) %>%
dplyr::filter(!is.na(x)) %>%
dplyr::count(.config, x)
expect_equal(
x32$total,
x32_1$n
)
x33 <- cal_plot_windowed(testthat_cal_binary())
expect_s3_class(x33, "ggplot")
expect_true(has_facet(x33))
win_configs <-
bin_with_configs() %>% cal_plot_windowed(truth = Class, estimate = .pred_good)
expect_true(has_facet(win_configs))
# ------------------------------------------------------------------------------
# multinomial outcome, binary windowed plots
multi_configs_from_tune <-
testthat_cal_multiclass() %>% cal_plot_windowed()
expect_s3_class(multi_configs_from_tune, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid"))
multi_configs_from_df <-
mnl_with_configs() %>% cal_plot_windowed(truth = obs, estimate = c(VF:L))
expect_s3_class(multi_configs_from_df, "ggplot")
# should be faceted by .config and class
expect_true(inherits(multi_configs_from_df$facet, "FacetGrid"))
})
test_that("windowed plot function errors - grouped_df", {
expect_snapshot_error(
cal_plot_windowed(dplyr::group_by(mtcars, vs))
)
})
test_that("Event level handling works", {
x7 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "second")
expect_equal(
which(x7$predicted_midpoint == min(x7$predicted_midpoint)),
which(x7$event_rate == max(x7$event_rate))
)
expect_snapshot_error(
.cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "invalid")
)
})
test_that("Groups are respected", {
preds <- segment_logistic %>%
dplyr::mutate(source = "logistic") %>%
dplyr::bind_rows(segment_naive_bayes) %>%
dplyr::mutate(source = ifelse(is.na(source), "nb", source)) %>%
dplyr::group_by(source)
x40 <- .cal_table_breaks(preds, Class, .pred_good)
expect_equal(as.integer(table(x40$source)), c(10, 10))
expect_equal(unique(x40$source), c("logistic", "nb"))
x41 <- .cal_table_logistic(preds, Class, .pred_good)
expect_equal(as.integer(table(x41$source)), c(101, 101))
expect_equal(unique(x41$source), c("logistic", "nb"))
x42 <- .cal_table_windowed(preds, Class, .pred_good)
expect_equal(as.integer(table(x42$source)), c(21, 21))
expect_equal(unique(x42$source), c("logistic", "nb"))
})
test_that("Groupings that may not match work", {
model <- glm(Class ~ .pred_good, segment_logistic, family = "binomial")
preds <- 1 - predict(model, segment_logistic, type = "response")
combined <- dplyr::bind_rows(
dplyr::mutate(segment_logistic, source = "original"),
dplyr::mutate(segment_logistic, .pred_good = preds, source = "glm")
)
x50 <- combined %>%
dplyr::group_by(source) %>%
.cal_table_breaks(Class, .pred_good)
expect_equal(
unique(x50$predicted_midpoint),
seq(0.05, 0.95, by = 0.10)
)
x51 <- combined %>%
dplyr::group_by(source) %>%
.cal_table_windowed(
truth = Class,
estimate = .pred_good,
step_size = 0.11,
window_size = 0.10
)
x51_1 <- combined %>%
dplyr::mutate(x = dplyr::case_when(
.pred_good <= 0.05 ~ 1,
.pred_good >= 0.06 & .pred_good <= 0.16 ~ 2,
.pred_good >= 0.17 & .pred_good <= 0.27 ~ 3,
.pred_good >= 0.28 & .pred_good <= 0.38 ~ 4,
.pred_good >= 0.39 & .pred_good <= 0.49 ~ 5,
.pred_good >= 0.50 & .pred_good <= 0.60 ~ 6,
.pred_good >= 0.61 & .pred_good <= 0.71 ~ 7,
.pred_good >= 0.72 & .pred_good <= 0.82 ~ 8,
.pred_good >= 0.83 & .pred_good <= 0.93 ~ 9,
.pred_good >= 0.94 & .pred_good <= 1 ~ 10,
)) %>%
dplyr::filter(!is.na(x)) %>%
dplyr::count(source, x)
expect_equal(
x51$total,
x51_1$n
)
})
test_that("Numeric groups are supported", {
grp_df <- segment_logistic
grp_df$num_group <- rep(c(1, 2), times = 505)
p <- grp_df %>%
cal_plot_breaks(Class, .pred_good, .by = num_group)
expect_s3_class(p, "ggplot")
})
test_that("Some general exceptions", {
expect_error(
.cal_table_breaks(tune::ames_grid_search),
"The `tune_results` object does not contain columns with predictions"
)
expect_warning(
cal_plot_breaks(segment_logistic, Class),
)
})
# ------------------------------------------------------------------------------
test_that("regression functions work", {
skip_if(R.version[["arch"]] != "aarch64") # see note below
obj <- testthat_cal_reg()
res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred)
expect_s3_class(res, "ggplot")
expect_equal(
res$data[0,],
dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0))
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~outcome"
)
expect_equal(
rlang::expr_text(res$mapping$y),
"~.pred"
)
expect_null(res$mapping$colour)
expect_null(res$mapping$fill)
expect_equal(
res$labels,
list(x = "Observed", y = "Predicted", colour = "colour", fill = "fill",
intercept = "intercept", slope = "slope")
)
expect_equal(length(res$layers), 3)
res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, .by = id)
expect_s3_class(res, "ggplot")
expect_equal(
res$data[0,],
dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0))
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~outcome"
)
expect_equal(
rlang::expr_text(res$mapping$y),
"~.pred"
)
expect_null(res$mapping$colour)
expect_null(res$mapping$fill)
expect_equal(
res$labels,
list(x = "Observed", y = "Predicted", colour = "colour", fill = "fill",
intercept = "intercept", slope = "slope")
)
expect_equal(length(res$layers), 3)
res <- cal_plot_regression(obj)
expect_s3_class(res, "ggplot")
skip_if_not_installed("tune", "1.2.0")
expect_equal(
res$data[0,],
dplyr::tibble(.pred = numeric(0), .row = numeric(0),
predictor_01 = integer(0), outcome = numeric(0),
.config = character())
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~outcome"
)
expect_equal(
rlang::expr_text(res$mapping$y),
"~.pred"
)
expect_null(res$mapping$colour)
expect_null(res$mapping$fill)
expect_equal(
res$labels,
list(x = "Observed", y = "Predicted", colour = "colour", fill = "fill",
intercept = "intercept", slope = "slope")
)
expect_equal(length(res$layers), 3)
res <- print(cal_plot_regression(obj), alpha = 1 / 5, smooth = FALSE)
expect_s3_class(res, "ggplot")
skip_if_not_installed("tune", "1.2.0")
expect_equal(
res$data[0,],
dplyr::tibble(.pred = numeric(0), .row = numeric(0),
predictor_01 = integer(0), outcome = numeric(0),
.config = character())
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~outcome"
)
expect_equal(
rlang::expr_text(res$mapping$y),
"~.pred"
)
expect_null(res$mapping$colour)
expect_null(res$mapping$fill)
expect_equal(
res$labels,
list(x = "Observed", y = "Predicted", colour = "colour", fill = "fill",
intercept = "intercept", slope = "slope")
)
expect_equal(length(res$layers), 3)
res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, smooth = FALSE)
expect_s3_class(res, "ggplot")
expect_equal(
res$data[0,],
dplyr::tibble(outcome = numeric(0), .pred = numeric(0),
id = character())
)
expect_equal(
rlang::expr_text(res$mapping$x),
"~outcome"
)
expect_equal(
rlang::expr_text(res$mapping$y),
"~.pred"
)
expect_null(res$mapping$colour)
expect_null(res$mapping$fill)
expect_equal(
res$labels,
list(x = "Observed", y = "Predicted", colour = "colour", fill = "fill",
intercept = "intercept", slope = "slope")
)
expect_equal(length(res$layers), 3)
})
test_that("regression plot function errors - grouped_df", {
expect_snapshot_error(
cal_plot_regression(dplyr::group_by(mtcars, vs))
)
})
# ------------------------------------------------------------------------------
test_that("don't facet if there is only one .config", {
class_data <- testthat_cal_binary()
class_data$.predictions <- lapply(
class_data$.predictions,
function(x) dplyr::filter(x, .config == "Preprocessor1_Model1")
)
res_breaks <- cal_plot_breaks(class_data)
expect_null(res_breaks$data[[".config"]])
expect_s3_class(res_breaks, "ggplot")
res_logistic <- cal_plot_logistic(class_data)
expect_null(res_logistic$data[[".config"]])
expect_s3_class(res_logistic, "ggplot")
res_windowed <- cal_plot_windowed(class_data)
expect_null(res_windowed$data[[".config"]])
expect_s3_class(res_windowed, "ggplot")
reg_data <- testthat_cal_reg()
reg_data$.predictions <- lapply(
reg_data$.predictions,
function(x) dplyr::filter(x, .config == "Preprocessor01_Model1")
)
res_regression <- cal_plot_regression(reg_data)
expect_null(res_regression$data[[".config"]])
expect_s3_class(res_regression, "ggplot")
})
test_that("custom names for cal_plot_breaks()", {
data(segment_logistic)
segment_logistic_1 <- dplyr::rename(segment_logistic, good_prob = .pred_good)
p <- cal_plot_breaks(segment_logistic_1, Class, good_prob)
expect_s3_class(p, "ggplot")
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.