R/plot.R

Defines functions tidy_perf_data plot_model_performance

Documented in plot_model_performance tidy_perf_data

#' Plot performance metrics for multiple ML runs with different parameters
#'
#' ggplot2 is required to use this function.
#'
#' @param performance_df dataframe of performance results from multiple calls to `run_ml()`
#'
#' @return A ggplot2 plot of performance.

#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#'   run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#'   result[["performance"]]
#' }) %>%
#'   dplyr::bind_rows()
#' # plot the performance results
#' p <- plot_model_performance(perf_df)
#'
#'
#' # call `run_ml()` with different ML methods
#' param_grid <- expand.grid(
#'   seeds = seq(100, 104),
#'   methods = c("glmnet", "rf")
#' )
#' results_mtx <- mapply(
#'   function(seed, method) {
#'     run_ml(otu_mini_bin, method, seed = seed, kfold = 2)
#'   },
#'   param_grid$seeds, param_grid$methods
#' )
#' # extract and combine the performance results
#' perf_df2 <- dplyr::bind_rows(results_mtx["performance", ])
#' # plot the performance results
#' p <- plot_model_performance(perf_df2)
#'
#' # you can continue adding layers to customize the plot
#' p +
#'   theme_classic() +
#'   scale_color_brewer(palette = "Dark2") +
#'   coord_flip()
#' }
plot_model_performance <- function(performance_df) {
  abort_packages_not_installed("ggplot2", "tidyr")
  performance_df %>%
    tidy_perf_data() %>%
    ggplot2::ggplot(ggplot2::aes(x = .data$method, y = .data$value, color = .data$metric)) +
    ggplot2::geom_boxplot() +
    ggplot2::geom_hline(yintercept = 0.5, linetype = "dashed") +
    ggplot2::ylim(0, 1) +
    ggplot2::labs(y = "Performance", x = NULL) +
    ggplot2::theme(legend.title = ggplot2::element_blank())
}

#' Tidy the performance dataframe
#'
#' Used by `plot_model_performance()`.
#'
#' @inheritParams plot_model_performance
#' @return Tidy dataframe with model performance metrics.
#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#'   run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#'   result[["performance"]]
#' }) %>%
#'   dplyr::bind_rows()
#' # make it pretty!
#' tidy_perf_data(perf_df)
#' }
tidy_perf_data <- function(performance_df) {
  abort_packages_not_installed("tidyr")
  cv_colname <- performance_df %>%
    dplyr::select(dplyr::starts_with("cv_metric_")) %>%
    colnames()
  test_colname <- cv_colname %>%
    gsub("cv_metric_", "", .)
  return(performance_df %>%
    dplyr::select(.data[["method"]], .data[[cv_colname]], .data[[test_colname]]) %>%
    tidyr::pivot_longer(
      cols = c(.data[[cv_colname]], .data[[test_colname]]),
      names_to = "metric"
    ) %>%
    dplyr::mutate(metric = dplyr::case_when(
      startsWith(metric, "cv_metric_") ~ gsub("cv_metric_", "CV ", metric),
      TRUE ~ paste("Test", metric)
    )))
}

#' Get hyperparameter performance metrics
#'
#' @param trained_model trained model (e.g. from `run_ml()`)
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters.
#' - `params`: Hyperparameters tuned.
#' - `metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
get_hp_performance <- function(trained_model) {
  metric <- trained_model$metric
  dat <- trained_model$results %>%
    dplyr::select(
      dplyr::all_of(trained_model$modelInfo$parameters$parameter),
      dplyr::all_of(metric)
    )
  params <- sapply(dat, function(x) length(unique(x)) > 1) %>%
    Filter(isTRUE, .) %>%
    names() %>%
    Filter(function(x) x != metric, .)
  return(list(
    dat = dat,
    params = params,
    metric = metric
  ))
}

#' Combine hyperparameter performance metrics for multiple train/test splits
#'
#' Combine hyperparameter performance metrics for multiple train/test splits generated by, for instance, [looping in R](http://www.schlosslab.org/mikropml/articles/parallel.html) or using a [snakemake workflow](https://github.com/SchlossLab/mikropml-snakemake-workflow) on a high-performance computer.
#'
#' @param trained_model_lst List of trained models.
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters
#' - `params`: Hyperparameters tuned.
#' - `Metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' results <- lapply(seq(100, 102), function(seed) {
#'   run_ml(otu_small, "glmnet", seed = seed, cv_times = 2, kfold = 2)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' combine_hp_performance(models)
#' }
combine_hp_performance <- function(trained_model_lst) {
  abort_packages_not_installed("purrr")
  # TODO: can we do this without purrr so we don't have to add a new dep?
  dat_params <- lapply(trained_model_lst, function(x) get_hp_performance(x)) %>%
    purrr::transpose()
  dat <- dplyr::bind_rows(dat_params$dat)
  return(list(
    dat = dat,
    params = unique(unlist(dat_params$params)),
    metric = unique(unlist(dat_params$metric))
  ))
}

