R/mrPerformancePlot.R

Defines functions mrPerformancePlot

Documented in mrPerformancePlot

#' Plot Model Performance Comparison
#' 
#' @description
#' Create visualizations to compare the performance of two models based on their
#' performance metrics generated by [mrIMLperformance].
#' 
#' @param ModelPerf1,ModelPerf2 Two data frames of model performance metrics to
#' compare. The data frames are created by [mrIMLperformance], see **Examples**.
#' @param mode A character string describing the mode of the models. Should be
#' either "regression" or "classification". The default is "classification".
#' 
#' @returns A list containing:
#' * `$performance_plot`: A box plot of model performance metrics.
#' * `$performance_diff_plot`: A bar plot of the differences in
#' performance metrics.
#' * `$performance_diff_df`: A data frame in wide format containing model
#' performance metrics and their differences.
#' 
#' @examples
#' library(tidymodels)
#'
#' data <- MRFcov::Bird.parasites
#' Y <- data %>%
#'   select(-scale.prop.zos) %>%
#'   select(order(everything()))
#' X <- data %>%
#'   select(scale.prop.zos)
#'
#' # Specify a random forest tidy model
#' model_rf <- rand_forest(
#'   trees = 50, # 50 trees are set for brevity. Aim to start with 1000
#'   mode = "classification",
#'   mtry = tune(),
#'   min_n = tune()
#' ) %>%
#'   set_engine("randomForest")
#' model_lm <- logistic_reg()
#'
#' MR_perf_rf <- mrIMLpredicts(
#'   X = X,
#'   Y = Y,
#'   Model = model_rf,
#'   prop = 0.7,
#'   k = 2,
#'   racing = FALSE
#' ) %>%
#'   mrIMLperformance()
#' MR_perf_lm <- mrIMLpredicts(
#'   X = X,
#'   Y = Y,
#'   Model = model_lm,
#'   prop = 0.7,
#'   k = 2,
#'   racing = FALSE
#' ) %>%
#'   mrIMLperformance()
#' 
#' perf_comp <- mrPerformancePlot(
#'   ModelPerf1 = MR_perf_rf,
#'   ModelPerf2 = MR_perf_lm
#' )
#'
#' perf_comp[[1]]
#' perf_comp[[2]]
#' perf_comp[[3]]
#' 
#' @export
mrPerformancePlot <- function(ModelPerf1 = NULL,
                              ModelPerf2 = NULL,
                              mode = "classification") {
  # Extract performance data frames
  model1_df <- ModelPerf1[[1]]
  model2_df <- ModelPerf2[[1]]
  
  # Check that the two data frames match 
  if (!identical(dim(model1_df), dim(model2_df)) |
      !identical(model1_df$response, model2_df$response)) {
    stop(
      "mrIMLperformance objects must be fit to the same data.",
      call. = FALSE
    )
  }
  
  # Get model names for plotting
  mod_names <- c(
    model1_df$model_name %>%
      unique(),
    model2_df$model_name %>%
      unique()
  )
  
  # Set the performance metric
  metric_name <- switch(
    mode,
    "classification" = "mcc",
    "regression" = "rmse",
    stop("Mode must be either regression or classification.", call. = FALSE)
  )
  
  # Detect outliers function
  findoutlier <- function(x) {
    (x < (stats::quantile(x, .25) - 1.5 * stats::IQR(x))) | 
      (x > (stats::quantile(x, .75) + 1.5 * stats::IQR(x)))
  }  
  
  model_compare_df <- lapply(
    list(model1_df, model2_df),
    function(df) {
      df %>%
        dplyr::mutate(
          metric = .data[[metric_name]],
          outlier = ifelse(findoutlier(.data$metric), .data$metric, NA)
        )
    } 
  ) %>%
    dplyr::bind_rows()
  
  # Create boxplot of model performance metrics
  p1 <- model_compare_df %>%
    dplyr::group_by(.data$model_name) %>%
    ggplot2::ggplot(
      ggplot2::aes(
        x = .data$model_name,
        y = .data$metric,
        label = round(.data$outlier, 4)
      )
    ) +
    ggplot2::geom_boxplot() +
    ggplot2::geom_text(
      na.rm = TRUE, hjust = -0.5
    ) +
    ggplot2::theme_bw() +
    ggplot2::labs(y = toupper(metric_name))
  
  # Data frame of individual taxa
  wide_df <- model_compare_df %>%
    dplyr::select(
      .data$response,
      .data$model_name,
      .data$metric,
      .data$outlier
    ) %>%
    tidyr::pivot_wider(
      names_from = .data$model_name,
      values_from = c(.data$metric, .data$outlier),
      names_glue = "{.value}_{.name}"
    )
  wide_df$diff_mod1_2 = wide_df[[3]] - wide_df[[2]]
  
  # Reshape back to long format for plotting
  long_df <- wide_df %>%
    tidyr::pivot_longer(
      cols = tidyselect::starts_with("diff_"),
      names_to = "comparison",
      values_to = "difference"
    )
  
  # Create bar plot of differences in performance metrics 
  p2 <- long_df %>%
    ggplot2::ggplot(
      ggplot2::aes(
        y = stats::reorder(.data$response, .data$difference, decreasing = TRUE),
        x = .data$difference
      )
    ) +
    ggplot2::geom_bar(stat = "identity") +
    ggplot2::labs(
      y = "Response",
      x = paste0("Difference in ", toupper(metric_name)),
      title = paste0(mod_names[2], " vs ", mod_names[1])
    ) +
    ggplot2::theme_bw()
  
  list(
    performance_plot = p1,
    performance_diff_plot = p2,
    performance_diff_df = wide_df
  )
}
nfj1380/mrIML documentation built on June 2, 2025, 1:03 a.m.