Nothing
#' Plot performance metrics for multiple ML runs with different parameters
#'
#' ggplot2 is required to use this function.
#'
#' @param performance_df dataframe of performance results from multiple calls to `run_ml()`
#'
#' @return A ggplot2 plot of performance.
#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#' result[["performance"]]
#' }) %>%
#' dplyr::bind_rows()
#' # plot the performance results
#' p <- plot_model_performance(perf_df)
#'
#'
#' # call `run_ml()` with different ML methods
#' param_grid <- expand.grid(
#' seeds = seq(100, 104),
#' methods = c("glmnet", "rf")
#' )
#' results_mtx <- mapply(
#' function(seed, method) {
#' run_ml(otu_mini_bin, method, seed = seed, kfold = 2)
#' },
#' param_grid$seeds, param_grid$methods
#' )
#' # extract and combine the performance results
#' perf_df2 <- dplyr::bind_rows(results_mtx["performance", ])
#' # plot the performance results
#' p <- plot_model_performance(perf_df2)
#'
#' # you can continue adding layers to customize the plot
#' p +
#' theme_classic() +
#' scale_color_brewer(palette = "Dark2") +
#' coord_flip()
#' }
plot_model_performance <- function(performance_df) {
abort_packages_not_installed("ggplot2", "tidyr")
performance_df %>%
tidy_perf_data() %>%
ggplot2::ggplot(ggplot2::aes(x = .data$method, y = .data$value, color = .data$metric)) +
ggplot2::geom_boxplot() +
ggplot2::geom_hline(yintercept = 0.5, linetype = "dashed") +
ggplot2::ylim(0, 1) +
ggplot2::labs(y = "Performance", x = NULL) +
ggplot2::theme(legend.title = ggplot2::element_blank())
}
#' Tidy the performance dataframe
#'
#' Used by `plot_model_performance()`.
#'
#' @inheritParams plot_model_performance
#' @return Tidy dataframe with model performance metrics.
#' @export
#' @author Begüm Topçuoglu, \email{topcuoglu.begum@@gmail.com}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#' @examples
#' \dontrun{
#' # call `run_ml()` multiple times with different seeds
#' results_lst <- lapply(seq(100, 104), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' # extract and combine the performance results
#' perf_df <- lapply(results_lst, function(result) {
#' result[["performance"]]
#' }) %>%
#' dplyr::bind_rows()
#' # make it pretty!
#' tidy_perf_data(perf_df)
#' }
tidy_perf_data <- function(performance_df) {
abort_packages_not_installed("tidyr")
cv_colname <- performance_df %>%
dplyr::select(dplyr::starts_with("cv_metric_")) %>%
colnames()
test_colname <- cv_colname %>%
gsub("cv_metric_", "", .)
return(performance_df %>%
dplyr::select(.data[["method"]], .data[[cv_colname]], .data[[test_colname]]) %>%
tidyr::pivot_longer(
cols = c(.data[[cv_colname]], .data[[test_colname]]),
names_to = "metric"
) %>%
dplyr::mutate(metric = dplyr::case_when(
startsWith(metric, "cv_metric_") ~ gsub("cv_metric_", "CV ", metric),
TRUE ~ paste("Test", metric)
)))
}
#' Get hyperparameter performance metrics
#'
#' @param trained_model trained model (e.g. from `run_ml()`)
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters.
#' - `params`: Hyperparameters tuned.
#' - `metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
get_hp_performance <- function(trained_model) {
metric <- trained_model$metric
dat <- trained_model$results %>%
dplyr::select(
dplyr::all_of(trained_model$modelInfo$parameters$parameter),
dplyr::all_of(metric)
)
params <- sapply(dat, function(x) length(unique(x)) > 1) %>%
Filter(isTRUE, .) %>%
names() %>%
Filter(function(x) x != metric, .)
return(list(
dat = dat,
params = params,
metric = metric
))
}
#' Combine hyperparameter performance metrics for multiple train/test splits
#'
#' Combine hyperparameter performance metrics for multiple train/test splits generated by, for instance, [looping in R](http://www.schlosslab.org/mikropml/articles/parallel.html) or using a [snakemake workflow](https://github.com/SchlossLab/mikropml-snakemake-workflow) on a high-performance computer.
#'
#' @param trained_model_lst List of trained models.
#'
#' @return
#'
#' Named list:
#' - `dat`: Dataframe of performance metric for each group of hyperparameters
#' - `params`: Hyperparameters tuned.
#' - `Metric`: Performance metric used.
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' results <- lapply(seq(100, 102), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed, cv_times = 2, kfold = 2)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' combine_hp_performance(models)
#' }
combine_hp_performance <- function(trained_model_lst) {
abort_packages_not_installed("purrr")
# TODO: can we do this without purrr so we don't have to add a new dep?
dat_params <- lapply(trained_model_lst, function(x) get_hp_performance(x)) %>%
purrr::transpose()
dat <- dplyr::bind_rows(dat_params$dat)
return(list(
dat = dat,
params = unique(unlist(dat_params$params)),
metric = unique(unlist(dat_params$metric))
))
}
#' Plot hyperparameter performance metrics
#'
#' @param dat dataframe of hyperparameters and performance metric (e.g. from `get_hp_performance()` or `combine_hp_performance()`)
#' @param param_col hyperparameter to be plotted. must be a column in `dat`.
#' @param metric_col performance metric. must be a column in `dat`.
#'
#' @return ggplot of hyperparameter performance.
#'
#' @export
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' # plot for a single `run_ml()` call
#' hp_metrics <- get_hp_performance(otu_mini_bin_results_glmnet$trained_model)
#' hp_metrics
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' \dontrun{
#' # plot for multiple `run_ml()` calls
#' results <- lapply(seq(100, 102), function(seed) {
#' run_ml(otu_small, "glmnet", seed = seed)
#' })
#' models <- lapply(results, function(x) x$trained_model)
#' hp_metrics <- combine_hp_performance(models)
#' plot_hp_performance(hp_metrics$dat, lambda, AUC)
#' }
plot_hp_performance <- function(dat, param_col, metric_col) {
abort_packages_not_installed("ggplot2")
mean_colname <- paste0("mean_", rlang::as_name(rlang::enquo(metric_col)))
sd_colname <- paste0("sd_", rlang::as_name(rlang::enquo(metric_col)))
dat_sum <- dat %>%
dplyr::group_by({{ param_col }}) %>%
dplyr::summarise("mean_{{ metric_col }}" := mean({{ metric_col }}),
"sd_{{ metric_col }}" := stats::sd({{ metric_col }}),
# is there a less repetitive way to do this cleanly?
ymin_metric = !!rlang::sym(mean_colname) - !!rlang::sym(sd_colname),
ymax_metric = !!rlang::sym(mean_colname) + !!rlang::sym(sd_colname)
)
return(dat_sum %>%
ggplot2::ggplot(ggplot2::aes(
x = {{ param_col }},
y = !!rlang::sym(mean_colname)
)) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::geom_errorbar(
ggplot2::aes(
ymin = .data$ymin_metric,
ymax = .data$ymax_metric
),
width = .001
))
}
#' Get plot layers shared by `plot_mean_roc` and `plot_mean_prc`
#'
#' @param ribbon_fill ribbon fill color (default: "#D9D9D9")
#' @param line_color line color (default: "#000000")
#'
#' @return list of ggproto objects to add to a ggplot
#'
#' @keywords internal
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
shared_ggprotos <- function(ribbon_fill = "#D9D9D9",
line_color = "#000000") {
return(list(
ggplot2::geom_ribbon(fill = ribbon_fill),
ggplot2::geom_line(color = line_color),
ggplot2::coord_equal(),
ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)),
ggplot2::theme_bw(),
ggplot2::theme(legend.title = ggplot2::element_blank())
))
}
#' @describeIn plot_curves Plot mean sensitivity over specificity
#'
#' @inheritParams shared_ggprotos
#' @param dat sensitivity, specificity, and precision data calculated by `calc_mean_roc()`
#'
#' @export
plot_mean_roc <- function(dat,
ribbon_fill = "#C6DBEF", line_color = "#08306B") {
specificity <- mean_sensitivity <- lower <- upper <- NULL
abort_packages_not_installed("ggplot2")
dat %>%
ggplot2::ggplot(ggplot2::aes(
x = specificity, y = mean_sensitivity,
ymin = lower, ymax = upper
)) +
shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
ggplot2::geom_abline(intercept = 1, slope = 1, linetype = "dashed", color = "grey50") +
ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) +
ggplot2::labs(x = "Specificity", y = "Mean Sensitivity")
}
#' @describeIn plot_curves Plot mean precision over recall
#'
#' @inheritParams shared_ggprotos
#' @inheritParams plot_mean_roc
#' @param baseline_precision baseline precision from `calc_baseline_precision()`
#' @param ycol column for the y axis (Default: `mean_precision`)
#'
#' @export
plot_mean_prc <- function(dat, baseline_precision = NULL, ycol = mean_precision,
ribbon_fill = "#C7E9C0", line_color = "#00441B") {
recall <- mean_precision <- lower <- upper <- NULL
abort_packages_not_installed("ggplot2")
prc_plot <- dat %>%
ggplot2::ggplot(ggplot2::aes(
x = recall, y = {{ ycol }},
ymin = lower, ymax = upper
)) +
shared_ggprotos(ribbon_fill = ribbon_fill, line_color = line_color) +
ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) +
ggplot2::labs(x = "Recall", y = "Mean Precision")
if (!is.null(baseline_precision)) {
prc_plot <- prc_plot +
ggplot2::geom_hline(
yintercept = baseline_precision,
linetype = "dashed", color = "grey50"
)
}
return(prc_plot)
}
#' @name plot_curves
#' @title Plot ROC and PRC curves
#'
#' @author Courtney Armour
#' @author Kelly Sovacool \email{sovacool@@umich.edu}
#'
#' @examples
#' \dontrun{
#' library(dplyr)
#' # get performance for multiple models
#' get_sensspec_seed <- function(seed) {
#' ml_result <- run_ml(otu_mini_bin, "glmnet", seed = seed)
#' sensspec <- calc_model_sensspec(
#' ml_result$trained_model,
#' ml_result$test_data,
#' "dx"
#' ) %>%
#' mutate(seed = seed)
#' return(sensspec)
#' }
#' sensspec_dat <- purrr::map_dfr(seq(100, 102), get_sensspec_seed)
#'
#' # plot ROC & PRC
#' sensspec_dat %>%
#' calc_mean_roc() %>%
#' plot_mean_roc()
#' baseline_prec <- calc_baseline_precision(otu_mini_bin, "dx", "cancer")
#' sensspec_dat %>%
#' calc_mean_prc() %>%
#' plot_mean_prc(baseline_precision = baseline_prec)
#' }
NULL
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.