#' Plot hyperparameter performance metrics
#'
#' @param dat dataframe of hyperparameters and performance metric (e.g. from `get_hp_performance()` or `combine_hp_performance()`)
#' @param param_col hyperparameter to be plotted. must be a column in `dat`.
#' @param metric_col performance metric. must be a column in `dat`.
#'
#' @return ggplot of hyperparameter performance.
#'
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' # plot for a single `run_ml()` call
#' hp_metrics <- get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
#' hp_metrics
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' \dontrun{
#' # plot for multiple `run_ml()` calls
#' results <- lapply(seq(100, 102), function(seed) {
#'   run_ml(otu_small, "glmnet", seed = seed)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' hp_metrics <- combine_hp_performance(models)
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' }
plot_hp_performance <- function(dat, param_col, metric_col) {
  abort_packages_not_installed("ggplot2")
  mean_colname <- paste0("mean_", rlang::as_name(rlang::enquo(metric_col)))
  sd_colname <- paste0("sd_", rlang::as_name(rlang::enquo(metric_col)))
  dat_sum <- dat %>%
    dplyr::group_by({{ param_col }}) %>%
    dplyr::summarise("mean_{{ metric_col }}" := mean({{ metric_col }}),
      "sd_{{ metric_col }}" := stats::sd({{ metric_col }}),
      # is there a less repetitive way to do this cleanly?
      ymin_metric = !!rlang::sym(mean_colname) - !!rlang::sym(sd_colname),
      ymax_metric = !!rlang::sym(mean_colname) + !!rlang::sym(sd_colname)
    )
  return(dat_sum %>%
    ggplot2::ggplot(ggplot2::aes(
      x = {{ param_col }},
      y = !!rlang::sym(mean_colname)
    )) +
    ggplot2::geom_line() +
    ggplot2::geom_point() +
    ggplot2::geom_errorbar(
      ggplot2::aes(
        ymin = .data$ymin_metric,
        ymax = .data$ymax_metric
      ),
      width = .001
    ))
}


#' Get plot layers shared by `plot_mean_roc` and `plot_mean_prc`
#'
#' @param ribbon_fill ribbon fill color (default: "#D9D9D9")
#' @param line_color  line color (default: "#000000")
#'
#' @return list of ggproto objects to add to a ggplot
#'
#' @keywords internal
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
shared_ggprotos <- function(ribbon_fill = "#D9D9D9",
                            line_color = "#000000") {
  return(list(
    ggplot2::geom_ribbon(fill = ribbon_fill),
    ggplot2::geom_line(color = line_color),
    ggplot2::coord_equal(),
    ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)),
    ggplot2::theme_bw(),
    ggplot2::theme(legend.title = ggplot2::element_blank())
  ))
}

#' @describeIn plot_curves Plot mean sensitivity over specificity
#'
#' @inheritParams shared_ggprotos
#' @param dat sensitivity, specificity, and precision data calculated by `calc_mean_roc()`
#'
#' @export
plot_mean_roc <- function(dat,
                          ribbon_fill = "#C6DBEF", line_color = "#08306B") {
  specificity <- mean_sensitivity <- lower <- upper <- NULL
  abort_packages_not_installed("ggplot2")
  dat %>%
    ggplot2::ggplot(ggplot2::aes(
      x = specificity, y = mean_sensitivity,
      ymin = lower, ymax = upper
    )) +
    shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
    ggplot2::geom_abline(intercept = 1, slope = 1, linetype = "dashed", color = "grey50") +
    ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) +
    ggplot2::labs(x = "Specificity", y = "Mean Sensitivity")
}

#' @describeIn plot_curves Plot mean precision over recall
#'
#' @inheritParams shared_ggprotos
#' @inheritParams plot_mean_roc
#' @param baseline_precision baseline precision from `calc_baseline_precision()`
#' @param ycol column for the y axis (Default: `mean_precision`)
#'
#' @export
plot_mean_prc <- function(dat, baseline_precision = NULL, ycol = mean_precision,
                          ribbon_fill = "#C7E9C0", line_color = "#00441B") {
  recall <- mean_precision <- lower <- upper <- NULL
  abort_packages_not_installed("ggplot2")
  prc_plot <- dat %>%
    ggplot2::ggplot(ggplot2::aes(
      x = recall, y = {{ ycol }},
      ymin = lower, ymax = upper
    )) +
    shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
    ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) +
    ggplot2::labs(x = "Recall", y = "Mean Precision")
  if (!is.null(baseline_precision)) {
    prc_plot <- prc_plot +
      ggplot2::geom_hline(
        yintercept = baseline_precision,
        linetype = "dashed", color = "grey50"
      )
  }
  return(prc_plot)
}

#' @name plot_curves
#' @title Plot ROC and PRC curves
#'
#' @author Courtney Armour
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' library(dplyr)
#' # get performance for multiple models
#' get_sensspec_seed <- function(seed) {
#'   ml_result <- run_ml(otu_mini_bin, "glmnet", seed = seed)
#'   sensspec <- calc_model_sensspec(
#'     ml_result$trained_model,
#'     ml_result$test_data,
#'     "dx"
#'   ) %>%
#'     mutate(seed = seed)
#'   return(sensspec)
#' }
#' sensspec_dat <- purrr::map_dfr(seq(100, 102), get_sensspec_seed)
#'
#' # plot ROC & PRC
#' sensspec_dat %>%
#'   calc_mean_roc() %>%
#'   plot_mean_roc()
#' baseline_prec <- calc_baseline_precision(otu_mini_bin, "dx", "cancer")
#' sensspec_dat %>%
#'   calc_mean_prc() %>%
#'   plot_mean_prc(baseline_precision = baseline_prec)
#' }
NULL
SchlossLab/mikropml documentation built on Aug. 24, 2023, 9:51 p.m.