R/plots.R

Defines functions plot_regular_grid plot_marginals plot_param_vs_iter plot_perf_vs_iter use_regular_grid_plot is_regular_grid is_factorial get_param_label get_param_columns get_param_object autoplot.resample_results autoplot.tune_results

Documented in autoplot.tune_results

#' Plot tuning search results
#'
#' @param object A tibble of results from [tune_grid()] or [tune_bayes()].
#' @param type A single character value. Choices are `"marginals"` (for a plot
#'  of each predictor versus performance; see Details below), `"parameters"`
#'  (each parameter versus search iteration), or `"performance"` (performance
#'  versus iteration). The latter two choices are only used for [tune_bayes()].
#' @param metric A character vector or `NULL` for which metric to plot. By
#' default, all metrics will be shown via facets.
#' @param width A number for the width of the confidence interval bars when
#' `type = "performance"`. A value of zero prevents them from being shown.
#' @param ... For plots with a regular grid, this is passed to `format()` and is
#' applied to a parameter used to color points. Otherwise, it is not used.
#' @return A `ggplot2` object.
#' @details
#'
#' When the results of `tune_grid()` are used with `autoplot()`, it tries to
#'  determine whether a _regular grid_ was used.
#'
#' ## Regular grids
#'
#'   For regular grids with one or more numeric tuning parameters, the parameter
#'  with the most unique values is used on the x-axis. If there are categorical
#'  parameters, the first is used to color the geometries. All other parameters
#'  are used in column faceting.
#'
#'   The plot has the performance metric(s) on the y-axis. If there are multiple
#'  metrics, these are row-faceted.
#'
#'   If there are more than five tuning parameters, the "marginal effects" plots
#'  are used instead.
#'
#' ## Irregular grids
#'
#' For space-filling or random grids, a _marginal_ effect plot is created. A
#'  panel is made for each numeric parameter so that each parameter is on the
#'  x-axis and performance is on the y-xis. If there are multiple metrics, these
#'  are row-faceted.
#'
#' A single categorical parameter is shown as colors. If there are two or more
#'  non-numeric parameters, an error is given. A similar result occurs is only
#'  non-numeric parameters are in the grid. In these cases, we suggest using
#'  `collect_metrics()` and `ggplot()` to create a plot that is appropriate for
#'  the data.
#'
#' If a parameter has an associated transformation associated with it (as
#' determined by the parameter object used to create it), the plot shows the
#' values in the transformed units (and is labeled with the transformation type).
#'
#' Parameters are labeled using the labels found in the parameter object
#' _except_ when an identifier was used (e.g. `neighbors = tune("K")`).
#'
#' @seealso [tune_grid()], [tune_bayes()]
#' @examplesIf tune:::should_run_examples()
#' # For grid search:
#' data("example_ames_knn")
#'
#' # Plot the tuning parameter values versus performance
#' autoplot(ames_grid_search, metric = "rmse")
#'
#'
#' # For iterative search:
#' # Plot the tuning parameter values versus performance
#' autoplot(ames_iter_search, metric = "rmse", type = "marginals")
#'
#' # Plot tuning parameters versus iterations
#' autoplot(ames_iter_search, metric = "rmse", type = "parameters")
#'
#' # Plot performance over iterations
#' autoplot(ames_iter_search, metric = "rmse", type = "performance")
#' @export
autoplot.tune_results <-
  function(object,
           type = c("marginals", "parameters", "performance"),
           metric = NULL,
           width = NULL,
           ...) {
    type <- match.arg(type)
    has_iter <- any(names(object) == ".iter")
    if (!has_iter && type != "marginals") {
      rlang::abort(paste0("`type = ", type, "` is only used iterative search results."))
    }
    pset <- .get_tune_parameters(object)
    if (any(is.na(pset$object))) {
      p_names <- pset$id[is.na(pset$object)]
      msg <-
        paste0(
          "Some parameters do not have corresponding parameter objects ",
          "and cannot be used with `autoplot()`: ",
          paste0("'", p_names, "'", collapse = ", ")
        )
      rlang::abort(msg)
    }

    if (type == "parameters") {
      p <- plot_param_vs_iter(object)
    } else {
      if (type == "performance") {
        p <- plot_perf_vs_iter(object, metric, width)
      } else {
        if (use_regular_grid_plot(object)) {
          p <- plot_regular_grid(object, metric = metric, ...)
        } else {
          p <- plot_marginals(object, metric = metric)
        }
      }
    }
    p
  }

