Nothing
#' Benchmark Multiple Models with Cross-Validation and Model-Specific Parameters
#'
#' Perform k-fold cross-validation on a list of models, using model-specific parameters.
#' Supports verbose messages and a progress bar.
#'
#' @param models A named list of \code{Model$new(...)} objects to benchmark.
#' @param X A data frame or matrix of predictors.
#' @param y A vector of outcomes (factor for classification, numeric for regression).
#' @param cv Integer, number of cross-validation folds (default 5).
#' @param scoring Scoring metric: "rmse", "mae", "accuracy", or "f1"
#' (default: auto-detected based on task)
#' @param params Optional named list of lists, each sublist containing extra arguments
#' to pass to the corresponding model's \code{fit()} call. Names must match `models`.
#' @param cl Optional number of clusters for parallel processing
#' @param show_progress Logical, whether to show a progress bar (default TRUE).
#' @param verbose Logical, whether to print messages about each model (default TRUE).
#'
#' @return A list containing the CV scores for each model.
#'
#' @examples
#' \dontrun{
#' library(randomForest)
#'
#' X <- iris[, 1:4]
#' y <- iris$Species
#'
#' models <- list(
#' glm = Model$new(caret::train),
#' rf = Model$new(randomForest::randomForest),
#' xgb = Model$new(caret::train)
#' )
#'
#' params <- list(
#' glm = list(method = "glmnet",
#' tuneGrid = data.frame(alpha = 0, lambda = 0.01),
#' trControl = trainControl(method = "none")),
#' rf = list(ntree = 150),
#' xgb = list(method = "xgbTree",
#' tuneGrid = data.frame(nrounds = 150, max_depth = 3, eta = 0.3,
#' gamma = 0, colsample_bytree = 1,
#' min_child_weight = 1, subsample = 1),
#' trControl = trainControl(method = "none"))
#' )
#'
#' results <- benchmark(models, X, y, cv = 5, params = params,
#' show_progress = TRUE, verbose = TRUE)
#' print(results)
#' }
#' @export
benchmark <- function(models, X, y, cv = 5L, scoring = NULL, params = NULL, cl=NULL, show_progress = FALSE, verbose = TRUE) {
n_models <- length(models)
results <- vector("list", n_models)
names(results) <- names(models)
if (show_progress) {
pb <- utils::txtProgressBar(min = 0, max = n_models, style = 3)
}
for (i in seq_along(models)) {
model_name <- names(models)[i]
mod <- models[[i]]
if (verbose) cat(sprintf("\n[%d/%d] Fitting model: %s\n",
i, length(models), model_name))
# Extract model-specific parameters if provided
extra_args <- if (!is.null(params) && model_name %in% names(params)) {
params[[model_name]]
} else {
list()
}
scores <- cross_val_score(
model = mod,
X = X,
y = y,
cv = cv,
scoring = scoring,
show_progress = FALSE,
cl = cl,
fit_params = extra_args
)
results[[i]] <- list(avg_score = mean(scores), scores = scores)
if (verbose) cat(sprintf("Mean CV score for %s: %.4f\n", model_name, results[[i]]$avg_score))
if (show_progress) utils::setTxtProgressBar(pb, i)
}
if (show_progress) close(pb)
return(results)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.