Nothing
#' Plot tuning search results
#'
#' @param object A tibble of results from [tune_grid()] or [tune_bayes()].
#' @param type A single character value. Choices are `"marginals"` (for a plot
#' of each predictor versus performance; see Details below), `"parameters"`
#' (each parameter versus search iteration), or `"performance"` (performance
#' versus iteration). The latter two choices are only used for [tune_bayes()].
#' @param metric A character vector or `NULL` for which metric to plot. By
#' default, all metrics will be shown via facets.
#' @param width A number for the width of the confidence interval bars when
#' `type = "performance"`. A value of zero prevents them from being shown.
#' @param ... For plots with a regular grid, this is passed to `format()` and is
#' applied to a parameter used to color points. Otherwise, it is not used.
#' @return A `ggplot2` object.
#' @details
#'
#' When the results of `tune_grid()` are used with `autoplot()`, it tries to
#' determine whether a _regular grid_ was used.
#'
#' ## Regular grids
#'
#' For regular grids with one or more numeric tuning parameters, the parameter
#' with the most unique values is used on the x-axis. If there are categorical
#' parameters, the first is used to color the geometries. All other parameters
#' are used in column faceting.
#'
#' The plot has the performance metric(s) on the y-axis. If there are multiple
#' metrics, these are row-faceted.
#'
#' If there are more than five tuning parameters, the "marginal effects" plots
#' are used instead.
#'
#' ## Irregular grids
#'
#' For space-filling or random grids, a _marginal_ effect plot is created. A
#' panel is made for each numeric parameter so that each parameter is on the
#' x-axis and performance is on the y-xis. If there are multiple metrics, these
#' are row-faceted.
#'
#' A single categorical parameter is shown as colors. If there are two or more
#' non-numeric parameters, an error is given. A similar result occurs is only
#' non-numeric parameters are in the grid. In these cases, we suggest using
#' `collect_metrics()` and `ggplot()` to create a plot that is appropriate for
#' the data.
#'
#' If a parameter has an associated transformation associated with it (as
#' determined by the parameter object used to create it), the plot shows the
#' values in the transformed units (and is labeled with the transformation type).
#'
#' Parameters are labeled using the labels found in the parameter object
#' _except_ when an identifier was used (e.g. `neighbors = tune("K")`).
#'
#' @seealso [tune_grid()], [tune_bayes()]
#' @examplesIf tune:::should_run_examples()
#' # For grid search:
#' data("example_ames_knn")
#'
#' # Plot the tuning parameter values versus performance
#' autoplot(ames_grid_search, metric = "rmse")
#'
#'
#' # For iterative search:
#' # Plot the tuning parameter values versus performance
#' autoplot(ames_iter_search, metric = "rmse", type = "marginals")
#'
#' # Plot tuning parameters versus iterations
#' autoplot(ames_iter_search, metric = "rmse", type = "parameters")
#'
#' # Plot performance over iterations
#' autoplot(ames_iter_search, metric = "rmse", type = "performance")
#' @export
autoplot.tune_results <-
function(object,
type = c("marginals", "parameters", "performance"),
metric = NULL,
width = NULL,
...) {
type <- match.arg(type)
has_iter <- any(names(object) == ".iter")
if (!has_iter && type != "marginals") {
rlang::abort(paste0("`type = ", type, "` is only used iterative search results."))
}
pset <- .get_tune_parameters(object)
if (any(is.na(pset$object))) {
p_names <- pset$id[is.na(pset$object)]
msg <-
paste0(
"Some parameters do not have corresponding parameter objects ",
"and cannot be used with `autoplot()`: ",
paste0("'", p_names, "'", collapse = ", ")
)
rlang::abort(msg)
}
if (type == "parameters") {
p <- plot_param_vs_iter(object)
} else {
if (type == "performance") {
p <- plot_perf_vs_iter(object, metric, width)
} else {
if (use_regular_grid_plot(object)) {
p <- plot_regular_grid(object, metric = metric, ...)
} else {
p <- plot_marginals(object, metric = metric)
}
}
}
p
}
#' @export
autoplot.resample_results <- function(object, ...) {
rlang::abort("There is no `autoplot()` implementation for `resample_results`.")
}
# ------------------------------------------------------------------------------
get_param_object <- function(x) {
att <- attributes(x)
if (any(names(att) == "parameters")) {
res <- att$parameters
} else {
res <- NULL
}
res
}
get_param_columns <- function(x) {
prm <- get_param_object(x)
if (!is.null(prm)) {
res <- prm$id
} else {
dat <- collect_metrics(x)
other_names <- c(
".metric", ".estimator", "mean", "n",
"std_err", ".iter", ".config"
)
res <- names(dat)[!(names(dat) %in% other_names)]
}
res
}
# Use the user-given id for the parameter or the parameter label?
get_param_label <- function(x, id_val) {
x <- tibble::as_tibble(x)
y <- dplyr::filter(x, id == id_val) %>% dplyr::slice(1)
num_param <- sum(x$name == y$name)
no_special_id <- y$name == y$id
if (no_special_id && num_param == 1) {
res <- y$object[[1]]$label
} else {
res <- id_val
}
res
}
# ------------------------------------------------------------------------------
is_factorial <- function(x, cutoff = 0.95) {
n <- nrow(x)
p <- ncol(x)
vals <- purrr::map(x, unique)
full_fact <-
tidyr::crossing(!!!vals) %>%
dplyr::full_join(x %>% dplyr::mutate(..obs = 1), by = names(x))
mean(!is.na(full_fact$..obs)) >= cutoff
}
is_regular_grid <- function(grid) {
num_points <- nrow(grid)
p <- ncol(grid)
if (p == 1) {
return(TRUE)
}
if (p <= 5) {
ff <- is_factorial(grid)
if (ff) {
return(TRUE)
}
}
pct_unique <- purrr::map_int(grid, ~ length(unique(.x))) / num_points
max_pct_unique <- max(pct_unique, na.rm = TRUE)
np_ratio <- p / num_points
# Derived from simulation data and C5.0 tree
if (max_pct_unique > 1 / 2) res <- FALSE
if (max_pct_unique <= 1 / 2 & max_pct_unique <= 1 / 6) res <- TRUE
if (max_pct_unique <= 1 / 2 & max_pct_unique > 1 / 6 & np_ratio > 0.05) res <- TRUE
if (max_pct_unique <= 1 / 2 & max_pct_unique > 1 / 6 & np_ratio <= 0.05) res <- FALSE
res
}
# This will eventually change and use a parameters object.
use_regular_grid_plot <- function(x) {
dat <- collect_metrics(x)
param_cols <- get_param_columns(x)
grd <- dat %>%
dplyr::select(all_of(param_cols)) %>%
distinct()
is_regular_grid(grd)
}
# ------------------------------------------------------------------------------
plot_perf_vs_iter <- function(x, metric = NULL, width = NULL) {
if (is.null(width)) {
width <- max(x$.iter) / 75
}
x <- estimate_tune_results(x)
if (!is.null(metric)) {
x <- x %>% dplyr::filter(.metric %in% metric)
}
x <- x %>% dplyr::filter(!is.na(mean))
search_iter <-
x %>%
dplyr::filter(.iter > 0 & std_err > 0) %>%
dplyr::mutate(const = ifelse(n > 0, qt(0.975, n), 0))
p <-
ggplot(x, aes(x = .iter, y = mean)) +
geom_point() +
xlab("Iteration")
if (nrow(search_iter) > 0 & width > 0) {
p <-
p +
geom_errorbar(
data = search_iter,
aes(ymin = mean - const * std_err, ymax = mean + const * std_err),
width = width
)
}
if (length(unique(x$.metric)) > 1) {
p <- p + facet_wrap(~.metric, scales = "free_y")
} else {
p <- p + ylab(unique(x$.metric))
}
p
}
plot_param_vs_iter <- function(x) {
param_cols <- get_param_columns(x)
pset <- get_param_object(x)
if (is.null(pset)) {
rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
}
# ----------------------------------------------------------------------------
# Collect and filter resampling results
x <- estimate_tune_results(x)
is_num <- purrr::map_lgl(x %>% dplyr::select(dplyr::all_of(param_cols)), is.numeric)
num_param_cols <- param_cols[is_num]
# ----------------------------------------------------------------------------
# Transform and re-label when needed. Previous vectors of names are updated.
for (prm in param_cols) {
pobj <- pset$object[[which(pset$id == prm)]]
lab <- get_param_label(pset, prm)
if (!is.null(pobj$trans)) {
x[[prm]] <- pobj$trans$transform(x[[prm]])
new_name <- paste0(lab, " (", pobj$trans$name, ")")
} else {
new_name <- lab
}
names(x)[names(x) == prm] <- new_name
num_param_cols[num_param_cols == prm] <- new_name
param_cols[param_cols == prm] <- new_name
}
# ----------------------------------------------------------------------------
# Stack numeric columns for filtering
x <-
x %>%
dplyr::select(.iter, dplyr::all_of(num_param_cols)) %>%
tidyr::pivot_longer(cols = dplyr::all_of(num_param_cols))
# ------------------------------------------------------------------------------
p <-
ggplot(x, aes(x = .iter, y = value)) +
geom_point() +
xlab("Iteration") +
ylab("") +
facet_wrap(~name, scales = "free_y")
p
}
plot_marginals <- function(x, metric = NULL) {
param_cols <- get_param_columns(x)
pset <- get_param_object(x)
if (is.null(pset)) {
rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
}
# ----------------------------------------------------------------------------
# Collect and filter resampling results
is_race <- inherits(x, "tune_race")
x <- collect_metrics(x)
if (!is.null(metric)) {
x <- x %>% dplyr::filter(.metric %in% metric)
}
x <- x %>% dplyr::filter(!is.na(mean))
# ----------------------------------------------------------------------------
# Check types of parameters then sort by unique values
is_num <- purrr::map_lgl(x %>% dplyr::select(dplyr::all_of(param_cols)), is.numeric)
num_val <- purrr::map_int(x %>% dplyr::select(dplyr::all_of(param_cols)), ~ length(unique(.x)))
if (any(num_val < 2)) {
rm_param <- param_cols[num_val < 2]
param_cols <- param_cols[num_val >= 2]
is_num <- is_num[num_val >= 2]
x <- x %>% dplyr::select(-dplyr::all_of(rm_param))
}
if (any(!is_num)) {
num_param_cols <- param_cols[is_num]
chr_param_cols <- param_cols[!is_num]
if (length(chr_param_cols) > 1) {
rlang::abort("Currently cannot autoplot grids with 2+ non-numeric parameters.")
}
if (length(num_param_cols) == 0) {
rlang::abort("Currently cannot autoplot grids with only non-numeric parameters.")
}
num_val <- num_val[param_cols %in% chr_param_cols]
names(num_val) <- chr_param_cols
num_val <- sort(num_val, decreasing = TRUE)
} else {
num_param_cols <- param_cols
chr_param_cols <- character(0)
}
# ----------------------------------------------------------------------------
# Transform and re-label when needed. Previous vectors of names are updated.
for (prm in param_cols) {
pobj <- pset$object[[which(pset$id == prm)]]
lab <- get_param_label(pset, prm)
if (!is.null(pobj$trans)) {
x[[prm]] <- pobj$trans$transform(x[[prm]])
new_name <- paste0(lab, " (", pobj$trans$name, ")")
} else {
new_name <- lab
}
names(x)[names(x) == prm] <- new_name
num_param_cols[num_param_cols == prm] <- new_name
chr_param_cols[chr_param_cols == prm] <- new_name
param_cols[param_cols == prm] <- new_name
}
# ----------------------------------------------------------------------------
# Stack numeric parameters for faceting.
x <-
x %>%
dplyr::rename(`# resamples` = n) %>%
dplyr::select(dplyr::all_of(param_cols), mean, `# resamples`, .metric) %>%
tidyr::pivot_longer(cols = dplyr::all_of(num_param_cols))
# ----------------------------------------------------------------------------
p <- ggplot(x, aes(x = value, y = mean))
if (length(chr_param_cols) > 0) {
if (is_race) {
p <- p + geom_point(aes(col = !!sym(chr_param_cols), alpha = `# resamples`, size = resamples))
p <- p + ggplot2::labs(color = chr_param_cols)
} else {
p <- p + geom_point(aes(col = !!sym(chr_param_cols)), alpha = .7)
p <- p + ggplot2::labs(color = chr_param_cols)
}
} else {
if (is_race) {
p <- p + geom_point(aes(alpha = `# resamples`, size = `# resamples`))
} else {
p <- p + geom_point(alpha = .7)
}
}
if (length(unique(x$.metric)) > 1) {
if (length(num_param_cols) == 1) {
p <-
p +
facet_wrap(~.metric, scales = "free_y") +
xlab(num_param_cols) +
ylab("")
} else {
p <-
p +
ggplot2::facet_grid(.metric ~ name, scales = "free") +
xlab("") +
ylab("")
}
} else {
if (length(num_param_cols) == 1) {
p <- p + xlab(num_param_cols) + ylab(unique(x$.metric))
} else {
p <-
p +
facet_wrap(~name, scales = "free_x") +
xlab("") +
ylab(unique(x$.metric))
}
}
p
}
plot_regular_grid <- function(x, metric = NULL, ...) {
# Collect and filter resampling results
is_race <- inherits(x, "tune_race")
dat <- collect_metrics(x)
if (!is.null(metric)) {
dat <- dat %>% dplyr::filter(.metric %in% metric)
if (nrow(dat) == 0) {
rlang::abort(paste0(
"After filtering for metric '", metric, "', there were ",
"no data points."
))
}
}
dat <- dat %>% dplyr::filter(!is.na(mean))
multi_metrics <- length(unique(dat$.metric)) > 1
# ----------------------------------------------------------------------------
# Get information about parameters
param_cols <- get_param_columns(x)
pset <- get_param_object(x)
if (is.null(pset)) {
rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
}
grd <- dat %>% dplyr::select(all_of(param_cols))
# ----------------------------------------------------------------------------
# Determine which parameter goes on the x-axis and their types
is_num <- purrr::map_lgl(grd, is.numeric)
num_param_cols <- param_cols[is_num]
chr_param_cols <- param_cols[!is_num]
num_values <- purrr::map_int(grd[, num_param_cols], ~ length(unique(.x)))
num_values <- sort(num_values, decreasing = TRUE)
if (!any(is_num)) {
x_col <- chr_param_cols[1]
grp_cols <- chr_param_cols[-1]
} else {
x_col <- names(num_values)[1]
grp_cols <- c(chr_param_cols, names(num_values)[-1])
}
g <- length(grp_cols)
# ----------------------------------------------------------------------------
if (g > 5) {
return(plot_marginals(x, metric))
}
# ----------------------------------------------------------------------------
# Transform and re-label when needed. Previous vectors of names are updated.
x_col_prm <- pset$object[[which(pset$id == x_col)]]
if (inherits(x_col_prm, "quant_param") && !is.null(x_col_prm$trans)) {
trans <- x_col_prm$trans
} else {
trans <- NULL
}
for (prm in param_cols) {
pobj <- pset$object[[which(pset$id == prm)]]
new_name <- get_param_label(pset, prm)
names(dat)[names(dat) == prm] <- new_name
grp_cols[grp_cols == prm] <- new_name
x_col[x_col == prm] <- new_name
param_cols[param_cols == prm] <- new_name
}
# ----------------------------------------------------------------------------
dat <-
dat %>%
dplyr::rename(`# resamples` = n) %>%
dplyr::select(dplyr::all_of(param_cols), mean, `# resamples`, .metric) %>%
tidyr::pivot_longer(cols = dplyr::all_of(x_col))
# ------------------------------------------------------------------------------
if (g >= 1) {
# Here we know that there is at least one grouping parameter. For the first,
# we assign it to the color aesthetic. If it is numeric, it is converted to
# character using format().
col_col <- grp_cols[1]
if (is.numeric(dat[[col_col]])) {
dat[[col_col]] <- format(dat[[col_col]], ...)
}
col_col <- rlang::ensym(col_col)
p <- ggplot(dat, aes(value, y = mean,
col = {{col_col}}, group = {{col_col}}
))
# Since `col_col` has either the parameter id or the parameter label, use
# is in the key:
p <- p + ggplot2::labs(color = col_col, x = x_col)
if (g >= 2) {
# Since there at 2 - 5 grouping parameters, the others are assigned to
# column facets. Row facets will be for performance metrics.
facets <- grp_cols[-1]
facets <- purrr::map(facets, sym)
facets <- rlang::quos(!!!facets)
# faceting variables
if (multi_metrics) {
p <- p + facet_grid(
rows = vars(.metric), vars(!!!facets),
labeller = ggplot2::labeller(.cols = ggplot2::label_both),
scales = "free_y"
)
} else {
p <-
p + facet_wrap(vars(!!!facets),
labeller = ggplot2::labeller(.cols = ggplot2::label_both)
)
}
} else if (multi_metrics) {
p <- p + facet_grid(rows = vars(.metric), scales = "free_y")
}
} else {
# Only a single parameter and potentially multiple metrics.
p <- ggplot(dat, aes(x = value, y = mean))
if (multi_metrics) {
p <- p + facet_wrap(~.metric, scales = "free_y", ncol = 1)
}
}
if (is_race) {
p <- p + geom_point(aes(alpha = `# resamples`, size = `# resamples`))
} else {
p <- p + geom_point(size = 1)
}
if (multi_metrics) {
p <- p + ylab("")
} else {
dat$.metric[1]
p <- p + ylab(dat$.metric[1])
}
if (nrow(pset) == 1) {
x_lab <- pset$object[[1]]$label
p <- p + xlab(x_lab)
}
if (any(is_num)) {
p <- p + geom_line()
}
if (!is.null(trans)) {
p <- p + ggplot2::scale_x_continuous(trans = trans)
}
p
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.