#' @export
autoplot.resample_results <- function(object, ...) {
  rlang::abort("There is no `autoplot()` implementation for `resample_results`.")
}

# ------------------------------------------------------------------------------

get_param_object <- function(x) {
  att <- attributes(x)
  if (any(names(att) == "parameters")) {
    res <- att$parameters
  } else {
    res <- NULL
  }
  res
}

get_param_columns <- function(x) {
  prm <- get_param_object(x)
  if (!is.null(prm)) {
    res <- prm$id
  } else {
    dat <- collect_metrics(x)
    other_names <- c(
      ".metric", ".estimator", "mean", "n",
      "std_err", ".iter", ".config"
    )
    res <- names(dat)[!(names(dat) %in% other_names)]
  }
  res
}

# Use the user-given id for the parameter or the parameter label?
get_param_label <- function(x, id_val) {
  x <- tibble::as_tibble(x)
  y <- dplyr::filter(x, id == id_val) %>% dplyr::slice(1)
  num_param <- sum(x$name == y$name)
  no_special_id <- y$name == y$id
  if (no_special_id && num_param == 1) {
    res <- y$object[[1]]$label
  } else {
    res <- id_val
  }
  res
}

# ------------------------------------------------------------------------------

is_factorial <- function(x, cutoff = 0.95) {
  n <- nrow(x)
  p <- ncol(x)
  vals <- purrr::map(x, unique)
  full_fact <-
    tidyr::crossing(!!!vals) %>%
    dplyr::full_join(x %>% dplyr::mutate(..obs = 1), by = names(x))
  mean(!is.na(full_fact$..obs)) >= cutoff
}


is_regular_grid <- function(grid) {
  num_points <- nrow(grid)
  p <- ncol(grid)

  if (p == 1) {
    return(TRUE)
  }

  if (p <= 5) {
    ff <- is_factorial(grid)
    if (ff) {
      return(TRUE)
    }
  }

  pct_unique <- purrr::map_int(grid, ~ length(unique(.x))) / num_points
  max_pct_unique <- max(pct_unique, na.rm = TRUE)
  np_ratio <- p / num_points

  # Derived from simulation data and C5.0 tree
  if (max_pct_unique > 1 / 2) res <- FALSE
  if (max_pct_unique <= 1 / 2 & max_pct_unique <= 1 / 6) res <- TRUE
  if (max_pct_unique <= 1 / 2 & max_pct_unique > 1 / 6 & np_ratio > 0.05) res <- TRUE
  if (max_pct_unique <= 1 / 2 & max_pct_unique > 1 / 6 & np_ratio <= 0.05) res <- FALSE
  res
}

# This will eventually change and use a parameters object.
use_regular_grid_plot <- function(x) {
  dat <- collect_metrics(x)
  param_cols <- get_param_columns(x)
  grd <- dat %>%
    dplyr::select(all_of(param_cols)) %>%
    distinct()
  is_regular_grid(grd)
}

# ------------------------------------------------------------------------------

