# Load data and create models
utils::data(models.reg)
utils::data(X.reg)
utils::data(Y.reg)
utils::data(models.class)
utils::data(X.class)
utils::data(Y.class)
utils::data(iris)
models_multiclass <- caretList(
x = iris[, -5L],
y = iris[, 5L],
methodList = c("rpart", "glmnet")
)
ens.class <- caretStack(models.class, preProcess = "pca", method = "glm", metric = "ROC")
ens.reg <- caretStack(models.reg, preProcess = "pca", method = "lm")
# Helper functions
expect_all_finite <- function(x) {
testthat::expect_true(all(vapply(x, function(col) all(is.finite(col)), logical(1L))))
}
######################################################################
testthat::context("caretStack")
######################################################################
testthat::test_that("caretStack creates valid ensemble models", {
testthat::expect_s3_class(ens.class, "caretStack")
testthat::expect_s3_class(ens.reg, "caretStack")
testthat::expect_s3_class(summary(ens.class), "summary.caretStack")
testthat::expect_s3_class(plot(ens.class), "ggplot")
testthat::expect_output(print(ens.class), "The following models were ensembled: rf, glm, rpart, treebag")
})
testthat::test_that("predict works for classification and regression ensembles", {
ens_list <- list(class = ens.class, reg = ens.reg)
X_list <- list(class = X.class, reg = X.reg)
for (model_name in names(ens_list)) {
ens <- ens_list[[model_name]]
X <- X_list[[model_name]]
param_grid <- expand.grid(
se = c(TRUE, FALSE),
newdata = c(TRUE, FALSE)
)
for (i in seq_len(nrow(param_grid))) {
params <- param_grid[i, ]
newdata <- if (params$newdata) X[1L:50L, ] else NULL
expected_rows <- if (params$newdata) 50L else nrow(X)
pred <- predict(ens, newdata = newdata, se = params$se)
testthat::expect_s3_class(pred, "data.table")
testthat::expect_identical(nrow(pred), expected_rows)
expect_all_finite(pred)
}
}
})
testthat::test_that("Classification predictions return correct output", {
pred.class <- predict(ens.class, X.class, return_class_only = TRUE)
testthat::expect_s3_class(pred.class, "factor")
testthat::expect_length(pred.class, nrow(X.class))
testthat::expect_named(
predict(ens.class, X.class, se = TRUE, excluded_class_id = 1L),
c("pred", "lwr", "upr")
)
testthat::expect_named(
predict(ens.class, X.class, se = TRUE, excluded_class_id = 0L),
c("pred.No", "pred.Yes", "lwr.No", "lwr.Yes", "upr.No", "upr.Yes")
)
})
testthat::test_that("Regression predictions return correct output", {
testthat::expect_named(
predict(ens.reg, X.reg, se = TRUE, excluded_class_id = 0L),
c("pred", "lwr", "upr")
)
})
testthat::test_that("Predictions are reproducible", {
set.seed(42L)
p1 <- predict(ens.class, X.class, se = TRUE, level = 0.8)
set.seed(42L)
p2 <- predict(ens.class, X.class, se = TRUE, level = 0.8)
testthat::expect_equivalent(p1, p2)
})
testthat::test_that("caretStack handles missing data correctly", {
ens.reg.subset <- caretEnsemble::caretStack(models.reg[2L:3L], method = "lm")
X_reg_na <- X.reg
X_reg_na[sample.int(nrow(X_reg_na), 20L), sample.int(ncol(X_reg_na) - 1L, 1L)] <- NA
pred.reg <- predict(ens.reg.subset, newdata = X_reg_na)
testthat::expect_identical(nrow(pred.reg), nrow(X_reg_na))
pred.reg <- predict(ens.reg.subset, newdata = X_reg_na)
testthat::expect_identical(nrow(pred.reg), nrow(X_reg_na))
})
testthat::test_that("caretStack handles multiclass problems", {
meta_model <- caretEnsemble::caretStack(models_multiclass, method = "rpart", excluded_class_id = 4L)
pred <- predict(meta_model, newdata = iris)
testthat::expect_identical(nrow(pred), 150L)
testthat::expect_identical(ncol(pred), 3L)
expect_all_finite(pred)
})
testthat::test_that("caretStack works with different stacking algorithms", {
stack_methods <- c("glm", "rf", "gbm", "glmnet")
for (method in stack_methods) {
for (model_list in list(models.reg, models.class)) {
stack <- if (method == "gbm") {
caretEnsemble::caretStack(model_list, method = method, verbose = FALSE)
} else {
caretEnsemble::caretStack(model_list, method = method)
}
testthat::expect_s3_class(stack, "caretStack")
testthat::expect_identical(stack$ens_model$method, method)
predictions <- predict(stack, newdata = if (identical(model_list, models.reg)) X.reg else X.class)
testthat::expect_identical(nrow(predictions), nrow(if (identical(model_list, models.reg)) X.reg else X.class))
}
}
})
testthat::test_that("caretStack handles custom preprocessing", {
preprocess <- c("center", "scale", "pca")
for (model_list in list(models.class, models.reg)) {
stack <- caretEnsemble::caretStack(model_list, method = "glm", preProcess = preprocess)
testthat::expect_s3_class(stack, "caretStack")
testthat::expect_named(stack$ens_model$preProcess$method, c(preprocess, "ignore"))
}
})
testthat::test_that("caretStack handles custom performance function", {
custom_summary <- function(data, lev = NULL, model = NULL) {
c(default = mean(data$obs == data$pred))
}
for (model_list in list(models.class, models.reg)) {
stack <- caretEnsemble::caretStack(
model_list,
method = "glm",
metric = "default",
trControl = caret::trainControl(method = "cv", number = 3L, summaryFunction = custom_summary)
)
testthat::expect_s3_class(stack, "caretStack")
testthat::expect_true("default" %in% names(stack$ens_model$results))
}
})
testthat::test_that("caretStack handles new data correctly", {
set.seed(42L)
N <- 50L
idx <- sample.int(nrow(X.reg), N)
stack_class <- caretStack(
models.class,
metric = "ROC",
method = "rpart",
new_X = X.class[idx, ],
new_y = Y.class[idx]
)
stack_reg <- caretStack(
models.reg,
method = "glm",
new_X = X.reg[idx, ],
new_y = Y.reg[idx]
)
testthat::expect_s3_class(stack_class, "caretStack")
testthat::expect_s3_class(stack_reg, "caretStack")
pred_class <- predict(stack_class, X.class)
pred_reg <- predict(stack_reg, X.reg)
testthat::expect_s3_class(pred_class, "data.table")
testthat::expect_s3_class(pred_reg, "data.table")
testthat::expect_identical(nrow(pred_class), nrow(X.class))
testthat::expect_identical(nrow(pred_reg), nrow(X.reg))
testthat::expect_identical(ncol(pred_class), 2L)
testthat::expect_identical(ncol(pred_reg), 1L)
})
testthat::test_that("caretStack coerces lists to caretLists", {
models <- list(
models.class[[1L]],
models.class[[2L]]
)
testthat::expect_warning(
caretStack(models, method = "glm", tuneLength = 1L),
"Attempting to coerce all.models to a caretList."
)
})
testthat::test_that("caretStack errors if new_X is provided but not new_y", {
testthat::expect_error(
caretStack(models.class, new_X = X.class),
"Both new_X and new_y must be NULL, or neither."
)
})
testthat::test_that("caretStack errors if new_y is provided but not new_X", {
testthat::expect_error(
caretStack(models.class, new_y = Y.class),
"Both new_X and new_y must be NULL, or neither."
)
})
######################################################################
testthat::context("S3 methods for caretStack")
######################################################################
testthat::test_that("print", {
for (ens in list(ens.class, ens.reg)) {
testthat::expect_output(print(ens), "The following models were ensembled: rf, glm, rpart, treebag")
testthat::expect_output(print(ens), "Linear")
testthat::expect_output(print(ens), "Resampling: Cross-Validated (5 fold)", fixed = TRUE)
}
})
testthat::test_that("summary", {
for (ens in list(ens.class, ens.reg)) {
s <- summary(ens)
testthat::expect_s3_class(s, "summary.caretStack")
testthat::expect_output(print(s), "The following models were ensembled: rf, glm, rpart, treebag")
testthat::expect_output(print(s), "Model Importance:")
testthat::expect_output(print(s), "Model accuracy:")
}
})
testthat::test_that("plot", {
for (ens in list(ens.class, ens.reg)) {
p <- plot(ens)
testthat::expect_s3_class(p, "ggplot")
}
})
testthat::test_that("dotplot", {
for (ens in list(ens.class, ens.reg)) {
p <- lattice::dotplot(ens)
testthat::expect_s3_class(p, "trellis")
}
})
testthat::test_that("autoplot", {
for (ens in list(ens.class, ens.reg)) {
p <- ggplot2::autoplot(ens, X.reg)
testthat::expect_s3_class(p, "ggplot")
}
})
######################################################################
testthat::context("varImp")
######################################################################
testthat::test_that("varImp works for classification and regression", {
for (ens in list(ens.class, ens.reg)) {
imp <- caret::varImp(ens)
testthat::expect_type(imp, "double")
testthat::expect_named(imp, names(ens$models))
testthat::expect_equal(sum(imp), 1.0, tolerance = 1e-6)
imp_new_data <- caret::varImp(ens, if (identical(ens, ens.class)) X.class else X.reg)
testthat::expect_type(imp_new_data, "double")
testthat::expect_named(imp_new_data, names(ens$models))
testthat::expect_equal(sum(imp_new_data), 1.0, tolerance = 1e-6)
}
})
######################################################################
testthat::context("wtd.sd")
######################################################################
testthat::test_that("wtd.sd calculates weighted standard deviation correctly", {
x <- c(1L, 2L, 3L, 4L, 5L)
w <- c(1L, 1L, 1L, 1L, 1L)
testthat::expect_equal(caretEnsemble::wtd.sd(x, w), stats::sd(x), tolerance = 0.001)
w_uneven <- c(2L, 1L, 1L, 1L, 1L)
testthat::expect_false(isTRUE(all.equal(caretEnsemble::wtd.sd(x, w_uneven), stats::sd(x))))
x_na <- c(1L, 2L, NA, 4L, 5L)
testthat::expect_true(is.na(caretEnsemble::wtd.sd(x_na, w)))
testthat::expect_false(is.na(caretEnsemble::wtd.sd(x_na, w, na.rm = TRUE)))
testthat::expect_error(caretEnsemble::wtd.sd(x, w[-1L]), "'x' and 'w' must have the same length")
x3 <- c(10L, 10L, 10L, 20L)
w1 <- c(0.1, 0.1, 0.1, 0.7)
testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1), 5.291503, tolerance = 0.001)
testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1 * 100L), caretEnsemble::wtd.sd(x3, w = w1), tolerance = 0.001)
})
######################################################################
testthat::context("set_excluded_class_id")
######################################################################
testthat::test_that("set_excluded_class_id warning if unset", {
old_ensemble <- ens.class
old_ensemble$excluded_class_id <- NULL
is_class <- isClassifier(old_ensemble)
new_ensemble <- expect_warning(
set_excluded_class_id(old_ensemble, is_class),
"No excluded_class_id set. Setting to 1L."
)
testthat::expect_identical(new_ensemble$excluded_class_id, 1L)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.