#' Spark ML -- Tuning
#'
#' Perform hyper-parameter tuning using either K-fold cross validation or train-validation split.
#'
#' @details \code{ml_cross_validator()} performs k-fold cross validation while \code{ml_train_validation_split()} performs tuning on one pair of train and validation datasets.
#'
#' @return The object returned depends on the class of \code{x}.
#'
#' \itemize{
#' \item \code{spark_connection}: When \code{x} is a \code{spark_connection}, the function returns an instance of a \code{ml_cross_validator} or \code{ml_traing_validation_split} object.
#'
#' \item \code{ml_pipeline}: When \code{x} is a \code{ml_pipeline}, the function returns a \code{ml_pipeline} with
#' the tuning estimator appended to the pipeline.
#'
#' \item \code{tbl_spark}: When \code{x} is a \code{tbl_spark}, a tuning estimator is constructed then
#' immediately fit with the input \code{tbl_spark}, returning a \code{ml_cross_validation_model} or a
#' \code{ml_train_validation_split_model} object.
#' }
#'
#' @param x A \code{spark_connection}, \code{ml_pipeline}, or a \code{tbl_spark}.
#' @param uid A character string used to uniquely identify the ML estimator.
#' @param ... Optional arguments; currently unused.
#' @param estimator A \code{ml_estimator} object.
#' @param estimator_param_maps A named list of stages and hyper-parameter sets to tune. See details.
#' @param evaluator A \code{ml_evaluator} object, see \link{ml_evaluator}.
#' @param collect_sub_models Whether to collect a list of sub-models trained during tuning.
#' If set to \code{FALSE}, then only the single best sub-model will be available after fitting.
#' If set to true, then all sub-models will be available. Warning: For large models, collecting
#' all sub-models can cause OOMs on the Spark driver.
#' @param parallelism The number of threads to use when running parallel algorithms. Default is 1 for serial execution.
#' @template roxlate-ml-seed
#' @name ml-tuning
NULL
ml_validate_params <- function(expanded_params, stage_jobjs, current_param_list) {
stage_uids <- names(stage_jobjs)
stage_indices <- integer(0)
expanded_params %>%
purrr::imap(function(param_sets, user_input_name) {
# Determine the pipeline stage based on the user specified name.
matched <- paste0("^", user_input_name) %>%
grepl(stage_uids) %>%
which()
# Error if we find more than one or no stage in the pipeline with the name.
if (length(matched) > 1) {
stop("The name ", user_input_name, " matches more than one stage in the pipeline.",
call. = FALSE
)
}
if (length(matched) == 0) {
stop("The name ", user_input_name, " matches no stages in the pipeline.",
call. = FALSE
)
}
# Save the index of the matched stage, this will be used for naming later.
stage_indices[[user_input_name]] <<- matched
stage_jobj <- stage_jobjs[[matched]]
purrr::map(param_sets, function(params) {
# Parameters currently specified in the pipeline object.
current_params <- current_param_list[[matched]] %>%
ml_map_param_list_names()
# Default arguments based on function formals.
default_params <- stage_jobj %>%
ml_get_stage_constructor() %>%
formals() %>%
as.list() %>%
purrr::discard(~ is.symbol(.x) || is.language(.x)) %>%
purrr::compact()
# Create a list of arguments to be validated. The precedence is as follows:
# 1. User specified values in `estimator_param_maps`
# 2. Values already set in pipeline `estimator`
# 3. Default arguments based on constructor function
input_param_names <- names(params)
current_param_names <- names(current_params)
default_param_names <- names(default_params)
current_params_keep <- setdiff(current_param_names, input_param_names)
default_params_keep <- setdiff(default_param_names, current_params_keep)
args_to_validate <- c(
params, current_params[current_params_keep], default_params[default_params_keep]
)
# Call the validator associated with the stage, and return the (validated)
# parameters the user specified.
validation_function <- ml_get_stage_validator(stage_jobj)
if(!is.null(validation_function)) {
do.call(validation_function, list(args_to_validate)) %>%
`[`(input_param_names)
} else {
params_validate_estimator(stage_jobj, args_to_validate) %>%
`[`(input_param_names)
}
})
}) %>%
rlang::set_names(stage_uids[stage_indices])
}
cross_compat <- function(x) {
out <- vctrs::vec_expand_grid(!!!x, .vary = "fastest")
purrr::transpose(out)
}
ml_spark_param_map <- function(param_map, sc, stage_jobjs) {
purrr::imap(param_map, function(param_set, stage_uid) {
purrr::imap(param_set, function(value, param_name) {
# Get the Param object by calling `[stage].[param]` in Scala
list(
param_jobj = stage_jobjs[[stage_uid]] %>%
invoke(ml_map_param_names(param_name, "rs")),
value = value
)
}) %>%
purrr::discard(~ is.null(.x[["value"]]))
}) %>%
unname() %>%
purrr::list_flatten() %>%
purrr::reduce(
function(x, pair) invoke(x, "put", pair$param_jobj, pair$value),
.init = invoke_new(sc, "org.apache.spark.ml.param.ParamMap")
)
}
ml_get_estimator_param_maps <- function(jobj) {
sc <- spark_connection(jobj)
jobj %>%
invoke("getEstimatorParamMaps") %>%
purrr::map(~ invoke_static(sc, "sparklyr.MLUtils", "paramMapToNestedList", .x)) %>%
purrr::map(~ lapply(.x, ml_map_param_list_names))
}
ml_new_validator <- function(sc, class, uid, estimator, evaluator,
estimator_param_maps, seed) {
uid <- cast_string(uid)
possibly_spark_jobj <- possibly_null(spark_jobj)
param_maps <- if (!is.null(estimator) && !is.null(estimator_param_maps)) {
stage_jobjs <- if (inherits(estimator, "ml_pipeline")) {
invoke_static(sc, "sparklyr.MLUtils", "uidStagesMapping", spark_jobj(estimator))
} else {
rlang::set_names(list(spark_jobj(estimator)), ml_uid(estimator))
}
current_param_list <- stage_jobjs %>%
purrr::map(invoke, "extractParamMap") %>%
purrr::map(~ invoke_static(sc, "sparklyr.MLUtils", "paramMapToList", .x))
estimator_param_maps %>%
purrr::map(cross_compat) %>%
ml_validate_params(stage_jobjs, current_param_list) %>%
cross_compat() %>%
purrr::map(ml_spark_param_map, sc, stage_jobjs)
}
jobj <- invoke_new(sc, class, uid) %>%
jobj_set_param("setEstimator", possibly_spark_jobj(estimator)) %>%
jobj_set_param("setEvaluator", possibly_spark_jobj(evaluator)) %>%
jobj_set_param("setSeed", seed)
if (!is.null(param_maps)) {
invoke_static(
sc, "sparklyr.MLUtils", "setParamMaps",
jobj, param_maps
)
} else {
jobj
}
}
new_ml_tuning <- function(jobj, ..., class = character()) {
new_ml_estimator(
jobj,
estimator = possibly_null(
~ invoke(jobj, "getEstimator") %>% ml_call_constructor()
)(),
evaluator = possibly_null(
~ invoke(jobj, "getEvaluator") %>% ml_call_constructor()
)(),
estimator_param_maps = possibly_null(ml_get_estimator_param_maps)(jobj),
...,
class = c(class, "ml_tuning")
)
}
new_ml_tuning_model <- function(jobj, ..., class = character()) {
new_ml_transformer(
jobj,
estimator = invoke(jobj, "getEstimator") %>%
ml_call_constructor(),
evaluator = invoke(jobj, "getEvaluator") %>%
ml_call_constructor(),
estimator_param_maps = ml_get_estimator_param_maps(jobj),
best_model = ml_call_constructor(invoke(jobj, "bestModel")),
...,
class = c(class, "ml_tuning_model")
)
}
print_tuning_info <- function(x, type = c("cv", "tvs")) {
type <- match.arg(type)
num_sets <- length(x$estimator_param_maps)
ml_print_class(x)
ml_print_uid(x)
# Abort if no hyperparameter grid is set.
if (!num_sets) {
return(invisible(NULL))
}
cat(" (Parameters -- Tuning)\n")
if (!is.null(x$estimator)) {
cat(paste0(" estimator: ", ml_short_type(x$estimator), "\n"))
cat(paste0(" "))
ml_print_uid(x$estimator)
}
if (!is.null(x$evaluator)) {
cat(paste0(" evaluator: ", ml_short_type(x$evaluator), "\n"))
cat(paste0(" "))
ml_print_uid(x$evaluator)
cat(" with metric", ml_param(x$evaluator, "metric_name"), "\n")
}
if (identical(type, "cv")) {
cat(" num_folds:", x$num_folds, "\n")
} else {
cat(" train_ratio:", x$train_ratio, "\n")
}
cat(
" [Tuned over", num_sets, "hyperparameter",
if (num_sets == 1) "set]" else "sets]"
)
}
print_best_model <- function(x) {
cat("\n (Best Model)\n")
best_model_output <- capture.output(print(x$best_model))
cat(paste0(" ", best_model_output), sep = "\n")
}
print_tuning_summary <- function(x, type = c("cv", "tvs")) {
type <- match.arg(type)
num_sets <- length(x$estimator_param_maps)
cat(paste0("Summary for ", ml_short_type(x)), "\n")
cat(paste0(" "))
ml_print_uid(x)
cat("\n")
cat(paste0("Tuned ", ml_short_type(x$estimator), "\n"))
cat(paste0(" with metric ", ml_param(x$evaluator, "metric_name"), "\n"))
cat(paste0(
" over ", num_sets, " hyperparameter ",
if (num_sets == 1) "set" else "sets"
), "\n")
if (identical(type, "cv")) {
cat(" via", paste0(x$num_folds, "-fold cross validation"))
} else {
cat(" via", paste0(x$train_ratio, "/", 1 - x$train_ratio, " train-validation split"))
}
cat("\n\n")
cat(paste0("Estimator: ", ml_short_type(x$estimator), "\n"))
cat(paste0(" "))
ml_print_uid(x$estimator)
cat(paste0("Evaluator: ", ml_short_type(x$evaluator), "\n"))
cat(paste0(" "))
ml_print_uid(x$evaluator)
cat("\n")
cat(paste0("Results Summary:"), "\n")
if (identical(type, "cv")) {
print(x$avg_metrics_df)
} else {
print(x$validation_metrics_df)
}
}
#' @rdname ml-tuning
#' @param model A cross validation or train-validation-split model.
#' @return For cross validation, \code{ml_sub_models()} returns a nested
#' list of models, where the first layer represents fold indices and the
#' second layer represents param maps. For train-validation split,
#' \code{ml_sub_models()} returns a list of models, corresponding to the
#' order of the estimator param maps.
#' @export
ml_sub_models <- function(model) {
fn <- model$sub_models %||% stop(
"Cannot extract sub models. `collect_sub_models` must be set to TRUE in ",
"ml_cross_validator() or ml_train_split_validation()."
)
fn()
}
#' @rdname ml-tuning
#' @return \code{ml_validation_metrics()} returns a data frame of performance
#' metrics and hyperparameter combinations.
#' @examples
#' \dontrun{
#' sc <- spark_connect(master = "local")
#' iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE)
#'
#' # Create a pipeline
#' pipeline <- ml_pipeline(sc) %>%
#' ft_r_formula(Species ~ .) %>%
#' ml_random_forest_classifier()
#'
#' # Specify hyperparameter grid
#' grid <- list(
#' random_forest = list(
#' num_trees = c(5, 10),
#' max_depth = c(5, 10),
#' impurity = c("entropy", "gini")
#' )
#' )
#'
#' # Create the cross validator object
#' cv <- ml_cross_validator(
#' sc,
#' estimator = pipeline, estimator_param_maps = grid,
#' evaluator = ml_multiclass_classification_evaluator(sc),
#' num_folds = 3,
#' parallelism = 4
#' )
#'
#' # Train the models
#' cv_model <- ml_fit(cv, iris_tbl)
#'
#' # Print the metrics
#' ml_validation_metrics(cv_model)
#' }
#'
#' @export
ml_validation_metrics <- function(model) {
if (inherits(model, "ml_cross_validator_model")) {
model$avg_metrics_df
} else if (inherits(model, "ml_train_validation_split_model")) {
model$validation_metrics_df
} else {
stop("ml_validation_metrics() must be called on `ml_cross_validator_model` ",
"or `ml_train_validation_split_model`.",
call. = FALSE
)
}
}
param_maps_to_df <- function(param_maps) {
param_maps %>%
lapply(function(param_map) {
param_map %>%
lapply(data.frame, stringsAsFactors = FALSE) %>%
(function(x) {
lapply(seq_along(x), function(n) {
fn <- function(x) paste(x, n, sep = "_")
dplyr::rename_all(x[[n]], fn)
})
}) %>%
dplyr::bind_cols()
}) %>%
dplyr::bind_rows()
}
validate_args_tuning <- function(.args) {
.args[["collect_sub_models"]] <- cast_scalar_logical(.args[["collect_sub_models"]])
.args[["parallelism"]] <- cast_scalar_integer(.args[["parallelism"]])
.args[["seed"]] <- cast_nullable_scalar_integer(.args[["seed"]])
if (!is.null(.args[["estimator"]]) && !inherits(.args[["estimator"]], "ml_estimator")) {
stop("`estimator` must be an `ml_estimator`.")
}
if (!is.null(.args[["estimator_param_maps"]]) && !rlang::is_bare_list(.args[["estimator_param_maps"]])) {
stop("`estimator_param_maps` must be a list.")
}
if (!is.null(.args[["evaluator"]]) && !inherits(.args[["evaluator"]], "ml_evaluator")) {
stop("`evaluator` must be an `ml_evaluator`.")
}
.args
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.