plot_perf_vs_iter <- function(x, metric = NULL, width = NULL) {
  if (is.null(width)) {
    width <- max(x$.iter) / 75
  }
  x <- estimate_tune_results(x)
  if (!is.null(metric)) {
    x <- x %>% dplyr::filter(.metric %in% metric)
  }
  x <- x %>% dplyr::filter(!is.na(mean))

  search_iter <-
    x %>%
    dplyr::filter(.iter > 0 & std_err > 0) %>%
    dplyr::mutate(const = ifelse(n > 0, qt(0.975, n), 0))

  p <-
    ggplot(x, aes(x = .iter, y = mean)) +
    geom_point() +
    xlab("Iteration")

  if (nrow(search_iter) > 0 & width > 0) {
    p <-
      p +
      geom_errorbar(
        data = search_iter,
        aes(ymin = mean - const * std_err, ymax = mean + const * std_err),
        width = width
      )
  }

  if (length(unique(x$.metric)) > 1) {
    p <- p + facet_wrap(~.metric, scales = "free_y")
  } else {
    p <- p + ylab(unique(x$.metric))
  }
  p
}

plot_param_vs_iter <- function(x) {
  param_cols <- get_param_columns(x)
  pset <- get_param_object(x)
  if (is.null(pset)) {
    rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
  }

  # ----------------------------------------------------------------------------
  # Collect and filter resampling results

  x <- estimate_tune_results(x)
  is_num <- purrr::map_lgl(x %>% dplyr::select(dplyr::all_of(param_cols)), is.numeric)
  num_param_cols <- param_cols[is_num]

  # ----------------------------------------------------------------------------
  # Transform and re-label when needed. Previous vectors of names are updated.

  for (prm in param_cols) {
    pobj <- pset$object[[which(pset$id == prm)]]
    lab <- get_param_label(pset, prm)
    if (!is.null(pobj$trans)) {
      x[[prm]] <- pobj$trans$transform(x[[prm]])
      new_name <- paste0(lab, " (", pobj$trans$name, ")")
    } else {
      new_name <- lab
    }
    names(x)[names(x) == prm] <- new_name
    num_param_cols[num_param_cols == prm] <- new_name
    param_cols[param_cols == prm] <- new_name
  }

  # ----------------------------------------------------------------------------
  # Stack numeric columns for filtering

  x <-
    x %>%
    dplyr::select(.iter, dplyr::all_of(num_param_cols)) %>%
    tidyr::pivot_longer(cols = dplyr::all_of(num_param_cols))

  # ------------------------------------------------------------------------------

  p <-
    ggplot(x, aes(x = .iter, y = value)) +
    geom_point() +
    xlab("Iteration") +
    ylab("") +
    facet_wrap(~name, scales = "free_y")

  p
}

