R/rank_results.R

Defines functions get_num_resamples rank_results

Documented in rank_results

#' Rank the results by a metric
#'
#' This function sorts the results by a specific performance metric.
#'
#' @inheritParams collect_metrics.workflow_set
#' @param rank_metric A character string for a metric.
#' @param select_best A logical giving whether the results should only contain
#' the numerically best submodel per workflow.
#' @details
#' If some models have the exact same performance,
#' `rank(value, ties.method = "random")` is used (with a reproducible seed) so
#' that all ranks are integers.
#'
#' No columns are returned for the tuning parameters since they are likely to
#' be different (or not exist) for some models. The `wflow_id` and `.config`
#' columns can be used to determine the corresponding parameter values.
#' @return A tibble with columns: `wflow_id`, `.config`, `.metric`, `mean`,
#' `std_err`, `n`, `preprocessor`, `model`, and `rank`.
#'
#' @includeRmd man-roxygen/example_data.Rmd note
#'
#' @examples
#' chi_features_res
#'
#' rank_results(chi_features_res)
#' rank_results(chi_features_res, select_best = TRUE)
#' rank_results(chi_features_res, rank_metric = "rsq")
#' @export
rank_results <- function(x, rank_metric = NULL, select_best = FALSE) {
  metric_info <- pick_metric(x, rank_metric)
  metric <- metric_info$metric
  direction <- metric_info$direction
  wflow_info <- dplyr::bind_cols(purrr::map_dfr(x$info, I), dplyr::select(x, wflow_id))

  results <- collect_metrics(x) %>%
    dplyr::select(wflow_id, .config, .metric, mean, std_err, n) %>%
    dplyr::full_join(wflow_info, by = "wflow_id") %>%
    dplyr::select(-comment, -workflow)

  types <- x %>%
    dplyr::full_join(wflow_info, by = "wflow_id") %>%
    dplyr::mutate(
      is_race = purrr::map_lgl(result, ~ inherits(.x, "tune_race")),
      num_rs = purrr::map_int(result, get_num_resamples)
    ) %>%
    dplyr::select(wflow_id, is_race, num_rs)

  ranked <-
    dplyr::full_join(results, types, by = "wflow_id") %>%
    dplyr::filter(.metric == metric)

  if (any(ranked$is_race)) {
    # remove any racing results with less resamples than the total number
    rm_rows <-
      ranked %>%
      dplyr::filter(is_race & (num_rs > n)) %>%
      dplyr::select(wflow_id, .config) %>%
      dplyr::distinct()
    if (nrow(rm_rows) > 0) {
      ranked <- dplyr::anti_join(ranked, rm_rows, by = c("wflow_id", ".config"))
      results <- dplyr::anti_join(results, rm_rows, by = c("wflow_id", ".config"))
    }
  }

  if (direction == "maximize") {
    ranked$mean <- -ranked$mean
  }

  if (select_best) {
    best_by_wflow <-
      dplyr::group_by(ranked, wflow_id) %>%
      dplyr::slice_min(mean, with_ties = FALSE) %>%
      dplyr::ungroup() %>%
      dplyr::select(wflow_id, .config)
    ranked <- dplyr::inner_join(ranked, best_by_wflow, by = c("wflow_id", ".config"))
  }

  # ensure reproducible rankings when there are ties
  withr::with_seed(
    1,
    {
      ranked <-
        ranked %>%
        dplyr::mutate(rank = rank(mean, ties.method = "random")) %>%
        dplyr::select(wflow_id, .config, rank)
    }
  )

  dplyr::inner_join(results, ranked, by = c("wflow_id", ".config")) %>%
    dplyr::arrange(rank) %>%
    dplyr::rename(preprocessor = preproc)
}

get_num_resamples <- function(x) {
  purrr::map_dfr(x$splits, ~ .x$id) %>%
    dplyr::distinct() %>%
    nrow()
}

Try the workflowsets package in your browser

Any scripts or data that you put into this service are public.

workflowsets documentation built on April 7, 2023, 1:05 a.m.