#' Plot performance metrics for multiple ML runs with different parameters
#'
#' ggplot2 is required to use this function.
#'
#' @param performance_df dataframe of performance results from multiple calls to `run_ml()`
#'
#' @return A ggplot2 plot of performance.
#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#' result[["performance"]]
#' }) %>%
#' dplyr::bind_rows()
#' # plot the performance results
#' p <- plot_model_performance(perf_df)
#'
#'
#' # call `run_ml()` with different ML methods
#' param_grid <- expand.grid(
#' seeds = seq(100, 104),
#' methods = c("glmnet", "rf")
#' )
#' results_mtx <- mapply(
#' function(seed, method) {
#' run_ml(otu_mini_bin, method, seed = seed, kfold = 2)
#' },
#' param_grid$seeds, param_grid$methods
#' )
#' # extract and combine the performance results
#' perf_df2 <- dplyr::bind_rows(results_mtx["performance", ])
#' # plot the performance results
#' p <- plot_model_performance(perf_df2)
#'
#' # you can continue adding layers to customize the plot
#' p +
#' theme_classic() +
#' scale_color_brewer(palette = "Dark2") +
#' coord_flip()
#' }
plot_model_performance <- function(performance_df) {
abort_packages_not_installed("ggplot2", "tidyr")
performance_df %>%
tidy_perf_data() %>%
ggplot2::ggplot(ggplot2::aes(x = .data$method, y = .data$value, color = .data$metric)) +
ggplot2::geom_boxplot() +
ggplot2::geom_hline(yintercept = 0.5, linetype = "dashed") +
ggplot2::ylim(0, 1) +
ggplot2::labs(y = "Performance", x = NULL) +
ggplot2::theme(legend.title = ggplot2::element_blank())
}
#' Tidy the performance dataframe
#'
#' Used by `plot_model_performance()`.
#'
#' @inheritParams plot_model_performance
#' @return Tidy dataframe with model performance metrics.
#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#' result[["performance"]]
#' }) %>%
#' dplyr::bind_rows()
#' # make it pretty!
#' tidy_perf_data(perf_df)
#' }
tidy_perf_data <- function(performance_df) {
abort_packages_not_installed("tidyr")
cv_colname <- performance_df %>%
dplyr::select(dplyr::starts_with("cv_metric_")) %>%
colnames()
test_colname <- cv_colname %>%
gsub("cv_metric_", "", .)
return(performance_df %>%
dplyr::select(.data[["method"]], .data[[cv_colname]], .data[[test_colname]]) %>%
tidyr::pivot_longer(
cols = c(.data[[cv_colname]], .data[[test_colname]]),
names_to = "metric"
) %>%
dplyr::mutate(metric = dplyr::case_when(
startsWith(metric, "cv_metric_") ~ gsub("cv_metric_", "CV ", metric),
TRUE ~ paste("Test", metric)
)))
}
#' Get hyperparameter performance metrics
#'
#' @param trained_model trained model (e.g. from `run_ml()`)
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters.
#' - `params`: Hyperparameters tuned.
#' - `metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
get_hp_performance <- function(trained_model) {
metric <- trained_model$metric
dat <- trained_model$results %>%
dplyr::select(
dplyr::all_of(trained_model$modelInfo$parameters$parameter),
dplyr::all_of(metric)
)
params <- sapply(dat, function(x) length(unique(x)) > 1) %>%
Filter(isTRUE, .) %>%
names() %>%
Filter(function(x) x != metric, .)
return(list(
dat = dat,
params = params,
metric = metric
))
}
#' Combine hyperparameter performance metrics for multiple train/test splits
#'
#' Combine hyperparameter performance metrics for multiple train/test splits generated by, for instance, [looping in R](http://www.schlosslab.org/mikropml/articles/parallel.html) or using a [snakemake workflow](https://github.com/SchlossLab/mikropml-snakemake-workflow) on a high-performance computer.
#'
#' @param trained_model_lst List of trained models.
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters
#' - `params`: Hyperparameters tuned.
#' - `Metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' results <- lapply(seq(100, 102), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed, cv_times = 2, kfold = 2)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' combine_hp_performance(models)
#' }
combine_hp_performance <- function(trained_model_lst) {
abort_packages_not_installed("purrr")
# TODO: can we do this without purrr so we don't have to add a new dep?
dat_params <- lapply(trained_model_lst, function(x) get_hp_performance(x)) %>%
purrr::transpose()
dat <- dplyr::bind_rows(dat_params$dat)
return(list(
dat = dat,
params = unique(unlist(dat_params$params)),
metric = unique(unlist(dat_params$metric))
))
}
#' Plot hyperparameter performance metrics
#'
#' @param dat dataframe of hyperparameters and performance metric (e.g. from `get_hp_performance()` or `combine_hp_performance()`)
#' @param param_col hyperparameter to be plotted. must be a column in `dat`.
#' @param metric_col performance metric. must be a column in `dat`.
#'
#' @return ggplot of hyperparameter performance.
#'
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' # plot for a single `run_ml()` call
#' hp_metrics <- get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
#' hp_metrics
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' \dontrun{
#' # plot for multiple `run_ml()` calls
#' results <- lapply(seq(100, 102), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' hp_metrics <- combine_hp_performance(models)
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' }
plot_hp_performance <- function(dat, param_col, metric_col) {
abort_packages_not_installed("ggplot2")
mean_colname <- paste0("mean_", rlang::as_name(rlang::enquo(metric_col)))
sd_colname <- paste0("sd_", rlang::as_name(rlang::enquo(metric_col)))
dat_sum <- dat %>%
dplyr::group_by({{ param_col }}) %>%
dplyr::summarise("mean_{{ metric_col }}" := mean({{ metric_col }}),
"sd_{{ metric_col }}" := stats::sd({{ metric_col }}),
# is there a less repetitive way to do this cleanly?
ymin_metric = !!rlang::sym(mean_colname) - !!rlang::sym(sd_colname),
ymax_metric = !!rlang::sym(mean_colname) + !!rlang::sym(sd_colname)
)
return(dat_sum %>%
ggplot2::ggplot(ggplot2::aes(
x = {{ param_col }},
y = !!rlang::sym(mean_colname)
)) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::geom_errorbar(
ggplot2::aes(
ymin = .data$ymin_metric,
ymax = .data$ymax_metric
),
width = .001
))
}
#' Get plot layers shared by `plot_mean_roc` and `plot_mean_prc`
#'
#' @param ribbon_fill ribbon fill color (default: "#D9D9D9")
#' @param line_color line color (default: "#000000")
#'
#' @return list of ggproto objects to add to a ggplot
#'
#' @keywords internal
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
shared_ggprotos <- function(ribbon_fill = "#D9D9D9",
line_color = "#000000") {
return(list(
ggplot2::geom_ribbon(fill = ribbon_fill),
ggplot2::geom_line(color = line_color),
ggplot2::coord_equal(),
ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)),
ggplot2::theme_bw(),
ggplot2::theme(legend.title = ggplot2::element_blank())
))
}
#' @describeIn plot_curves Plot mean sensitivity over specificity
#'
#' @inheritParams shared_ggprotos
#' @param dat sensitivity, specificity, and precision data calculated by `calc_mean_roc()`
#'
#' @export
plot_mean_roc <- function(dat,
ribbon_fill = "#C6DBEF", line_color = "#08306B") {
specificity <- mean_sensitivity <- lower <- upper <- NULL
abort_packages_not_installed("ggplot2")
dat %>%
ggplot2::ggplot(ggplot2::aes(
x = specificity, y = mean_sensitivity,
ymin = lower, ymax = upper
)) +
shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
ggplot2::geom_abline(intercept = 1, slope = 1, linetype = "dashed", color = "grey50") +
ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) +
ggplot2::labs(x = "Specificity", y = "Mean Sensitivity")
}
#' @describeIn plot_curves Plot mean precision over recall
#'
#' @inheritParams shared_ggprotos
#' @inheritParams plot_mean_roc
#' @param baseline_precision baseline precision from `calc_baseline_precision()`
#' @param ycol column for the y axis (Default: `mean_precision`)
#'
#' @export
plot_mean_prc <- function(dat, baseline_precision = NULL, ycol = mean_precision,
ribbon_fill = "#C7E9C0", line_color = "#00441B") {
recall <- mean_precision <- lower <- upper <- NULL
abort_packages_not_installed("ggplot2")
prc_plot <- dat %>%
ggplot2::ggplot(ggplot2::aes(
x = recall, y = {{ ycol }},
ymin = lower, ymax = upper
)) +
shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) +
ggplot2::labs(x = "Recall", y = "Mean Precision")
if (!is.null(baseline_precision)) {
prc_plot <- prc_plot +
ggplot2::geom_hline(
yintercept = baseline_precision,
linetype = "dashed", color = "grey50"
)
}
return(prc_plot)
}
#' @name plot_curves
#' @title Plot ROC and PRC curves
#'
#' @author Courtney Armour
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' library(dplyr)
#' # get performance for multiple models
#' get_sensspec_seed <- function(seed) {
#' ml_result <- run_ml(otu_mini_bin, "glmnet", seed = seed)
#' sensspec <- calc_model_sensspec(
#' ml_result$trained_model,
#' ml_result$test_data,
#' "dx"
#' ) %>%
#' mutate(seed = seed)
#' return(sensspec)
#' }
#' sensspec_dat <- purrr::map_dfr(seq(100, 102), get_sensspec_seed)
#'
#' # plot ROC & PRC
#' sensspec_dat %>%
#' calc_mean_roc() %>%
#' plot_mean_roc()
#' baseline_prec <- calc_baseline_precision(otu_mini_bin, "dx", "cancer")
#' sensspec_dat %>%
#' calc_mean_prc() %>%
#' plot_mean_prc(baseline_precision = baseline_prec)
#' }
NULL
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.