Nothing
# Load required data
utils::data(models.reg)
utils::data(X.reg)
utils::data(Y.reg)
utils::data(models.class)
utils::data(X.class)
utils::data(Y.class)
utils::data(Sonar, package = "mlbench")
# Set up test environment
set.seed(1234L)
k <- 2L
ens.reg <- caretEnsemble(models.reg)
ens.class <- caretEnsemble(models.class)
# Helper function for prediction tests
test_predictions <- function(ens, newdata, one_row_preds) {
is_class <- isClassifier(ens)
N <- nrow(newdata)
pred_stacked <- predict(ens)
pred <- predict(ens, newdata = newdata)
pred_se <- predict(ens.reg, newdata = X.reg, se = TRUE)
pred_one <- predict(ens, newdata = newdata[1L, , drop = FALSE])
testthat::expect_s3_class(pred_stacked, "data.table")
testthat::expect_s3_class(pred, "data.table")
testthat::expect_s3_class(pred_se, "data.table")
testthat::expect_s3_class(pred_one, "data.table")
testthat::expect_identical(nrow(pred_stacked), N)
testthat::expect_identical(nrow(pred), N)
testthat::expect_identical(nrow(pred_se), N)
testthat::expect_identical(nrow(pred_one), 1L)
testthat::expect_equal(pred_stacked, pred, tol = ifelse(is_class, 0.35, 0.05))
if (is_class) {
testthat::expect_identical(ncol(pred_stacked), 2L)
testthat::expect_identical(ncol(pred), 2L)
testthat::expect_identical(ncol(pred_one), 2L)
testthat::expect_equivalent(pred_one$Yes, one_row_preds[1L], tol = 0.05)
testthat::expect_equivalent(pred_one$No, one_row_preds[2L], tol = 0.05)
} else {
testthat::expect_equivalent(pred_one$pred, one_row_preds[1L], tol = 0.05)
}
}
######################################################################
testthat::context("Metric and residual extraction")
######################################################################
testthat::test_that("We can extract residuals from train regression objects", {
data(iris)
mod <- caret::train(iris[, 1L:2L], iris[, 3L], method = "lm")
r <- stats::residuals(mod)
testthat::expect_is(r, "numeric")
testthat::expect_length(r, 150L)
})
######################################################################
testthat::context("Ensembling and prediction")
######################################################################
testthat::test_that("We can ensemble regression models", {
testthat::expect_s3_class(ens.reg, "caretEnsemble")
test_predictions(ens.reg, X.reg, 5.04)
})
testthat::test_that("We can ensemble classification models", {
testthat::expect_s3_class(ens.class, "caretEnsemble")
test_predictions(ens.class, X.class, c(0.02, 0.98))
})
######################################################################
testthat::context("Ensembling with models of differing predictors")
######################################################################
testthat::test_that("We can ensemble models of different predictors", {
data(iris)
Y.reg <- iris[, 1L]
X.reg <- model.matrix(~., iris[, -1L])
set.seed(482L)
my_control <- defaultControl(target = Y.reg)
nestedList <- list(
glm1 = caret::train(x = X.reg[, c(-1L, -2L, -6L)], y = Y.reg, method = "glm", trControl = my_control),
glm2 = caret::train(x = X.reg[, c(-1L, -3L, -6L)], y = Y.reg, method = "glm", trControl = my_control),
glm3 = caret::train(x = X.reg[, c(-1L, -2L, -3L, -6L)], y = Y.reg, method = "glm", trControl = my_control),
glm4 = caret::train(x = X.reg[, c(-1L, -4L, -6L)], y = Y.reg, method = "glm", trControl = my_control)
)
nestedList <- as.caretList(nestedList)
pred_list <- predict(nestedList, newdata = X.reg)
testthat::expect_s3_class(pred_list, "data.table")
testthat::expect_identical(nrow(pred_list), 150L)
testthat::expect_identical(ncol(pred_list), length(nestedList))
ensNest <- caretEnsemble(nestedList)
testthat::expect_s3_class(ensNest, "caretEnsemble")
pred.nest <- predict(ensNest, newdata = X.reg)
testthat::expect_s3_class(pred.nest, "data.table")
testthat::expect_identical(nrow(pred.nest), 150L)
X_reg_new <- X.reg
X_reg_new[2L, 3L] <- NA
testthat::expect_error(
predict(ensNest, newdata = X_reg_new),
"is.finite(newdata) are not all TRUE",
fixed = TRUE
)
})
######################################################################
testthat::context("Ensembles with custom models")
######################################################################
testthat::test_that("Ensembles using custom models work correctly", {
set.seed(1234L)
custom.rf <- getModelInfo("rf", regex = FALSE)[[1L]]
custom.rf$method <- "custom.rf"
custom.rpart <- getModelInfo("rpart", regex = FALSE)[[1L]]
custom.rpart$method <- "custom.rpart"
tune.list <- list(
caretModelSpec(method = custom.rf, tuneLength = 1L),
myrpart = caretModelSpec(method = custom.rpart, tuneLength = 1L),
treebag = caretModelSpec(method = "treebag", tuneLength = 1L)
)
cl <- caretList(X.class, Y.class, tuneList = tune.list)
cs <- caretEnsemble(cl)
testthat::expect_is(cs, "caretEnsemble")
testthat::expect_named(cs$models, c("custom.rf", "myrpart", "treebag"))
test_predictions(cs, X.class, c(0.0198, 0.9802))
tune.list_bad <- list(
caretModelSpec(method = getModelInfo("rf", regex = FALSE)[[1L]], tuneLength = 1L),
treebag = caretModelSpec(method = "treebag", tuneLength = 1L)
)
testthat::expect_error(
caretList(X.class, Y.class, tuneList = tune.list_bad, trControl = train.control),
"Custom models must be defined with a \"method\" attribute"
)
})
testthat::test_that("Ensembles fails if predictions are not saved", {
models_bad <- models.reg[[1L]]
models_bad$pred <- NULL
testthat::expect_error(
stackedTrainResiduals(models_bad),
"No predictions saved during training. Please set savePredictions = 'final' in trainControl"
)
})
######################################################################
testthat::context("Variable importance and plotting")
######################################################################
testthat::test_that("caret::varImp.caretEnsemble works", {
set.seed(2239L)
for (m in list(ens.class, ens.reg)) {
for (s in c(TRUE, FALSE)) {
i <- caret::varImp(m, normalize = s)
testthat::expect_is(i, "numeric")
if (isClassifier(m)) {
len <- length(m$models) * 2L
n <- c(outer(c("rf", "glm", "rpart", "treebag"), c("No", "Yes"), paste, sep = "_"))
n <- matrix(n, ncol = 2L)
n <- c(t(n))
} else {
len <- length(m$models)
n <- names(m$models)
}
testthat::expect_length(i, len)
testthat::expect_named(i, n)
if (s) {
testthat::expect_true(all(i >= 0.0))
testthat::expect_true(all(i <= 1.0))
testthat::expect_equal(sum(i), 1.0, tolerance = 1e-6)
}
}
}
})
testthat::test_that("plot.caretEnsemble works", {
for (ens in list(ens.class, ens.reg)) {
plt <- plot(ens)
testthat::expect_is(plt, "ggplot")
testthat::expect_identical(nrow(plt$data), 5L)
testthat::expect_named(ens$models, plt$data$model_name[-1L])
}
})
testthat::test_that("ggplot2::autoplot.caretEnsemble works", {
for (ens in list(ens.class, ens.reg)) {
plt1 <- ggplot2::autoplot(ens)
plt2 <- ggplot2::autoplot(ens, xvars = c("Petal.Length", "Petal.Width"))
testthat::expect_is(plt1, "ggplot")
testthat::expect_is(plt2, "ggplot")
testthat::expect_is(plt1, "patchwork")
testthat::expect_is(plt2, "patchwork")
train_model <- ens.reg$models[[1L]]
testthat::expect_error(ggplot2::autoplot(train_model), "Objects of class (.*?) are not supported by autoplot")
}
})
testthat::test_that("summary.caretEnsemble works", {
for (ens in list(ens.class, ens.reg)) {
smry <- testthat::expect_silent(summary(ens.class))
testthat::expect_output(print(smry), ens.class$ens_model$metric)
for (name in names(ens.class$models)) {
testthat::expect_output(print(smry), name)
}
}
})
testthat::test_that("predict.caretEnsemble works with and without se and weights", {
for (ens in list(ens.class, ens.reg)) {
is_class <- isClassifier(ens)
for (se in c(FALSE, TRUE)) {
p <- predict(
ens,
newdata = X.reg,
se = se,
excluded_class_id = 1L
)
testthat::expect_s3_class(p, "data.table")
if (se) {
testthat::expect_named(p, c("pred", "lwr", "upr"))
} else {
testthat::expect_named(p, ifelse(is_class, "Yes", "pred"))
}
}
}
})
testthat::test_that("We can train and ensemble models with custom tuning lists", {
target <- "Class"
custom_list <- caretList(
x = Sonar[, setdiff(names(Sonar), target)],
y = Sonar[, target],
tuneList = list(
rpart = caretModelSpec(
method = "rpart",
tuneGrid = data.table::data.table(.cp = c(0.01, 0.001, 0.1, 1.0))
),
knn = caretModelSpec(
method = "knn",
tuneLength = 9L
),
lda = caretModelSpec(
method = "lda2",
tuneLength = 1L
),
nnet = caretModelSpec(
method = "nnet",
tuneLength = 2L,
trace = FALSE,
softmax = FALSE
)
)
)
testthat::expect_is(custom_list, "caretList")
testthat::expect_identical(nrow(custom_list[["rpart"]]$results), 4L)
testthat::expect_identical(nrow(custom_list[["knn"]]$results), 9L)
testthat::expect_identical(nrow(custom_list[["lda"]]$results), 1L)
testthat::expect_identical(nrow(custom_list[["nnet"]]$results), 4L)
testthat::expect_false(custom_list[["nnet"]]$finalModel$softmax)
custom_ensemble <- caretEnsemble(custom_list)
testthat::expect_is(custom_ensemble, "caretEnsemble")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.