#' Plot Model Performance Comparison
#'
#' @description
#' Create visualizations to compare the performance of two models based on their
#' performance metrics generated by [mrIMLperformance].
#'
#' @param ModelPerf1,ModelPerf2 Two data frames of model performance metrics to
#' compare. The data frames are created by [mrIMLperformance], see **Examples**.
#' @param mode A character string describing the mode of the models. Should be
#' either "regression" or "classification". The default is "classification".
#'
#' @returns A list containing:
#' * `$performance_plot`: A box plot of model performance metrics.
#' * `$performance_diff_plot`: A bar plot of the differences in
#' performance metrics.
#' * `$performance_diff_df`: A data frame in wide format containing model
#' performance metrics and their differences.
#'
#' @examples
#' library(tidymodels)
#'
#' data <- MRFcov::Bird.parasites
#' Y <- data %>%
#' select(-scale.prop.zos) %>%
#' select(order(everything()))
#' X <- data %>%
#' select(scale.prop.zos)
#'
#' # Specify a random forest tidy model
#' model_rf <- rand_forest(
#' trees = 50, # 50 trees are set for brevity. Aim to start with 1000
#' mode = "classification",
#' mtry = tune(),
#' min_n = tune()
#' ) %>%
#' set_engine("randomForest")
#' model_lm <- logistic_reg()
#'
#' MR_perf_rf <- mrIMLpredicts(
#' X = X,
#' Y = Y,
#' Model = model_rf,
#' prop = 0.7,
#' k = 2,
#' racing = FALSE
#' ) %>%
#' mrIMLperformance()
#' MR_perf_lm <- mrIMLpredicts(
#' X = X,
#' Y = Y,
#' Model = model_lm,
#' prop = 0.7,
#' k = 2,
#' racing = FALSE
#' ) %>%
#' mrIMLperformance()
#'
#' perf_comp <- mrPerformancePlot(
#' ModelPerf1 = MR_perf_rf,
#' ModelPerf2 = MR_perf_lm
#' )
#'
#' perf_comp[[1]]
#' perf_comp[[2]]
#' perf_comp[[3]]
#'
#' @export
mrPerformancePlot <- function(ModelPerf1 = NULL,
ModelPerf2 = NULL,
mode = "classification") {
# Extract performance data frames
model1_df <- ModelPerf1[[1]]
model2_df <- ModelPerf2[[1]]
# Check that the two data frames match
if (!identical(dim(model1_df), dim(model2_df)) |
!identical(model1_df$response, model2_df$response)) {
stop(
"mrIMLperformance objects must be fit to the same data.",
call. = FALSE
)
}
# Get model names for plotting
mod_names <- c(
model1_df$model_name %>%
unique(),
model2_df$model_name %>%
unique()
)
# Set the performance metric
metric_name <- switch(
mode,
"classification" = "mcc",
"regression" = "rmse",
stop("Mode must be either regression or classification.", call. = FALSE)
)
# Detect outliers function
findoutlier <- function(x) {
(x < (stats::quantile(x, .25) - 1.5 * stats::IQR(x))) |
(x > (stats::quantile(x, .75) + 1.5 * stats::IQR(x)))
}
model_compare_df <- lapply(
list(model1_df, model2_df),
function(df) {
df %>%
dplyr::mutate(
metric = .data[[metric_name]],
outlier = ifelse(findoutlier(.data$metric), .data$metric, NA)
)
}
) %>%
dplyr::bind_rows()
# Create boxplot of model performance metrics
p1 <- model_compare_df %>%
dplyr::group_by(.data$model_name) %>%
ggplot2::ggplot(
ggplot2::aes(
x = .data$model_name,
y = .data$metric,
label = round(.data$outlier, 4)
)
) +
ggplot2::geom_boxplot() +
ggplot2::geom_text(
na.rm = TRUE, hjust = -0.5
) +
ggplot2::theme_bw() +
ggplot2::labs(y = toupper(metric_name))
# Data frame of individual taxa
wide_df <- model_compare_df %>%
dplyr::select(
.data$response,
.data$model_name,
.data$metric,
.data$outlier
) %>%
tidyr::pivot_wider(
names_from = .data$model_name,
values_from = c(.data$metric, .data$outlier),
names_glue = "{.value}_{.name}"
)
wide_df$diff_mod1_2 = wide_df[[3]] - wide_df[[2]]
# Reshape back to long format for plotting
long_df <- wide_df %>%
tidyr::pivot_longer(
cols = tidyselect::starts_with("diff_"),
names_to = "comparison",
values_to = "difference"
)
# Create bar plot of differences in performance metrics
p2 <- long_df %>%
ggplot2::ggplot(
ggplot2::aes(
y = stats::reorder(.data$response, .data$difference, decreasing = TRUE),
x = .data$difference
)
) +
ggplot2::geom_bar(stat = "identity") +
ggplot2::labs(
y = "Response",
x = paste0("Difference in ", toupper(metric_name)),
title = paste0(mod_names[2], " vs ", mod_names[1])
) +
ggplot2::theme_bw()
list(
performance_plot = p1,
performance_diff_plot = p2,
performance_diff_df = wide_df
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.