plot_marginals <- function(x, metric = NULL) {
  param_cols <- get_param_columns(x)
  pset <- get_param_object(x)
  if (is.null(pset)) {
    rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
  }

  # ----------------------------------------------------------------------------
  # Collect and filter resampling results

  is_race <- inherits(x, "tune_race")

  x <- collect_metrics(x)
  if (!is.null(metric)) {
    x <- x %>% dplyr::filter(.metric %in% metric)
  }
  x <- x %>% dplyr::filter(!is.na(mean))

  # ----------------------------------------------------------------------------
  # Check types of parameters then sort by unique values

  is_num <- purrr::map_lgl(x %>% dplyr::select(dplyr::all_of(param_cols)), is.numeric)
  num_val <- purrr::map_int(x %>% dplyr::select(dplyr::all_of(param_cols)), ~ length(unique(.x)))

  if (any(num_val < 2)) {
    rm_param <- param_cols[num_val < 2]
    param_cols <- param_cols[num_val >= 2]
    is_num <- is_num[num_val >= 2]
    x <- x %>% dplyr::select(-dplyr::all_of(rm_param))
  }

  if (any(!is_num)) {
    num_param_cols <- param_cols[is_num]
    chr_param_cols <- param_cols[!is_num]
    if (length(chr_param_cols) > 1) {
      rlang::abort("Currently cannot autoplot grids with 2+ non-numeric parameters.")
    }
    if (length(num_param_cols) == 0) {
      rlang::abort("Currently cannot autoplot grids with only non-numeric parameters.")
    }
    num_val <- num_val[param_cols %in% chr_param_cols]
    names(num_val) <- chr_param_cols
    num_val <- sort(num_val, decreasing = TRUE)
  } else {
    num_param_cols <- param_cols
    chr_param_cols <- character(0)
  }

  # ----------------------------------------------------------------------------
  # Transform and re-label when needed. Previous vectors of names are updated.

  for (prm in param_cols) {
    pobj <- pset$object[[which(pset$id == prm)]]
    lab <- get_param_label(pset, prm)
    if (!is.null(pobj$trans)) {
      x[[prm]] <- pobj$trans$transform(x[[prm]])
      new_name <- paste0(lab, " (", pobj$trans$name, ")")
    } else {
      new_name <- lab
    }
    names(x)[names(x) == prm] <- new_name
    num_param_cols[num_param_cols == prm] <- new_name
    chr_param_cols[chr_param_cols == prm] <- new_name
    param_cols[param_cols == prm] <- new_name
  }

  # ----------------------------------------------------------------------------
  # Stack numeric parameters for faceting.

  x <-
    x %>%
    dplyr::rename(`# resamples` = n) %>%
    dplyr::select(dplyr::all_of(param_cols), mean, `# resamples`, .metric) %>%
    tidyr::pivot_longer(cols = dplyr::all_of(num_param_cols))

  # ----------------------------------------------------------------------------

  p <- ggplot(x, aes(x = value, y = mean))

  if (length(chr_param_cols) > 0) {
    if (is_race) {
      p <- p + geom_point(aes(col = !!sym(chr_param_cols), alpha = `# resamples`, size = resamples))
      p <- p + ggplot2::labs(color = chr_param_cols)
    } else {
      p <- p + geom_point(aes(col = !!sym(chr_param_cols)), alpha = .7)
      p <- p + ggplot2::labs(color = chr_param_cols)
    }
  } else {
    if (is_race) {
      p <- p + geom_point(aes(alpha = `# resamples`, size = `# resamples`))
    } else {
      p <- p + geom_point(alpha = .7)
    }
  }

  if (length(unique(x$.metric)) > 1) {
    if (length(num_param_cols) == 1) {
      p <-
        p +
        facet_wrap(~.metric, scales = "free_y") +
        xlab(num_param_cols) +
        ylab("")
    } else {
      p <-
        p +
        ggplot2::facet_grid(.metric ~ name, scales = "free") +
        xlab("") +
        ylab("")
    }
  } else {
    if (length(num_param_cols) == 1) {
      p <- p + xlab(num_param_cols) + ylab(unique(x$.metric))
    } else {
      p <-
        p +
        facet_wrap(~name, scales = "free_x") +
        xlab("") +
        ylab(unique(x$.metric))
    }
  }

  p
}


