Nothing
context("model_list tests setup")
# Setup ------------------------------------------------------------------------
set.seed(25705)
short <- dplyr::sample_n(na.omit(pima_diabetes), 50)
dreg <- prep_data(short, outcome = pedigree)
dcla <- prep_data(short, outcome = diabetes)
dmul <- prep_data(dplyr::sample_n(iris, 100), outcome = Species)
# Implicit test that warning not issued for missing resampled performance metrics:
suppressWarnings(r_models <- tune_models(dreg, pedigree, tune_depth = 2, n_folds = 3, models = c("rf", "glm")))
c_models <- tune_models(dcla, diabetes, tune_depth = 2, n_folds = 3, model_name = "great_name", models = "xgb")
m_models <- tune_models(d = dmul, outcome = Species, tune_depth = 2, n_folds = 3, models = "rf")
c_pr <- flash_models(dcla, diabetes, metric = "PR", n_folds = 2, models = "xgb")
single_model_tune <- tune_models(dcla, diabetes, models = "xgb", n_folds = 2, tune_depth = 2)
r_flash <- flash_models(dreg, pedigree, n_folds = 2, models = "xgb")
c_flash <- flash_models(dcla, diabetes, n_folds = 2, models = "xgb")
m_flash <- flash_models(dmul, Species, n_folds = 2, models = "xgb")
ods <- attr(c_models, "original_data_str")
context("Checking model_list constructors") # ----------------------------------
test_that("as.model_list fails if model_class is unsupported", {
expect_error(as.model_list(model_class = "what diabetes i even?"))
})
test_that("as.model_list errors if input isn't a caret model", {
expect_error(as.model_list(1:5, model_class = "regression"))
expect_error(as.model_list("HEY"))
expect_error(as.model_list(ranger::ranger(pedigree ~ ., dreg)))
})
test_that("model_lists have original data str as zero-row DF with right names and classes", {
expect_s3_class(ods, "data.frame")
expect_equal(nrow(ods), 0)
expect_equivalent(ods, short[0, -which(names(short) == "diabetes")])
})
test_that("model_list's original_data_str is the same as predict's return", {
preds <- predict(c_models, short)[0, - (1:2)]
expect_equal(preds, ods, check.attributes = FALSE)
})
test_that("model_lists have r, hcai, and other-package versions as attrs", {
expected <- c("r_version", "hcai_version", "other_packages")
expect_true(all(expected %in% names(attributes(r_models)$versions)))
})
context("Checking model_list generics") # --------------------------------------
test_that("plot.model_list works on regression_list", {
expect_equal(class(plot(r_models, print = FALSE)),
c("gg", "ggplot"))
expect_equal(class(plot.model_list(r_models, print = FALSE)),
c("gg", "ggplot"))
r2 <- flash_models(dreg, pedigree, models = "rf", n_folds = 3, metric = "Rsquared")
expect_s3_class(plot(r2, print = FALSE), "gg")
})
test_that("plot.model_list works on classification_list", {
expect_equal(class(plot(c_models, print = FALSE)),
c("gg", "ggplot"))
expect_equal(class(plot.model_list(c_models, print = FALSE)),
c("gg", "ggplot"))
expect_equal(class(plot(c_pr, print = FALSE)),
c("gg", "ggplot"))
# With PR as the metric
expect_equal(class(plot(c_pr, print = FALSE)),
c("gg", "ggplot"))
expect_equal(class(plot.model_list(c_pr, print = FALSE)),
c("gg", "ggplot"))
expect_equal(class(plot(c_pr, print = FALSE)),
c("gg", "ggplot"))
})
test_that("print.model_list works", {
rprint <- capture_output(r_models, TRUE)
expect_true(nchar(rprint) > 0)
expect_true(grepl("regression", rprint, ignore.case = TRUE))
cprint <- capture_output(c_models, TRUE)
expect_true(nchar(cprint) > 0)
expect_true(grepl("classification", cprint, ignore.case = TRUE))
mprint <- capture_output(m_models, TRUE)
expect_true(nchar(mprint) > 0)
expect_true(grepl("multiclass", mprint, ignore.case = TRUE))
# With PR as the metric
cprint <- capture_output(c_pr, TRUE)
expect_true(nchar(cprint) > 0)
expect_true(grepl("PR", cprint, ignore.case = TRUE))
# Model name
rprint <- capture_output(r_models, TRUE)
expect_true(grepl("Model Name: pedigree", rprint, ignore.case = TRUE))
cprint <- capture_output(c_models, TRUE)
expect_true(grepl("Model Name: great_name", cprint, ignore.case = TRUE))
})
test_that("summary.model_list works", {
rsumout <- capture_output(rsum <- summary(r_models), TRUE)
expect_true(nchar(rsumout) > 0)
expect_true(grepl("hyperparameters", rsumout, ignore.case = TRUE))
expect_true(is.list(rsum))
expect_true(rlang::is_named(rsum))
csumout <- capture_output(csum <- summary(c_models), TRUE)
expect_true(nchar(csumout) > 0)
expect_true(grepl("hyperparameters", csumout, ignore.case = TRUE))
expect_true(is.list(csum))
expect_true(rlang::is_named(csum))
# With PR as the metric
csumout <- capture_output(csum <- summary(c_pr), TRUE)
expect_true(nchar(csumout) > 0)
expect_true(grepl("Precision", csumout, ignore.case = TRUE))
expect_true(is.list(csum))
expect_true(rlang::is_named(csum))
})
context("Checking model_list generics on untuned model_lists") #----------------
test_that("print.model_list works with untuned_model_lists", {
expect_warning(flash_r_print <- capture_output(print(r_flash)), NA)
expect_warning(flash_c_print <- capture_output(print(c_flash)), NA)
expect_warning(flash_m_print <- capture_output(print(m_flash)), NA)
expect_false(grepl("Inf", flash_r_print))
expect_false(grepl("Inf", flash_c_print))
expect_false(grepl("Inf", flash_m_print))
expect_true(grepl("Target: pedigree", flash_r_print))
expect_true(grepl("Target: diabetes", flash_c_print))
expect_true(grepl("Target: Species", flash_m_print))
expect_true(grepl("Models have not been tuned", flash_r_print))
expect_true(grepl("selected hyperparameter values", flash_c_print))
expect_true(grepl("Accuracy", flash_m_print))
})
test_that("summary.model_list works with untuned_model_lists", {
expect_warning(flash_r_summary <- capture_output(summary(r_flash)), NA)
expect_warning(flash_c_summary <- capture_output(summary(c_flash)), NA)
expect_warning(flash_m_summary <- capture_output(summary(m_flash)), NA)
expect_false(grepl("Inf", flash_r_summary))
expect_false(grepl("Inf", flash_c_summary))
expect_false(grepl("Inf", flash_m_summary))
expect_false(grepl("0 rows", flash_r_summary))
expect_false(grepl("Best performance:", flash_r_summary))
expect_true(grepl("Best algorithm:", flash_c_summary))
expect_true(grepl("Kappa", flash_m_summary))
})
test_that("plot.model_list works with message untuned_model_lists", {
expect_warning(flash_r_plot <- plot(r_flash, print = FALSE), NA)
expect_warning(flash_c_plot <- plot(c_flash, print = FALSE), NA)
expect_message(plot(c_flash, print = FALSE))
expect_s3_class(flash_r_plot, "gg")
expect_s3_class(flash_c_plot, "gg")
})
context("Testing model list utilities") # --------------------------------------
test_that("change_metric_names changes AUC to AUPR and prints", {
m <- change_metric_names(c_pr)
expect_true(all(c("AUPR", "Precision", "Recall") %in% names(m[[1]]$results)))
expect_true(all(c("AUPR", "Precision", "Recall") %in% names(m[[1]]$results)))
expect_error(capture_output(print(m)), NA)
})
test_that("change_metric_names changes ROC to AUROC", {
m <- change_metric_names(c_models)
expect_true(all(c("AUROC", "Sens", "Spec") %in% names(m[[1]]$results)))
expect_true(all(c("AUROC", "Sens", "Spec") %in% names(m[[1]]$results)))
expect_error(capture_output(print(m)), NA)
})
test_that("Change PR metric doesn't change object class", {
expect_setequal(class(change_metric_names(c_pr)), class(c_pr))
expect_setequal(class(change_metric_names(c_models)), class(c_models))
preds <- predict(c_models)
expect_setequal(class(change_metric_names(preds)), class(preds))
})
test_that("model_lists have time model trained attribute", {
check_timestamp <- function(m) expect_true(lubridate::is.POSIXt(attr(m, "timestamp")))
check_timestamp(r_models)
check_timestamp(c_models)
check_timestamp(m_models)
check_timestamp(c_pr)
check_timestamp(r_flash)
check_timestamp(c_flash)
check_timestamp(m_flash)
})
test_that("model_lists only carry training data in recipe", {
expect_null(r_models[[1]]$trainingData)
expect_null(c_models[[1]]$trainingData)
expect_null(r_models[[2]]$trainingData)
expect_null(m_models[[1]]$trainingData)
})
test_that("[ extracts models by index", {
# Warnings here are due to tiny models that didn't converge.
expect_s3_class(suppressWarnings(r_flash[1]), "model_list")
expect_s3_class(suppressWarnings(r_models[2]), "model_list")
expect_s3_class(suppressWarnings(r_models[1:2]), "model_list")
expect_equivalent(suppressWarnings(r_models[seq_along(r_models)]), r_models)
})
test_that("[ extracts by name, index, or logical vector", {
expect_equivalent(suppressWarnings(r_models[1]), suppressWarnings(r_models[names(r_models)[1]]))
expect_equivalent(suppressWarnings(c_models[1]), suppressWarnings(c_models[names(c_models)[1]]))
expect_equivalent(suppressWarnings(r_models[1]), suppressWarnings(r_models[c(TRUE, FALSE)]))
expect_equivalent(suppressWarnings(r_models[2]), suppressWarnings(r_models[names(r_models)[2]]))
expect_equivalent(suppressWarnings(r_models[1:2]), suppressWarnings(r_models[names(r_models)[1:2]]))
expect_equivalent(suppressWarnings(r_models), suppressWarnings(r_models[c(TRUE, TRUE)]))
})
test_that("metrics and predict are same for extracted best model", {
best <- extract_model_info(c_models)$best_model_name
extracted <- c_models[best]
expect_equivalent(evaluate(c_models), evaluate(extracted))
expect_equivalent(predict(c_models), predict(extracted))
})
test_that("performance drops if the best model is pulled out of model_list", {
mi <- extract_model_info(r_models)
not_best <- which(!names(r_models) %in% mi$best_model_name)
original_rmse <- attr(r_models, "performance")[mi$metric]
subsetted_rmse <- attr(r_models[not_best], "performance")[mi$metric]
expect_true(original_rmse < subsetted_rmse)
})
test_that("extract_model_info has correct has_training_data item", {
mi <- extract_model_info(r_models)
expect_true(mi[["has_training_data"]])
# True if one
attr(r_models, "recipe")$template <- NULL
mi <- extract_model_info(r_models)
expect_true(mi[["has_training_data"]])
# True if one
attr(r_models, "recipe")$template <- attr(r_models, "recipe")$orig_data
attr(r_models, "recipe")$orig_data <- NULL
mi <- extract_model_info(r_models)
expect_true(mi[["has_training_data"]])
# False if none
attr(r_models, "recipe")$template <- NULL
attr(r_models, "recipe")$orig_data <- NULL
mi <- extract_model_info(r_models)
expect_false(mi[["has_training_data"]])
})
test_that("extraction from model list doesn't change model-trained timestamp", {
expect_equal(attr(c_models, "timestamp"), attr(c_models[1], "timestamp"))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.