R/plot_confusion_matrix.R

Defines functions plot_confusion_matrix

Documented in plot_confusion_matrix

#   __________________ #< e5efa5aa075de36169902b4bf6e14ce8 ># __________________
#   Plot a confusion matrix                                                 ####


#' @title Plot a confusion matrix
#' @description
#'  \Sexpr[results=rd, stage=render]{lifecycle::badge("experimental")}
#'
#'  Creates a \code{\link[ggplot2:ggplot]{ggplot2}} object representing a confusion matrix with counts,
#'  overall percentages, row percentages and column percentages. An extra row and column with sum tiles and the
#'  total count can be added.
#'
#'  The confusion matrix can be created with \code{\link[cvms:evaluate]{evaluate()}}. See \code{`Examples`}.
#'
#'  While this function is intended to be very flexible (hence the large number of arguments),
#'  the defaults should work in most cases for most users. See the \code{Examples}.
#'
#'  \strong{NEW}: Our
#'  \href{https://huggingface.co/spaces/ludvigolsen/plot_confusion_matrix}{
#'  \strong{Plot Confusion Matrix} web application}
#'  allows using this function without code. Select from multiple design templates
#'  or make your own.
#' @author Ludvig Renbo Olsen, \email{r-pkgs@@ludvigolsen.dk}
#' @export
#' @family plotting functions
#' @param conf_matrix Confusion matrix \code{tibble} with each combination of
#'  targets and predictions along with their counts.
#'
#'  E.g. for a binary classification:
#'
#'  \tabular{rrr}{
#'   \strong{Target} \tab \strong{Prediction} \tab \strong{N} \cr
#'   class_1 \tab class_1 \tab 5 \cr
#'   class_1 \tab class_2 \tab 9 \cr
#'   class_2 \tab class_1 \tab 3 \cr
#'   class_2 \tab class_2 \tab 2 \cr
#'  }
#'
#'  As created with the various evaluation functions in \code{cvms}, like
#'  \code{\link[cvms:evaluate]{evaluate()}}.
#'
#'  An additional \code{`sub_col`} column (\code{character}) can be specified
#'  as well. Its content will replace the bottom text (`counts` by default or
#'  `normalized` when \code{`counts_on_top`} is enabled).
#'
#'  \strong{Note}: If you supply the results from \code{\link[cvms:evaluate]{evaluate()}}
#'  or \code{\link[cvms:confusion_matrix]{confusion_matrix()}} directly,
#'  the confusion matrix \code{tibble} is extracted automatically, if possible.
#' @param target_col Name of column with target levels.
#' @param prediction_col Name of column with prediction levels.
#' @param counts_col Name of column with a count for each combination
#'  of the target and prediction levels.
#' @param sub_col Name of column with text to replace the bottom text
#'  (`counts` by default or `normalized` when \code{`counts_on_top`} is enabled).
#'
#'  It simply replaces the text, so all settings will still be called
#'  e.g. \code{`font_counts`} etc. When other settings make it so, that no
#'  bottom text is displayed (e.g. \code{`add_counts` = FALSE}),
#'  this text is not displayed either.
#' @param class_order Names of the classes in \code{`conf_matrix`} in the desired order.
#'  When \code{NULL}, the classes are ordered alphabetically.
#'  Note that the entire set of unique classes from both \code{`target_col`}
#'  and \code{`prediction_col`} must be specified.
#' @param add_sums Add tiles with the row/column sums. Also adds a total count tile. (Logical)
#'
#'  The appearance of these tiles can be specified in \code{`sums_settings`}.
#'
#'  Note: Adding the sum tiles with a palette requires the \code{ggnewscale} package.
#' @param add_counts Add the counts to the middle of the tiles. (Logical)
#' @param add_normalized Normalize the counts to percentages and
#'  add to the middle of the tiles. (Logical)
#' @param add_row_percentages Add the row percentages,
#'  i.e. how big a part of its row the tile makes up. (Logical)
#'
#'  By default, the row percentage is placed to the right of the tile, rotated 90 degrees.
#' @param add_col_percentages Add the column percentages,
#'  i.e. how big a part of its column the tile makes up. (Logical)
#'
#'  By default, the row percentage is placed at the bottom of the tile.
#' @param add_arrows Add the arrows to the row and col percentages. (Logical)
#'
#'  Note: Adding the arrows requires the \code{rsvg} and \code{ggimage} packages.
#' @param add_zero_shading Add image of skewed lines to zero-tiles. (Logical)
#'
#'  Note: Adding the zero-shading requires the \code{rsvg} and \code{ggimage} packages.
#' @param amount_3d_effect Amount of 3D effect (tile overlay) to add.
#'  Passed as whole number from \code{0} (no effect) up to \code{6} (biggest effect).
#'  This helps separate tiles with the same intensities.
#'
#'  Note: The overlay may not fit the tiles in many-class cases that haven't been tested.
#'  If the boxes do not overlap properly, simply turn it off.
#' @param diag_percentages_only Whether to only have row and column percentages in the diagonal tiles. (Logical)
#' @param rm_zero_percentages Whether to remove row and column percentages when the count is \code{0}. (Logical)
#' @param rm_zero_text Whether to remove counts and normalized percentages when the count is \code{0}. (Logical)
#' @param counts_on_top Switch the counts and normalized counts,
#'  such that the counts are on top. (Logical)
#' @param rotate_y_text Whether to rotate the y-axis text to
#'  be vertical instead of horizontal. (Logical)
#' @param font_counts \code{list} of font settings for the counts.
#'  Can be provided with \code{\link[cvms:font]{font()}}.
#' @param font_normalized \code{list} of font settings for the normalized counts.
#'  Can be provided with \code{\link[cvms:font]{font()}}.
#' @param font_row_percentages \code{list} of font settings for the row percentages.
#'  Can be provided with \code{\link[cvms:font]{font()}}.
#' @param font_col_percentages \code{list} of font settings for the column percentages.
#'  Can be provided with \code{\link[cvms:font]{font()}}.
#' @param intensity_by The measure that should control the color intensity of the tiles.
#'  Either \code{`counts`}, \code{`normalized`} or one of \code{`log counts`,
#'  `log2 counts`, `log10 counts`, `arcsinh counts`}.
#'
#'  For `normalized`, the color limits become \code{0-100} (except when
#'  \code{`intensity_lims`} are specified), why the intensities
#'  can better be compared across plots.
#'
#'  For the `log*` and `arcsinh` versions, the log/arcsinh transformed counts are used.
#'
#'  \strong{Note}: In `log*` transformed counts, 0-counts are set to `0`, why they
#'  won't be distinguishable from 1-counts.
#' @param intensity_lims A specific range of values for the color intensity of
#'  the tiles. Given as a numeric vector with \code{c(min, max)}.
#'
#'  This allows having the same intensity scale across plots for better comparison
#'  of prediction sets.
#' @param intensity_beyond_lims What to do with values beyond the
#'  \code{`intensity_lims`}. One of \code{"truncate", "grey"}.
#' @param arrow_size Size of arrow icons. (Numeric)
#'
#'  Is divided by \code{sqrt(nrow(conf_matrix))} and passed on
#'  to \code{\link[ggimage:geom_icon]{ggimage::geom_icon()}}.
#' @param arrow_nudge_from_text Distance from the percentage text to the arrow. (Numeric)
#' @param digits Number of digits to round to (percentages only).
#'  Set to a negative number for no rounding.
#'
#'  Can be set for each font individually via the \code{font_*} arguments.
#' @param palette Color scheme. Passed directly to \code{`palette`} in
#'  \code{\link[ggplot2:scale_fill_distiller]{ggplot2::scale_fill_distiller}}.
#'
#'  Try these palettes: \code{"Greens"}, \code{"Oranges"},
#'  \code{"Greys"}, \code{"Purples"}, \code{"Reds"},
#'  as well as the default \code{"Blues"}.
#'
#'  Alternatively, pass a named list with limits of a custom gradient as e.g.
#'  \code{`list("low"="#B1F9E8", "high"="#239895")`}. These are passed to
#'  \code{\link[ggplot2:scale_fill_gradient]{ggplot2::scale_fill_gradient}}.
#' @param tile_border_color Color of the tile borders. Passed as \emph{\code{`colour`}} to
#' \code{\link[ggplot2:geom_tile]{ggplot2::geom_tile}}.
#' @param tile_border_size Size of the tile borders. Passed as \emph{\code{`size`}} to
#' \code{\link[ggplot2:geom_tile]{ggplot2::geom_tile}}.
#' @param tile_border_linetype Linetype for the tile borders. Passed as \emph{\code{`linetype`}} to
#' \code{\link[ggplot2:geom_tile]{ggplot2::geom_tile}}.
#' @param sums_settings A list of settings for the appearance of the sum tiles.
#'  Can be provided with \code{\link[cvms:sum_tile_settings]{sum_tile_settings()}}.
#' @param darkness How dark the darkest colors should be, between \code{0} and \code{1}, where \code{1} is darkest.
#'
#'  Technically, a lower value increases the upper limit in
#'  \code{\link[ggplot2:scale_fill_distiller]{ggplot2::scale_fill_distiller}}.
#' @param theme_fn The \code{ggplot2} theme function to apply.
#' @param place_x_axis_above Move the x-axis text to the top and reverse the levels such that
#'  the "correct" diagonal goes from top left to bottom right. (Logical)
#' @details
#'  Inspired by Antoine Sachet's answer at https://stackoverflow.com/a/53612391/11832955
#' @return
#'  A \code{ggplot2} object representing a confusion matrix.
#'  Color intensity depends on either the counts (default) or the overall percentages.
#'
#'  By default, each tile has the normalized count
#'  (overall percentage) and count in the middle, the
#'  column percentage at the bottom, and the
#'  row percentage to the right and rotated 90 degrees.
#'
#'  In the "correct" diagonal (upper left to bottom right, by default),
#'  the column percentages are the class-level sensitivity scores,
#'  while the row percentages are the class-level positive predictive values.
#' @examples
#' \donttest{
#' # Attach cvms
#' library(cvms)
#' library(ggplot2)
#'
#' # Two classes
#'
#' # Create targets and predictions data frame
#' data <- data.frame(
#'   "target" = c("A", "B", "A", "B", "A", "B", "A", "B",
#'                "A", "B", "A", "B", "A", "B", "A", "A"),
#'   "prediction" = c("B", "B", "A", "A", "A", "B", "B", "B",
#'                    "B", "B", "A", "B", "A", "A", "A", "A"),
#'   stringsAsFactors = FALSE
#' )
#'
#' # Evaluate predictions and create confusion matrix
#' evaluation <- evaluate(
#'   data = data,
#'   target_col = "target",
#'   prediction_cols = "prediction",
#'   type = "binomial"
#' )
#'
#' # Inspect confusion matrix tibble
#' evaluation[["Confusion Matrix"]][[1]]
#'
#' # Plot confusion matrix
#' # Supply confusion matrix tibble directly
#' plot_confusion_matrix(evaluation[["Confusion Matrix"]][[1]])
#' # Plot first confusion matrix in evaluate() output
#' plot_confusion_matrix(evaluation)
#'
#' }
#' \dontrun{
#' # Add sum tiles
#' plot_confusion_matrix(evaluation, add_sums = TRUE)
#' }
#' \donttest{
#'
#' # Add labels to diagonal row and column percentages
#' # This example assumes "B" is the positive class
#' # but you could write anything as prefix to the percentages
#' plot_confusion_matrix(
#'     evaluation,
#'     font_row_percentages = font(prefix=c("NPV = ", "", "", "PPV = ")),
#'     font_col_percentages = font(prefix=c("Spec = ", "", "", "Sens = "))
#' )
#'
#' # Three (or more) classes
#'
#' # Create targets and predictions data frame
#' data <- data.frame(
#'   "target" = c("A", "B", "C", "B", "A", "B", "C",
#'                "B", "A", "B", "C", "B", "A"),
#'   "prediction" = c("C", "B", "A", "C", "A", "B", "B",
#'                    "C", "A", "B", "C", "A", "C"),
#'   stringsAsFactors = FALSE
#' )
#'
#' # Evaluate predictions and create confusion matrix
#' evaluation <- evaluate(
#'   data = data,
#'   target_col = "target",
#'   prediction_cols = "prediction",
#'   type = "multinomial"
#' )
#'
#' # Inspect confusion matrix tibble
#' evaluation[["Confusion Matrix"]][[1]]
#'
#' # Plot confusion matrix
#' # Supply confusion matrix tibble directly
#' plot_confusion_matrix(evaluation[["Confusion Matrix"]][[1]])
#' # Plot first confusion matrix in evaluate() output
#' plot_confusion_matrix(evaluation)
#'
#' }
#' \dontrun{
#' # Add sum tiles
#' plot_confusion_matrix(evaluation, add_sums = TRUE)
#' }
#' \donttest{
#'
#' # Counts only
#' plot_confusion_matrix(
#'   evaluation[["Confusion Matrix"]][[1]],
#'   add_normalized = FALSE,
#'   add_row_percentages = FALSE,
#'   add_col_percentages = FALSE
#' )
#'
#' # Change color palette to green
#' # Change theme to `theme_light`.
#' plot_confusion_matrix(
#'   evaluation[["Confusion Matrix"]][[1]],
#'   palette = "Greens",
#'   theme_fn = ggplot2::theme_light
#' )
#'
#' }
#' \dontrun{
#' # Change colors palette to custom gradient
#' # with a different gradient for sum tiles
#' plot_confusion_matrix(
#'   evaluation[["Confusion Matrix"]][[1]],
#'   palette = list("low" = "#B1F9E8", "high" = "#239895"),
#'   sums_settings = sum_tile_settings(
#'     palette = list("low" = "#e9e1fc", "high" = "#BE94E6")
#'   ),
#'   add_sums = TRUE
#' )
#' }
#' \donttest{
#'
#' # The output is a ggplot2 object
#' # that you can add layers to
#' # Here we change the axis labels
#' plot_confusion_matrix(evaluation[["Confusion Matrix"]][[1]]) +
#'   ggplot2::labs(x = "True", y = "Guess")
#'
#' # Replace the bottom tile text
#' # with some information
#' # First extract confusion matrix
#' # Then add new column with text
#' cm <- evaluation[["Confusion Matrix"]][[1]]
#' cm[["Trials"]] <- c(
#'   "(8/9)", "(3/9)", "(1/9)",
#'   "(3/9)", "(7/9)", "(4/9)",
#'   "(1/9)", "(2/9)", "(8/9)"
#'  )
#'
#' # Now plot with the `sub_col` argument specified
#' plot_confusion_matrix(cm, sub_col="Trials")
#'
#' }
plot_confusion_matrix <- function(conf_matrix,
                                  target_col = "Target",
                                  prediction_col = "Prediction",
                                  counts_col = "N",
                                  sub_col = NULL,
                                  class_order = NULL,
                                  add_sums = FALSE,
                                  add_counts = TRUE,
                                  add_normalized = TRUE,
                                  add_row_percentages = TRUE,
                                  add_col_percentages = TRUE,
                                  diag_percentages_only = FALSE,
                                  rm_zero_percentages = TRUE,
                                  rm_zero_text = TRUE,
                                  add_zero_shading = TRUE,
                                  amount_3d_effect = 1,
                                  add_arrows = TRUE,
                                  counts_on_top = FALSE,
                                  palette = "Blues",
                                  intensity_by = "counts",
                                  intensity_lims = NULL,
                                  intensity_beyond_lims = "truncate",
                                  theme_fn = ggplot2::theme_minimal,
                                  place_x_axis_above = TRUE,
                                  rotate_y_text = TRUE,
                                  digits = 1,
                                  font_counts = font(),
                                  font_normalized = font(),
                                  font_row_percentages = font(),
                                  font_col_percentages = font(),
                                  arrow_size = 0.048,
                                  arrow_nudge_from_text = 0.065,
                                  tile_border_color = NA,
                                  tile_border_size = 0.1,
                                  tile_border_linetype = "solid",
                                  sums_settings = sum_tile_settings(),
                                  darkness = 0.8) {

  if (length(intersect(class(conf_matrix), c("cfm_results", "eval_results"))) > 0 &&
      "Confusion Matrix" %in% colnames(conf_matrix) &&
      nrow(conf_matrix) > 0){
    if (nrow(conf_matrix) > 1){
      warning(paste0("'conf_matrix' has more than one row. Extracting first confu",
                     "sion matrix with 'conf_matrix[['Confusion Matrix']][[1]]'."))
    }
    conf_matrix <- conf_matrix[["Confusion Matrix"]][[1]]
  }

  # Check arguments ####
  assert_collection <- checkmate::makeAssertCollection()
  checkmate::assert_data_frame(
    x = conf_matrix,
    any.missing = FALSE,
    min.rows = 4,
    min.cols = 3,
    col.names = "named",
    add = assert_collection
  )
  # String
  checkmate::assert_string(x = target_col, min.chars = 1, add = assert_collection)
  checkmate::assert_string(x = prediction_col, min.chars = 1, add = assert_collection)
  checkmate::assert_string(x = counts_col, min.chars = 1, add = assert_collection)
  checkmate::assert_string(x = sub_col, min.chars = 1, add = assert_collection, null.ok = TRUE)
  checkmate::assert_string(x = tile_border_linetype, na.ok = TRUE, add = assert_collection)
  checkmate::assert_string(x = tile_border_color, na.ok = TRUE, add = assert_collection)
  checkmate::assert_character(x = class_order, null.ok = TRUE, any.missing = FALSE, add = assert_collection)
  checkmate::assert_string(x = intensity_by, add = assert_collection)
  checkmate::assert_string(x = intensity_beyond_lims, add = assert_collection)
  checkmate::assert_choice(x = intensity_beyond_lims, choices = c("truncate", "grey"), add = assert_collection)
  checkmate::assert_string(x = sums_settings[["intensity_by"]], null.ok = TRUE, add = assert_collection)
  checkmate::assert_string(x = sums_settings[["intensity_beyond_lims"]], null.ok = TRUE, add = assert_collection)
  checkmate::assert_choice(x = sums_settings[["intensity_beyond_lims"]], choices = c("truncate", "grey"), null.ok = TRUE, add = assert_collection)


  checkmate::assert(
    checkmate::check_string(x = palette, na.ok = TRUE, null.ok = TRUE),
    checkmate::check_list(x = palette, types = "character", names = "unique")
  )
  if (is.list(palette)){
    checkmate::assert_names(
      x = names(palette),
      must.include = c("high", "low"),
      add = assert_collection
    )
  }

  # Flag
  checkmate::assert_flag(x = add_sums, add = assert_collection)
  checkmate::assert_flag(x = add_counts, add = assert_collection)
  checkmate::assert_flag(x = add_normalized, add = assert_collection)
  checkmate::assert_flag(x = add_row_percentages, add = assert_collection)
  checkmate::assert_flag(x = add_col_percentages, add = assert_collection)
  checkmate::assert_flag(x = add_zero_shading, add = assert_collection)
  checkmate::assert_flag(x = counts_on_top, add = assert_collection)
  checkmate::assert_flag(x = place_x_axis_above, add = assert_collection)
  checkmate::assert_flag(x = rotate_y_text, add = assert_collection)
  checkmate::assert_flag(x = diag_percentages_only, add = assert_collection)
  checkmate::assert_flag(x = rm_zero_percentages, add = assert_collection)
  checkmate::assert_flag(x = rm_zero_text, add = assert_collection)
  checkmate::assert_choice(x = amount_3d_effect,
                           choices = 0:6,
                           add = assert_collection)

  # Function
  checkmate::assert_function(x = theme_fn, add = assert_collection)

  # Number
  checkmate::assert_number(x = darkness, lower = 0, upper = 1, add = assert_collection)
  checkmate::assert_number(x = tile_border_size, lower = 0, add = assert_collection)
  checkmate::assert_number(x = arrow_size, lower = 0, add = assert_collection)
  checkmate::assert_number(x = arrow_nudge_from_text, lower = 0, add = assert_collection)
  checkmate::assert_count(x = digits, add = assert_collection)
  checkmate::assert_numeric(
    x = intensity_lims,
    len = 2,
    null.ok = TRUE,
    any.missing = FALSE,
    unique = TRUE,
    sorted = TRUE,
    finite = TRUE,
    add = assert_collection
  )

  # List
  checkmate::assert_list(x = font_counts, names = "named", add = assert_collection)
  checkmate::assert_list(x = font_normalized, names = "named", add = assert_collection)
  checkmate::assert_list(x = font_row_percentages, names = "named", add = assert_collection)
  checkmate::assert_list(x = font_col_percentages, names = "named", add = assert_collection)
  checkmate::assert_list(x = sums_settings, names = "named", add = assert_collection)
  checkmate::reportAssertions(assert_collection)

  # Names
  checkmate::assert_names(
    x = colnames(conf_matrix),
    must.include = c(target_col, prediction_col, counts_col),
    add = assert_collection
  )
  available_font_settings <- names(font())
  checkmate::assert_names(
    x = names(font_counts),
    subset.of = available_font_settings,
    add = assert_collection
  )
  checkmate::assert_names(
    x = names(font_normalized),
    subset.of = available_font_settings,
    add = assert_collection
  )
  checkmate::assert_names(
    x = names(font_row_percentages),
    subset.of = available_font_settings,
    add = assert_collection
  )
  checkmate::assert_names(
    x = names(font_col_percentages),
    subset.of = available_font_settings,
    add = assert_collection
  )
  checkmate::assert_names(
    x = intensity_by,
    subset.of = c("counts", "normalized", "log counts", "log2 counts", "log10 counts", "arcsinh counts"),
    add = assert_collection
  )
  if (!is.null(sums_settings[["intensity_by"]])){
    checkmate::assert_names(
      x = sums_settings[["intensity_by"]],
      subset.of = c("counts", "normalized", "log counts", "log2 counts", "log10 counts", "arcsinh counts"),
      add = assert_collection
    )
  }

  if (!is.null(class_order)) {
    classes_in_data <- unique(c(
      conf_matrix[[target_col]],
      conf_matrix[[prediction_col]]
    ))
    if (length(classes_in_data) != length(class_order) ||
        length(setdiff(classes_in_data, class_order)) != 0) {
      assert_collection$push(
        "when 'class_order' is specified, it must contain the same levels as 'conf_matrix'."
      )
    }
  } else {
    class_order <- sort(unique(
      c(conf_matrix[[target_col]],
        conf_matrix[[prediction_col]])
    ))
  }

  # Check `sum_settings` `palette`
  if (!is.null(sums_settings[["palette"]])) {
    checkmate::assert(
      checkmate::check_string(x = sums_settings[["palette"]], na.ok = TRUE),
      checkmate::check_list(
        x = sums_settings[["palette"]],
        types = "character",
        names = "unique"
      )
    )
    if (is.list(sums_settings[["palette"]])) {
      checkmate::assert_names(
        x = names(sums_settings[["palette"]]),
        must.include = c("high", "low"),
        add = assert_collection
      )
    }

    if (palettes_are_equal(palette, sums_settings[["palette"]])) {
      assert_collection$push("'palette' and 'sums_settings[['palette']]' cannot be the same palette.")
    }
  }

  # Check that N are (>= 0)
  checkmate::assert_numeric(
    x = conf_matrix[[counts_col]],
    lower = 0,
    add = assert_collection
  )

  checkmate::reportAssertions(assert_collection)
  # End of argument checks ####

  # When 'rsvg', 'ggimage' or 'ggnewscale' is missing
  user_has_rsvg <- requireNamespace("rsvg", quietly = TRUE)
  user_has_ggimage <- requireNamespace("ggimage", quietly = TRUE)
  user_has_ggnewscale <- requireNamespace("ggnewscale", quietly = TRUE)
  use_ggimage <- all(user_has_rsvg, user_has_ggimage)
  if (!isTRUE(use_ggimage)){
    if (!isTRUE(user_has_ggimage))
      warning("'ggimage' is missing. Will not plot arrows and zero-shading.")
    if (!isTRUE(user_has_rsvg))
      warning("'rsvg' is missing. Will not plot arrows and zero-shading.")
    add_arrows <- FALSE
    add_zero_shading <- FALSE
  }
  if (isTRUE(add_sums) &&
      !isTRUE(user_has_ggnewscale) &&
      is.null(sums_settings[["tile_fill"]])){
    warning("'ggnewscale' is missing. Will not use palette for sum tiles.")
    sums_settings[["tile_fill"]] <- "#F5F5F5"
  }

  #### Update font settings ####

  font_top_size <- 4.3
  font_bottom_size <- 2.8
  big_counts <- isTRUE(counts_on_top) || !isTRUE(add_normalized)

  # Font for counts
  font_counts <- update_font_setting(font_counts, defaults = list(
    "size" = ifelse(isTRUE(big_counts), font_top_size, font_bottom_size), "digits" = -1
  ), initial_vals = list(
    "nudge_y" = function(x) {
      x + dplyr::case_when(
        !isTRUE(counts_on_top) &&
          isTRUE(add_normalized) ~ -0.16,
        TRUE ~ 0
      )
    }
  ))

  # Font for normalized counts
  font_normalized <- update_font_setting(font_normalized,
    defaults = list(
      "size" = ifelse(!isTRUE(big_counts), font_top_size, font_bottom_size),
      "suffix" = "%", "digits" = digits
    ),
    initial_vals = list(
      "nudge_y" = function(x) {
        x + dplyr::case_when(
          isTRUE(counts_on_top) &&
            isTRUE(add_counts) ~ -0.16,
          TRUE ~ 0
        )
      }
    )
  )

  # Font for row percentages
  font_row_percentages <- update_font_setting(font_row_percentages, defaults = list(
    "size" = 2.35, "prefix" = "",
    "suffix" = "%", "fontface" = "italic",
    "digits" = digits, "alpha" = 0.85
  ), initial_vals = list(
    "nudge_x" = function(x) {
      x + 0.41
    },
    "angle" = function(x) {
      x + 90
    }
  ))

  # Font for column percentages
  font_col_percentages <- update_font_setting(font_col_percentages, defaults = list(
    "size" = 2.35, "prefix" = "",
    "suffix" = "%", "fontface" = "italic",
    "digits" = digits, "alpha" = 0.85
  ), initial_vals = list(
    "nudge_y" = function(x) {
      x - 0.41
    }
  ))

  #### Update sum tile settings ####

  if (isTRUE(add_sums)) {

    if (!is.list(palette) &&
        palette == "Greens" &&
        is.null(sums_settings[["palette"]])) {
      sums_settings[["palette"]] <- "Blues"
    }

    # Defaults are defined in update_sum_tile_settings()
    sums_settings <- update_sum_tile_settings(
      settings = sums_settings,
      defaults = list()
    )

  }

  #### Prepare dataset ####

  # Extract needed columns
  cm <- tibble::tibble(
    "Target" = factor(as.character(conf_matrix[[target_col]]), levels = class_order[class_order %in% unique(conf_matrix[[target_col]])]),
    "Prediction" = factor(as.character(conf_matrix[[prediction_col]]), levels = class_order[class_order %in% unique(conf_matrix[[prediction_col]])]),
    "N" = as.integer(conf_matrix[[counts_col]])
  )
  cm <- calculate_normalized(cm)

  # Prepare text versions of the numerics
  cm[["N_text"]] <- preprocess_numeric(cm[["N"]], font_counts, rm_zero_text = rm_zero_text)
  cm[["Normalized_text"]] <- preprocess_numeric(
    cm[["Normalized"]],
    font_normalized,
    rm_zero_text = rm_zero_text,
    rm_zeroes_post_rounding = FALSE # Only remove where N==0
  )

  # Set color intensity metric
  cm <- set_intensity(cm, intensity_by)

  # Get min and max intensity scores and their range
  # We need to do this before adding sums
  intensity_measures <- get_intensity_range(
    data = cm,
    intensity_by = intensity_by,
    intensity_lims = intensity_lims
  )

  # Handle intensities outside of the allow range¨
  cm$Intensity <- handle_beyond_intensity_limits(
    intensities = cm$Intensity,
    intensity_range = intensity_measures,
    intensity_beyond_lims = intensity_beyond_lims
  )

  # Calculate number of tiles to size by
  num_predict_classes <- length(unique(cm[["Prediction"]])) + as.integer(isTRUE(add_sums))

  # Arrow icons
  arrow_icons <- list("up" = get_figure_path("caret_up_sharp.svg"),
                      "down" = get_figure_path("caret_down_sharp.svg"),
                      "left" = get_figure_path("caret_back_sharp.svg"),
                      "right" = get_figure_path("caret_forward_sharp.svg"))

  # Scaling arrow size
  arrow_size <- arrow_size / num_predict_classes

  # Add icons depending on where the tile will be in the image
  cm <- set_arrows(cm, place_x_axis_above = place_x_axis_above,
                   icons = arrow_icons)

  if (isTRUE(use_ggimage) &&
      isTRUE(add_zero_shading)){
    # Add image path for skewed lines for when there's an N=0
    cm[["image_skewed_lines"]] <- ifelse(cm[["N"]] == 0,
                                         get_figure_path("skewed_lines.svg"),
                                         get_figure_path("empty_square.svg"))
  }

  # Calculate column sums
  if (isTRUE(add_col_percentages) || isTRUE(add_sums)) {
    column_sums <- cm %>%
      dplyr::group_by(.data$Target) %>%
      dplyr::summarize(Class_N = sum(.data$N))
  }

  # Calculate row sums
  if (isTRUE(add_col_percentages) || isTRUE(add_sums)) {
    row_sums <- cm %>%
      dplyr::group_by(.data$Prediction) %>%
      dplyr::summarize(Prediction_N = sum(.data$N))
  }

  # Calculate column percentages
  if (isTRUE(add_col_percentages)) {
    cm <- cm %>%
      dplyr::left_join(column_sums, by = "Target") %>%
      dplyr::mutate(
        Class_Percentage = 100 * (.data$N / .data$Class_N),
        Class_Percentage_text = preprocess_numeric(
          .data$Class_Percentage, font_col_percentages
        )
      )
  }

  # Calculate row percentages
  if (isTRUE(add_row_percentages)) {
    cm <- cm %>%
      dplyr::left_join(row_sums, by = "Prediction") %>%
      dplyr::mutate(
        Prediction_Percentage = 100 * (.data$N / .data$Prediction_N),
        Prediction_Percentage_text = preprocess_numeric(
          .data$Prediction_Percentage, font_row_percentages
        )
      )
  }

  # Signal that current rows are not for the sum tiles
  cm[["is_sum"]] <- FALSE

  # Add sums
  if (isTRUE(add_sums)){

    # Create data frame with data for
    # the sum column, sum row, and total counts tile
    column_sums[["Prediction"]] <- "Total"
    row_sums[["Target"]] <-  "Total"
    column_sums <- dplyr::rename(column_sums, N = "Class_N")
    row_sums <- dplyr::rename(row_sums, N = "Prediction_N")
    column_sums <- calculate_normalized(column_sums)
    row_sums <- calculate_normalized(row_sums)
    total_count <- dplyr::tibble(
      "Target" = "Total", "Prediction" = "Total",
      "N" = sum(cm$N), "Normalized" = 100
    )
    sum_data <- dplyr::bind_rows(column_sums, row_sums, total_count)

    # Prepare text versions of the numerics
    sum_data[["N_text"]] <- preprocess_numeric(sum_data[["N"]], font_counts)
    sum_data[["Normalized_text"]] <- preprocess_numeric(sum_data[["Normalized"]], font_normalized)

    # Set total counts tile text
    sum_data[nrow(sum_data), "N_text"] <- ifelse(
      isTRUE(counts_on_top) || !isTRUE(add_normalized),
      as.character(sum_data[nrow(sum_data), "N"]),
      ""
    )
    sum_data[nrow(sum_data), "Normalized_text"] <- ifelse(
      isTRUE(counts_on_top), "", paste0(sum_data[nrow(sum_data), "N"]))

    sums_intensity_by <- sums_settings[["intensity_by"]]
    if (is.null(sums_intensity_by)) sums_intensity_by <- intensity_by

    sums_intensity_beyond_lims <- sums_settings[["intensity_beyond_lims"]]
    if (is.null(sums_intensity_beyond_lims)) sums_intensity_beyond_lims <- intensity_beyond_lims

    # Set color intensity metric
    sum_data <- set_intensity(sum_data, sums_intensity_by)

    # Get min and max intensity scores and their range
    # We need to do this before adding sum_data
    sums_intensity_measures <- get_intensity_range(
      data = head(sum_data, nrow(sum_data) - 1),
      intensity_by = sums_intensity_by,
      intensity_lims = sums_settings[["intensity_lims"]]
    )

    # Handle intensities outside of the allow range
    sum_data[1:(nrow(sum_data) - 1), "Intensity"] <- handle_beyond_intensity_limits(
      intensities = sum_data[1:(nrow(sum_data) - 1), "Intensity"]$Intensity,
      intensity_range=sums_intensity_measures,
      intensity_beyond_lims=sums_intensity_beyond_lims
    )

    # Get color limits
    sums_color_limits <- get_color_limits(sums_intensity_measures, darkness)

    sum_data[["image_skewed_lines"]] <- get_figure_path("empty_square.svg")

    # Set flag for whether a row is a total score
    sum_data[["is_sum"]] <- TRUE

    # Combine cm and sum_data
    cm <-  dplyr::bind_rows(cm, sum_data)

    # Set arrow icons to empty square image
    cm[cm[["Target"]] == "Total" | cm[["Prediction"]] == "Total",] <- empty_tile_percentages(
      cm[cm[["Target"]] == "Total" | cm[["Prediction"]] == "Total",])

    # Set class order and labels
    if (isTRUE(place_x_axis_above)) {
      class_order <- c("Total", class_order)
    } else {
      class_order <- c(class_order, "Total")
    }
    class_labels <- class_order
    class_labels[class_labels == "Total"] <- sums_settings[["label"]]
    cm[["Target"]] <- factor(
      cm[["Target"]],
      levels = class_order[class_order %in% unique(cm[["Target"]])],
      labels = class_labels
    )
    cm[["Prediction"]] <- factor(
      cm[["Prediction"]],
      levels = class_order[class_order %in% unique(cm[["Prediction"]])],
      labels = class_labels
    )
  }

  # Assign 3D effect image
  if (isTRUE(use_ggimage) &&
      amount_3d_effect > 0){
    # Add image path with slight 3D effect
    if (FALSE){  # Debugging
      cm[["image_3d"]] <- get_figure_path("square_overlay_bordered.png")
    } else {
      cm[["image_3d"]] <- get_figure_path(paste0("square_overlay_", amount_3d_effect, ".png"))
    }
  }

  # If sub column is specified

  if (!is.null(sub_col)) {
    # We overwrite the Normalized/N text with the sub column
    if (isTRUE(big_counts)) {
      text_removed <- cm[["Normalized_text"]] == ""
      cm[["Normalized_text"]] <- ""
      cm[seq_len(nrow(conf_matrix)), "Normalized_text"] <-
        as.character(conf_matrix[[sub_col]])
      cm[text_removed, "Normalized_text"] <- ""
    } else {
      text_removed <- cm[["N_text"]] == ""
      cm[["N_text"]] <- ""
      cm[seq_len(nrow(conf_matrix)), "N_text"] <-
        as.character(conf_matrix[[sub_col]])
      cm[text_removed, "N_text"] <- ""
    }
  }

  # Remove percentages outside the diagonal
  if (isTRUE(diag_percentages_only)) {
    cm[cm[["Target"]] != cm[["Prediction"]],] <- empty_tile_percentages(
      cm[cm[["Target"]] != cm[["Prediction"]],])
  }
  if (isTRUE(rm_zero_percentages)){
    # Remove percentages when the count is 0
    cm[cm[["N"]] == 0,] <- empty_tile_percentages(cm[cm[["N"]] == 0,])
  }

  #### Prepare for plotting ####

  # To avoid the extremely dark colors
  # where the black font does not work that well
  # We add a bit to the range, so our max intensity
  # will not appear to be the extreme
  color_limits <- get_color_limits(intensity_measures, darkness)

  #### Create plot ####

  # Create plot
  pl <- cm %>%
    ggplot2::ggplot(ggplot2::aes(
      x = .data$Target,
      y = .data$Prediction,
      fill = .data$Intensity
    )) +
    ggplot2::labs(
      x = "Target",
      y = "Prediction",
      fill = ifelse(grepl("counts", intensity_by), "N", "Normalized"),
      label = "N"
    ) +
    ggplot2::geom_tile(
      colour = tile_border_color,
      linewidth = tile_border_size,
      linetype = tile_border_linetype,
      show.legend = FALSE
    ) +
    theme_fn() +
    ggplot2::coord_equal()


  # Add fill colors that differ by N
  if (is.list(palette)) {
    pl <- pl +
      ggplot2::scale_fill_gradient(low = palette[["low"]],
                                   high = palette[["high"]],
                                   limits = color_limits)
  } else {
    pl <- pl +
      ggplot2::scale_fill_distiller(palette = palette,
                                    direction = 1,
                                    limits = color_limits)
  }

  pl <- pl +
    # Remove the guide
    ggplot2::guides(fill = "none") +
    ggplot2::theme(
      # Rotate y-axis text
      axis.text.y = ggplot2::element_text(
        angle = ifelse(isTRUE(rotate_y_text), 90, 0),
        hjust = ifelse(isTRUE(rotate_y_text), 0.5, 1),
        vjust = ifelse(isTRUE(rotate_y_text), 0.5, 0)
      ),
      # Add margin to axis labels
      axis.title.y = ggplot2::element_text(margin = ggplot2::margin(0, 6, 0, 0)),
      axis.title.x.top = ggplot2::element_text(margin = ggplot2::margin(0, 0, 6, 0)),
      axis.title.x.bottom = ggplot2::element_text(margin = ggplot2::margin(6, 0, 0, 0))
    )

  #### Add sum tiles ####

  if (isTRUE(add_sums)){
    if (is.null(sums_settings[["tile_fill"]])){
      pl <- pl +
        ggnewscale::new_scale_fill() +
        ggplot2::geom_tile(
          data = cm[cm[["is_sum"]] & cm[["Target"]] != cm[["Prediction"]],],
          mapping = ggplot2::aes(fill = .data$Intensity),
          colour = sums_settings[["tile_border_color"]],
          linewidth = sums_settings[["tile_border_size"]],
          linetype = sums_settings[["tile_border_linetype"]],
          show.legend = FALSE
        )

      if (is.list(sums_settings[["palette"]])) {
        pl <- pl +
          ggplot2::scale_fill_gradient(low = sums_settings[["palette"]][["low"]],
                                       high = sums_settings[["palette"]][["high"]],
                                       limits = sums_color_limits)
      } else {
        pl <- pl +
          ggplot2::scale_fill_distiller(palette = sums_settings[["palette"]],
                                        direction = 1,
                                        limits = sums_color_limits)
      }

    } else {
      pl <- pl +
        ggplot2::geom_tile(
          data = cm[cm[["is_sum"]] & cm[["Target"]] != cm[["Prediction"]],],
          fill = sums_settings[["tile_fill"]],
          colour = sums_settings[["tile_border_color"]],
          linewidth = sums_settings[["tile_border_size"]],
          linetype = sums_settings[["tile_border_linetype"]],
          show.legend = FALSE
        )
    }

    # Add special total count tile
    pl <- pl +
      ggplot2::geom_tile(
        data = cm[cm[["is_sum"]] & cm[["Target"]] == cm[["Prediction"]],],
        colour = sums_settings[["tc_tile_border_color"]],
        linewidth = sums_settings[["tc_tile_border_size"]],
        linetype = sums_settings[["tc_tile_border_linetype"]],
        fill = sums_settings[["tc_tile_fill"]]
      )
  }

  #### Add zero shading ####

  if (isTRUE(use_ggimage) &&
      isTRUE(add_zero_shading) &&
      any(cm[["N"]] == 0)){
    pl <- pl + ggimage::geom_image(
      ggplot2::aes(image = .data$image_skewed_lines),
      by = "height", size = 0.90 / num_predict_classes)
  }

  #### Add 3D effect ####

  if (isTRUE(use_ggimage) &&
      amount_3d_effect > 0) {
    num_rows_ <- num_predict_classes
    overlay_size_subtract <- 0.043 / 2 ^ (num_rows_ - 2)
    overlay_size_subtract <-
      dplyr::case_when(num_rows_ >= 5 ~ overlay_size_subtract + 0.002,
                       TRUE ~ overlay_size_subtract)

    pl <- pl + ggimage::geom_image(
      ggplot2::aes(image = .data$image_3d),
      by = "width",
      size = 1.0 / num_rows_ - overlay_size_subtract
    )
  }

  #### Add numbers to plot ####

  if (isTRUE(add_counts)) {
    # Preset the arguments for the geom_text function to avoid repetition
    text_geom <- purrr::partial(
      ggplot2::geom_text,
      size = font_counts[["size"]],
      alpha = font_counts[["alpha"]],
      nudge_x = font_counts[["nudge_x"]],
      nudge_y = font_counts[["nudge_y"]],
      angle = font_counts[["angle"]],
      family = font_counts[["family"]],
      fontface = font_counts[["fontface"]],
      hjust = font_counts[["hjust"]],
      vjust = font_counts[["vjust"]],
      lineheight = font_counts[["lineheight"]]
    )

    # Add count labels to middle of the regular tiles
    pl <- pl +
      text_geom(
        data = cm[!cm[["is_sum"]], ],
        ggplot2::aes(label = .data$N_text),
        color = font_counts[["color"]]
      )

    # Add count labels to middle of the sum tiles
    if (isTRUE(add_sums)){
      tmp_color <- ifelse(is.null(sums_settings[["font_color"]]),
                          font_counts[["color"]],
                          sums_settings[["font_color"]])
      tmp_tc_color <- ifelse(is.null(sums_settings[["tc_font_color"]]),
                             font_counts[["color"]],
                             sums_settings[["tc_font_color"]])

      # Add count labels to middle of sum tiles
      pl <- pl +
        text_geom(
          data = cm[cm[["is_sum"]] & cm[["Target"]] != cm[["Prediction"]], ],
          ggplot2::aes(label = .data$N_text),
          color = tmp_color
        )

      # Add count label to middle of total count tile
      pl <- pl +
        text_geom(
          data = cm[cm[["is_sum"]] & cm[["Target"]] == cm[["Prediction"]], ],
          ggplot2::aes(label = .data$N_text),
          color = tmp_tc_color
        )
    }
  }

  if (isTRUE(add_normalized)) {
    # Preset the arguments for the geom_text function to avoid repetition
    text_geom <- purrr::partial(
      ggplot2::geom_text,
      size = font_normalized[["size"]],
      alpha = font_normalized[["alpha"]],
      nudge_x = font_normalized[["nudge_x"]],
      nudge_y = font_normalized[["nudge_y"]],
      angle = font_normalized[["angle"]],
      family = font_normalized[["family"]],
      fontface = font_normalized[["fontface"]],
      hjust = font_normalized[["hjust"]],
      vjust = font_normalized[["vjust"]],
      lineheight = font_normalized[["lineheight"]]
    )

    # Add percentage labels to middle of the regular tiles
    pl <- pl +
      text_geom(
        data = cm[!cm[["is_sum"]], ],
        ggplot2::aes(label = .data$Normalized_text),
        color = font_normalized[["color"]]
      )

    if (isTRUE(add_sums)){
      # Get font color for sum tiles
      # If not set, we use the same as for the other tiles
      tmp_color <- ifelse(is.null(sums_settings[["font_color"]]),
                          font_normalized[["color"]],
                          sums_settings[["font_color"]])
      tmp_tc_color <- ifelse(is.null(sums_settings[["tc_font_color"]]),
                             font_normalized[["color"]],
                             sums_settings[["tc_font_color"]])

      # Add percentage labels to middle of sum tiles
      pl <- pl +
        text_geom(
          data = cm[cm[["is_sum"]] & cm[["Target"]] != cm[["Prediction"]], ],
          ggplot2::aes(label = .data$Normalized_text),
          color = tmp_color
        )

      # Add percentage label to middle of total count tile
      pl <- pl +
        text_geom(
          data = cm[cm[["is_sum"]] & cm[["Target"]] == cm[["Prediction"]], ],
          ggplot2::aes(label = .data$Normalized_text),
          color = tmp_tc_color
        )
    }
  }

  # Place x-axis text on top of the plot
  if (isTRUE(place_x_axis_above)) {
    pl <- pl +
      ggplot2::scale_x_discrete(
        position = "top",
        limits = rev(levels(cm$Target))
      )
  }

  # Add row percentages
  if (isTRUE(add_row_percentages)) {
    pl <- pl + ggplot2::geom_text(ggplot2::aes(label = .data$Prediction_Percentage_text),
      size = font_row_percentages[["size"]],
      color = font_row_percentages[["color"]],
      alpha = font_row_percentages[["alpha"]],
      nudge_x = font_row_percentages[["nudge_x"]],
      nudge_y = font_row_percentages[["nudge_y"]],
      angle = font_row_percentages[["angle"]],
      family = font_row_percentages[["family"]],
      fontface = font_row_percentages[["fontface"]],
      hjust = font_row_percentages[["hjust"]],
      vjust = font_row_percentages[["vjust"]],
      lineheight = font_row_percentages[["lineheight"]]
    )
  }

  # Add column percentages
  if (isTRUE(add_col_percentages)) {
    pl <- pl +
      ggplot2::geom_text(ggplot2::aes(label = .data$Class_Percentage_text),
        size = font_col_percentages[["size"]],
        color = font_col_percentages[["color"]],
        alpha = font_col_percentages[["alpha"]],
        nudge_x = font_col_percentages[["nudge_x"]],
        nudge_y = font_col_percentages[["nudge_y"]],
        angle = font_col_percentages[["angle"]],
        family = font_col_percentages[["family"]],
        fontface = font_col_percentages[["fontface"]],
        hjust = font_col_percentages[["hjust"]],
        vjust = font_col_percentages[["vjust"]],
        lineheight = font_col_percentages[["lineheight"]]
      )
  }

  #### Add arrow icons ####

  if (isTRUE(use_ggimage) &&
      isTRUE(add_col_percentages) &&
      isTRUE(add_arrows)){
    pl <- pl +
      ggimage::geom_image(
        ggplot2::aes(image = .data$down_icon),
        by = "height",
        size = arrow_size,
        nudge_x = font_col_percentages[["nudge_x"]],
        nudge_y = font_col_percentages[["nudge_y"]] - arrow_nudge_from_text
      ) +
      ggimage::geom_image(
        ggplot2::aes(image = .data$up_icon),
        by = "height",
        size = arrow_size,
        nudge_x = font_col_percentages[["nudge_x"]],
        nudge_y = font_col_percentages[["nudge_y"]] +
          arrow_nudge_from_text - (arrow_size/2)
      )
  }

  if (isTRUE(use_ggimage) &&
      isTRUE(add_row_percentages) &&
      isTRUE(add_arrows)){
    pl <- pl +
      ggimage::geom_image(
        ggplot2::aes(image = .data$right_icon),
        by = "height",
        size = arrow_size,
        nudge_x = font_row_percentages[["nudge_x"]] +
          arrow_nudge_from_text - (arrow_size / 2),
        nudge_y = font_row_percentages[["nudge_y"]]
      ) +
      ggimage::geom_image(
        ggplot2::aes(image = .data$left_icon),
        by = "height",
        size = arrow_size,
        nudge_x = font_row_percentages[["nudge_x"]] - arrow_nudge_from_text,
        nudge_y = font_row_percentages[["nudge_y"]]
      )
  }

  pl
}

Try the cvms package in your browser

Any scripts or data that you put into this service are public.

cvms documentation built on Sept. 11, 2024, 6:22 p.m.