plot_regular_grid <- function(x, metric = NULL, ...) {
  # Collect and filter resampling results

  is_race <- inherits(x, "tune_race")

  dat <- collect_metrics(x)
  if (!is.null(metric)) {
    dat <- dat %>% dplyr::filter(.metric %in% metric)
    if (nrow(dat) == 0) {
      rlang::abort(paste0(
        "After filtering for metric '", metric, "', there were ",
        "no data points."
      ))
    }
  }
  dat <- dat %>% dplyr::filter(!is.na(mean))
  multi_metrics <- length(unique(dat$.metric)) > 1

  # ----------------------------------------------------------------------------
  # Get information about parameters

  param_cols <- get_param_columns(x)
  pset <- get_param_object(x)
  if (is.null(pset)) {
    rlang::abort("`autoplot()` requires objects made with tune version 0.1.0 or later.")
  }

  grd <- dat %>% dplyr::select(all_of(param_cols))

  # ----------------------------------------------------------------------------
  # Determine which parameter goes on the x-axis and their types

  is_num <- purrr::map_lgl(grd, is.numeric)
  num_param_cols <- param_cols[is_num]
  chr_param_cols <- param_cols[!is_num]

  num_values <- purrr::map_int(grd[, num_param_cols], ~ length(unique(.x)))
  num_values <- sort(num_values, decreasing = TRUE)

  if (!any(is_num)) {
    x_col <- chr_param_cols[1]
    grp_cols <- chr_param_cols[-1]
  } else {
    x_col <- names(num_values)[1]
    grp_cols <- c(chr_param_cols, names(num_values)[-1])
  }

  g <- length(grp_cols)

  # ----------------------------------------------------------------------------

  if (g > 5) {
    return(plot_marginals(x, metric))
  }

  # ----------------------------------------------------------------------------
  # Transform and re-label when needed. Previous vectors of names are updated.

  x_col_prm <- pset$object[[which(pset$id == x_col)]]
  if (inherits(x_col_prm, "quant_param") && !is.null(x_col_prm$trans)) {
    trans <- x_col_prm$trans
  } else {
    trans <- NULL
  }

  for (prm in param_cols) {
    pobj <- pset$object[[which(pset$id == prm)]]
    new_name <- get_param_label(pset, prm)
    names(dat)[names(dat) == prm] <- new_name
    grp_cols[grp_cols == prm] <- new_name
    x_col[x_col == prm] <- new_name
    param_cols[param_cols == prm] <- new_name
  }

  # ----------------------------------------------------------------------------

  dat <-
    dat %>%
    dplyr::rename(`# resamples` = n) %>%
    dplyr::select(dplyr::all_of(param_cols), mean, `# resamples`, .metric) %>%
    tidyr::pivot_longer(cols = dplyr::all_of(x_col))

  # ------------------------------------------------------------------------------

  if (g >= 1) {
    # Here we know that there is at least one grouping parameter. For the first,
    # we assign it to the color aesthetic. If it is numeric, it is converted to
    # character using format().
    col_col <- grp_cols[1]
    if (is.numeric(dat[[col_col]])) {
      dat[[col_col]] <- format(dat[[col_col]], ...)
    }
    col_col <- rlang::ensym(col_col)
    p <- ggplot(dat, aes(value, y = mean,
      col = {{col_col}}, group = {{col_col}}
    ))
    # Since `col_col` has either the parameter id or the parameter label, use
    # is in the key:

    p <- p + ggplot2::labs(color = col_col, x = x_col)

    if (g >= 2) {
      # Since there at 2 - 5 grouping parameters, the others are assigned to
      # column facets. Row facets will be for performance metrics.
      facets <- grp_cols[-1]
      facets <- purrr::map(facets, sym)
      facets <- rlang::quos(!!!facets)
      # faceting variables
      if (multi_metrics) {
        p <- p + facet_grid(
          rows = vars(.metric), vars(!!!facets),
          labeller = ggplot2::labeller(.cols = ggplot2::label_both),
          scales = "free_y"
        )
      } else {
        p <-
          p + facet_wrap(vars(!!!facets),
            labeller = ggplot2::labeller(.cols = ggplot2::label_both)
          )
      }
    } else if (multi_metrics) {
      p <- p + facet_grid(rows = vars(.metric), scales = "free_y")
    }
  } else {
    # Only a single parameter and potentially multiple metrics.
    p <- ggplot(dat, aes(x = value, y = mean))
    if (multi_metrics) {
      p <- p + facet_wrap(~.metric, scales = "free_y", ncol = 1)
    }
  }

  if (is_race) {
    p <- p + geom_point(aes(alpha = `# resamples`, size = `# resamples`))
  } else {
    p <- p + geom_point(size = 1)
  }

  if (multi_metrics) {
    p <- p + ylab("")
  } else {
    dat$.metric[1]
    p <- p + ylab(dat$.metric[1])
  }
  if (nrow(pset) == 1) {
    x_lab <- pset$object[[1]]$label
    p <- p + xlab(x_lab)
  }

  if (any(is_num)) {
    p <- p + geom_line()
  }

  if (!is.null(trans)) {
    p <- p + ggplot2::scale_x_continuous(trans = trans)
  }

  p
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Aug. 24, 2023, 1:09 a.m.