library(civis)
context("civis_ml_utils")
model_list <- readRDS("data/civis_ml_models.rds")
is_classif <- sapply(model_list, function(m) is(m, "civis_ml_classifier"))
is_noval <- sapply(model_list, function(m) is.null(m$metrics))
job_ids <- lapply(model_list, function(x) x[["job"]][["id"]])
run_ids <- lapply(model_list, function(x) x[["run"]][["id"]])
id_regex <- paste(paste0("(", job_ids, ")*"), paste0("(", run_ids, ")"))
class_algo <- sapply(model_list[is_classif], function(x) x$model_info$model$model)
reg_algo <- sapply(model_list[!is_classif], function(x) x$model_info$model$model)
feat_imp_mods <- model_list[c(2, 3, 4, 10, 11, 12, 15, 16)]
feat_imp_err_mods <- model_list[!(model_list %in% feat_imp_mods)]
coef_mods <- model_list[c(1, 7, 8, 9, 18)]
no_coef_mods <- model_list[!(model_list %in% coef_mods)]
str_detect_multiple <- function(string, pattern) {
mapply(function(string, pattern) grepl(pattern, string),
string = string, pattern = pattern)
}
test_that("get_train_template_id works", {
fake_civis_ml_template_ids <- data.frame(id=c(11219,11220,10615,11028,11219,7020),
version=c("prod","prod","dev","dev","v2.2","v0.5"),
name=c("training","prediction","training","registration","training","training"),
stringsAsFactors=FALSE)
with_mock(
`civis::get_template_ids_all_versions` = function(...) fake_civis_ml_template_ids,
expect_equal(get_train_template_id("prod"), 11219),
expect_equal(get_train_template_id("dev"), 10615),
expect_equal(get_train_template_id("v2.2"), 11219),
expect_equal(get_train_template_id("v0.5"), 7020),
expect_error(get_train_template_id("foo"))
)
})
test_that("get_template_ids_all_versions works", {
fake_template_alias_objects <- list(list(id = 11,
objectId = 12367,
objectType = "template_script",
alias = "civis-ShapefileExport",
userId = 3001,
displayName = "Export Shapefile"),
list(id = 14,
objectId = 11219,
objectType = "template_script",
alias = "civis-civisml-training",
userId = 400,
displayName = "Model Training"),
list(id = 21,
objectId = 10615,
objectType = "template_script",
alias = "civis-civisml-training-dev",
userId = 400,
displayName = "Model Training - DEV ONLY"),
list(id = 26,
objectId = 11221,
objectType = "template_script",
alias = "civis-civisml-registration-v2-2",
userId = 1750,
displayName = "Trained Model Registration, v2.2"))
with_mock(
`civis::fetch_until` = function(...) fake_template_alias_objects,
expect_equal(get_template_ids_all_versions(),
data.frame(id=c(11219,10615,11221),
version=c("prod","dev","v2.2"),
name=c("training","training","registration"),
stringsAsFactors=FALSE)
)
)
})
test_that("get_job_type_version works", {
expect_equal(get_job_type_version("civis-civisml-training"),
list(job_type = "training", version = "prod"))
expect_equal(get_job_type_version("civis-civisml-training-v2-3"),
list(job_type = "training", version = "v2.3"))
expect_equal(get_job_type_version("civis-civisml-training-dev"),
list(job_type = "training", version = "dev"))
expect_equal(get_job_type_version("civis-civisml-training-foo-bar"),
list(job_type = "training", version = "foo-bar"))
expect_error(get_job_type_version("foo-bar"))
expect_error(get_job_type_version("civis-civisml"))
expect_error(get_job_type_version("civis-civisml-"))
expect_error(get_job_type_version("civis-civisml-training-"))
expect_error(get_job_type_version("civis-civisml-training-foobar-"))
})
test_that("print.civis_ml_classifier works", {
class_msg <- lapply(model_list[is_classif], function(x) utils::capture.output(x))
first_row <- lapply(class_msg, function(x) x[[1]])
mapply(expect_match, first_row, class_algo)
third_row <- lapply(class_msg, function(x) x[[3]])
expect_true(all(str_detect_multiple(third_row, id_regex[is_classif])))
})
test_that("print.civis_ml_regressor works", {
reg_msg <- lapply(model_list[!is_classif], function(x) utils::capture.output(x))
n_models <- length(reg_msg)
first_row <- lapply(reg_msg, function(x) x[[1]])
mapply(expect_match, first_row, reg_algo)
third_row <- lapply(reg_msg, function(x) x[[3]])
expect_true(all(str_detect_multiple(third_row, id_regex[!is_classif])))
reg_msg <- lapply(reg_msg, paste0, collapse = '')
lapply(reg_msg[1:(n_models - 1)], expect_match, "(MAD)*(RMSE)*(R-squared)")
lapply(reg_msg[1:(n_models - 1)], expect_match, "weight")
lapply(reg_msg[n_models - 1], expect_match, c("(weight)*(Time)"))
})
test_that("print.civis_ml digits works", {
m <- model_list[!is_classif][[1]]
d_str <- capture.output(print(m, digits = 2))[6:8]
nums <- lapply(strsplit(d_str, " "), tail, 1)
dec <- lapply(sapply(nums, strsplit, "\\."), tail, 1)
expect_equal(sapply(dec, nchar), rep(2, 3))
})
test_that("get_metrics returns metrics", {
for (m in model_list) {
metr <- m$metrics$metrics
expect_equal(metr, get_metric(m))
}
})
test_that("get_metrics throws error if not model", {
msg <- "is_civis_ml\\(model\\) is not TRUE"
expect_error(get_metric("HIPPO_INA_STRING"), msg)
})
test_that("get_model_data returns model data", {
for (m in model_list) {
dat <- m$model_info$data
expect_equal(dat, get_model_data(m))
expect_equal(get_model_data(m, "target_columns"),
m$model_info$data$target_columns)
}
})
test_that("is_multiclass works", {
expect_true(is_multiclass(model_list[[1]]))
# binary
expect_false(is_multiclass(model_list[[7]]))
# reg
expect_false(is_multiclass(model_list[!is_classif][[1]]))
})
test_that("is_multitarget works", {
test <- sapply(model_list, is_multitarget)
ans <- sapply(model_list, function(m) length(m$model_info$data$n_unique_targets) > 1)
expect_equal(test, ans)
})
test_that("get_predict_template_id returns correct template for train/predict version ", {
fake_civis_ml_template_ids <- data.frame(id=c(9968,9969,9968,9969,10582,10583),
version=c("prod","prod","v2.2","v2.2","v2.1","v2.1"),
name=c("training","prediction","training","prediction","training","prediction"),
stringsAsFactors=FALSE)
m <- model_list[[1]]
with_mock(
`civis::get_template_ids_all_versions` = function(...) fake_civis_ml_template_ids,
expect_equal(get_predict_template_id(m), 9969)
)
fake_model <- list(job = list(fromTemplateId = 10582))
with_mock(
`civis::get_template_ids_all_versions` = function(...) fake_civis_ml_template_ids,
expect_equal(get_predict_template_id(fake_model), 10583)
)
})
test_that("get_feature_importance returns correct feature importance matrix when available", {
true_feature_importances <- readRDS("data/feature_importances.rds")
test_feature_importances <- lapply(feat_imp_mods, get_feature_importance)
expect_equal(true_feature_importances, test_feature_importances)
})
test_that("models with no feature importance throw errors for get_feature_importance", {
for (m in feat_imp_err_mods) {
expect_error(get_feature_importance(m), "Feature importance data not available.")
}
})
test_that("coef.civis_ml returns correct coefficients when available", {
true_coefs <- readRDS("data/model_coefficients.rds")
test_coefs <- lapply(coef_mods, coef)
expect_equal(true_coefs, test_coefs)
})
test_that("coef.civis_ml returns NULL when coefficients are unavailable", {
for (m in no_coef_mods) {
expect_null(coef(m))
}
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.