R/utils_plotting.R

Defines functions plot_class_stats plot_class_p plot_performance plot_confusion plot_roc plot_regression

Documented in plot_class_p plot_class_stats plot_confusion plot_performance plot_regression plot_roc

# This script provides plotting functions to visualize correlation of the fitted and outcome values.

# Plotting for regression -----

#' Plot fitted values versus outcome for regression objects.
#'
#' @description
#' Generates a point plot with for the fitted and outcome values for
#' regression predictions with a fitted trend as an option.
#'
#' @details The fitted trend is generated by \code{\link[ggplot2]{geom_smooth}}.
#'
#' @param predx_object a `predx` object.
#' @param x_var name of the variable presented in the X axis.
#' @param y_var name of the variable presented in the Y axis.
#' @param point_size size of the plot points.
#' @param point_shape shape of the points.
#' @param point_color color of the data points.
#' @param point_wjitter horizontal jittering of the points.
#' @param point_hjitter vertical jittering of the points.
#' @param point_alpha plot point alpha.
#' @param show_trend logical, should a trend line be displayed?
#' @param trend_method a method for fitting the trend line, see
#' \code{\link[ggplot2]{geom_smooth}} for details, defaults to 'lm'.
#' @param show_calibration logical, should a line with slope 1 and
#' intercept 1 be displayed in the plot?
#' @param line_size size of the calibration or trend line.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, number of complete observations if not specified by the user.
#' @param x_lab X axis title.
#' @param y_lab Y axis title.
#' @param cust_theme customized plot theme provided by the user.
#' @param ... extra arguments passed to \code{\link[ggplot2]{geom_smooth}}.
#'
#' @return returns a ggplot object.

  plot_regression <- function(predx_object,
                              x_var = '.outcome',
                              y_var = '.fitted',
                              point_size = 2,
                              point_shape = 21,
                              point_color = 'steelblue',
                              point_wjitter = 0.01,
                              point_hjitter = 0.01,
                              point_alpha = 0.75,
                              show_trend = TRUE,
                              trend_method = 'lm',
                              show_calibration = TRUE,
                              line_size = 0.5,
                              plot_title = NULL,
                              plot_subtitle = NULL,
                              plot_tag = NULL,
                              x_lab = x_var,
                              y_lab = y_var,
                              cust_theme = ggplot2::theme_classic(), ...) {

    ## entry control ------

    stopifnot(is_predx(predx_object))
    stopifnot(is.logical(show_trend))
    stopifnot(is.logical(show_calibration))
    stopifnot(inherits(cust_theme, 'theme'))
    stopifnot(is.numeric(line_size))

    if(predx_object$type %in% c('multi_class', 'binary')) {

      warning(paste('Regression plots for the multi-class or',
                    'binary predictions are not available.'),
              call. = FALSE)

      return(NULL)

    }

    if(is.null(plot_tag)) {

      plot_tag <- paste('n =', nobs(predx_object))

    }

    ## plotting -------

    reg_plot <-
      ggplot(components(predx_object, 'data'),
             aes(x = .data[[x_var]],
                 y = .data[[y_var]])) +
      ggplot2::geom_point(shape = point_shape,
                          size = point_size,
                          alpha = point_alpha,
                          fill = point_color,
                          position = ggplot2::position_jitter(width = point_wjitter,
                                                              height = point_hjitter)) +
      cust_theme +
      ggplot2::labs(title = plot_title,
                    subtitle = plot_subtitle,
                    tag = plot_tag,
                    x = x_lab,
                    y = y_lab)

    if(show_trend) {

      reg_plot <- reg_plot +
        ggplot2::geom_smooth(method = trend_method,
                             size = line_size, ...)

    }

    if(show_calibration) {

      reg_plot <- reg_plot +
        ggplot2::geom_abline(intercept = 0,
                             slope = 1,
                             color = 'black',
                             size = line_size)

    }

    reg_plot

  }

# Binary classification: ROC plot ------

#' Plot a receiver-operator characteristic curve.
#'
#' @description
#' Generates a ROC plot with for the fitted and outcome values.
#' Optionally, a custom annotation inside the plot may be added.
#'
#' @details The plot is generated by \code{\link[plotROC]{geom_roc}}.
#'
#' @param predx_object a `predx` object.
#' @param line_color color of the ROC line.
#' @param line_size size of the ROC line.
#' @param cutoffs_at numeric, between 0 and 1, indicates the
#' cut points to be presented in the ROC curve.
#' @param point_size size of the cutoff point.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, contains the number of complete observations and
#' vents if not specified by the user.
#' @param annotation_txt annotation text.
#' @param annotation_color annotation text color.
#' @param annotation_size size of the annotation text and of the cutoff label.
#' @param annotation_x annotation x position.
#' @param annotation_y annotation y position.
#' @param annotation_hjust horizontal justification of the annotation text.
#' @param annotation_vjust horizontal justification of the annotation text.
#' @param cust_theme customized plot theme provided by the user.
#' @param ... extra arguments passed to \code{\link[plotROC]{geom_roc}}.
#'
#' @return returns a  ggplot object.

  plot_roc <- function(predx_object,
                       line_color = 'steelblue',
                       line_size = 0.5,
                       cutoffs_at = 0.5,
                       point_size = 0.3,
                       plot_title = NULL,
                       plot_subtitle = NULL,
                       plot_tag = NULL,
                       annotation_txt = NULL,
                       annotation_color = line_color,
                       annotation_size = 2.75,
                       annotation_x = 0.6,
                       annotation_y = 0.3,
                       annotation_hjust = 0,
                       annotation_vjust = 1,
                       cust_theme = NULL, ...) {

    ## entry control ------

    stopifnot(is_predx(predx_object))

    .outcome <- NULL

    if(!is.null(cust_theme)) stopifnot(inherits(cust_theme, 'theme'))

    if(predx_object$type %in% c('regression', 'multi_class')) {

      warning(paste('ROC plots for the multi-class or regression',
                    'predictions are not available.'),
              call. = FALSE)

      return(NULL)

    }

    if(is.null(plot_tag)) {

      plot_tag <-
        paste0('total: n = ', nobs(predx_object),
               ', events: n = ', count(predx_object$data, .outcome)[2, 2])

    }

    if(is.factor(predx_object$data[['.outcome']])) {

      data <- mutate(predx_object$data,
                     .outcome = as.numeric(.outcome) - 1)

    } else {

      data <- predx_object$data

    }

    ## plotting -------

    roc_plot <-
      ggplot(data,
             aes(d = .data[['.outcome']],
                 m = .data[[predx_object$classes[2]]])) +
      plotROC::geom_roc(labelsize = annotation_size,
                        cutoffs.at = cutoffs_at,
                        pointsize = point_size,
                        color = line_color,
                        size = line_size, ...) +
      plotROC::style_roc() +
      ggplot2::geom_abline(slope = 1,
                           intercept = 0,
                           linetype = 'dashed') +
      ggplot2::labs(title = plot_title,
                    subtitle = plot_subtitle,
                    tag = plot_tag)

    if(!is.null(cust_theme)) {

      roc_plot <- roc_plot +
        cust_theme

    }

    if(!is.null(annotation_txt)) {

      roc_plot <- roc_plot +
        ggplot2::annotate('text',
                          label = annotation_txt,
                          x = annotation_x,
                          y = annotation_y,
                          hjust = annotation_hjust,
                          vjust = annotation_vjust,
                          size = annotation_size)

    }

    roc_plot

  }

# Classification models: plots of confusion matrix -----

#' Plot confusion matrix for a classification predx object.
#'
#' @description Generates a heat map representation of the confusion matrix.
#' Outcome is presented in the X axis, fitted is presented in the Y axis.
#'
#' @param predx_object a `predx` object.
#' @param scale indicates, how the table is to be scaled.
#' 'none' returns the counts (default),
#' 'fraction' returns the fraction of all observations,
#' 'percent' returns the percent of all observations.
#' @param show_labels logical, indicates if counts/fractions/percents are shown
#' in the plot.
#' @param label_size size of the label text.
#' @param label_color color of the label text.
#' @param signif_digits significant digits for rounding of the label values.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, contains a reference to scale and the
#' observation number if not specified.
#' @param x_lab X axis label.
#' @param y_lab Y axis label.
#' @param cust_theme customized plot theme provided by the user.
#' @return returns a ggplot object.

  plot_confusion <- function(predx_object,
                             scale = c('none', 'fraction', 'percent'),
                             show_labels = TRUE,
                             label_size = 2.75,
                             label_color = 'black',
                             signif_digits = 2,
                             plot_title = NULL,
                             plot_subtitle = NULL,
                             plot_tag = NULL,
                             x_lab = '.outcome',
                             y_lab = '.fitted',
                             cust_theme = ggplot2::theme_classic()) {

    ## entry control --------

    stopifnot(is_predx(predx_object))
    stopifnot(is.logical(show_labels))

    if(!is.null(cust_theme)) stopifnot(inherits(cust_theme, 'theme'))

    if(predx_object$type == 'regression') {

      warning(paste('Confusion matrix in not available',
                    'for regression predictions.'),
              call. = FALSE)

      return(NULL)

    }

    if(is.null(plot_tag)) {

      scale_lab <- c('none' = 'Counts',
                     'fraction' = 'Fraction of total',
                     'percent' = '% of total')

      plot_tag <- paste0(scale_lab[scale],
                         ', total: n = ',
                         nobs(predx_object))

    }

    ## plotting ------

    conf <- as.data.frame(confusion(predx_object, scale = scale))

    conf_plot <- ggplot(conf,
                        aes(x = .data[['.outcome']],
                            y = .data[['.fitted']],
                            fill = .data[['Freq']])) +
      ggplot2::geom_tile(color = 'black') +
      ggplot2::scale_fill_gradient2(low = 'steelblue',
                                    high = 'firebrick',
                                    mid = 'white') +
      cust_theme +
      ggplot2::labs(title = plot_title,
                    subtitle = plot_subtitle,
                    tag = plot_tag,
                    x = x_lab,
                    y = y_lab,
                    fill = NULL)

    if(show_labels) {

      conf_plot <- conf_plot +
        ggplot2::geom_text(aes(label = signif(.data[['Freq']], signif_digits)),
                           size = label_size,
                           hjust = 0.5,
                           vjust = 0.5,
                           color = label_color)

    }

    conf_plot

  }

# Plots of performance stats --------

#' Plots of performance metrics in the training, resample and test data.
#'
#' @description
#' This internally utilized function takes a `caretx` model and plots selected
#' performance stats appropriate for the prediction type in the test, resample
#' and, optionally, training data set.
#' Scatter plots are generated, the data set type is color-coded and,
#' optionally, indicated in the plot.
#'
#' @details
#' The plotted performance stats are:
#'
#' * for regression: pseudo-R-squared (Y axis) and
#' root mean square error (RMSE, X axis), Spearman's coefficient for
#' correlation of the outcome and prediction is represented by the point size.
#'
#' * for classification: Brier score (Y axis) and Cohen's kappa, the overall
#' accuracy is represented by the point size.
#'
#' Numbers of complete observations are displayed in the plot subtitle,
#' if no user-provided subtitle is specified.
#'
#' @param caretx_object a `caretx` object.
#' @param newdata optional, a data frame with predictions.
#' @param plot_subtitle plot subtitle, see: `Details`.
#' @param plot_tag plot tag.
#' @param show_txt logical, should the data set names (training, CV and test)
#' be displayed in the plot?
#' @param txt_size size of the text labels of the data sets.
#' @param cust_theme a custom ggplot theme
#' @param ... extra arguments passed to \code{\link[ggrepel]{geom_text_repel}}
#'
#' @return a ggplot graphic.

  plot_performance <- function(caretx_object,
                               newdata = NULL,
                               plot_subtitle = NULL,
                               plot_tag = NULL,
                               show_txt = FALSE,
                               txt_size = 2.75,
                               cust_theme = ggplot2::theme_classic(), ...) {

    ## entry control -------

    stopifnot(is_caretx(caretx_object))
    stopifnot(is.numeric(txt_size))

    if(!inherits(cust_theme, 'theme')) {

      stop("'cust_theme' needs to be a valid ggplot 'theme' object.",
           call. = FALSE)

    }

    data_labs <-
      c(train = 'training',
        cv = 'CV',
        test = 'test')

    data_colors <-
      c(train = 'steelblue',
        cv = 'gray40',
        test = 'firebrick4')

    RMSE <- NULL
    rsq <- NULL
    spearman <- NULL
    dataset <- NULL
    brier_score <- NULL
    correct_rate <- NULL

    ## predictions ------

    preds <- compact(predict(caretx_object, newdata = newdata))

    ## plot subtitles ------

    if(is.null(plot_subtitle)) {

      n_numbers <- map_dbl(preds, nobs)

      plot_subtitle <-
        map2_chr(data_labs[names(n_numbers)], n_numbers,
                 paste, sep = ': n = ')

      plot_subtitle <- paste(plot_subtitle, collapse = ', ')

    }

    ## plotting data -----

    plot_tbl <- map(preds, summary)

    plot_tbl <- map(plot_tbl, ~.x[c('statistic', 'estimate')])

    plot_tbl <- reduce(plot_tbl, left_join, by = 'statistic')

    plot_tbl <- set_names(plot_tbl, c('statistic', names(preds)))

    plot_tbl <- t(column_to_rownames(plot_tbl, 'statistic'))

    plot_tbl <- rownames_to_column(as.data.frame(plot_tbl), 'dataset')

    ## base plots -------

    pred_type <- preds[[1]]$type

    if(pred_type == 'regression') {

      sc_plot <- ggplot(plot_tbl,
                        aes(x = RMSE,
                            y = rsq,
                            size = spearman,
                            fill = dataset)) +
        ggplot2::labs(title = 'Regression model performance',
                      subtitle = plot_subtitle,
                      tag = plot_tag,
                      x = 'RMSE',
                      y = expression('pseudo-R'^2),
                      size = expression("Spearman's " * rho))

    }

    if(pred_type == 'binary') {

      sc_plot <- ggplot(plot_tbl,
                        aes(x = kappa,
                            y = 1 - brier_score,
                            size = correct_rate,
                            fill = dataset)) +
        ggplot2::labs(title = 'Binary classification model performance',
                      subtitle = plot_subtitle,
                      tag = plot_tag,
                      x = expression("Cohen's " * kappa),
                      y = '1 - Brier score',
                      size = 'Accuracy')

    }

    if(pred_type == 'multi_class') {

      sc_plot <- ggplot(plot_tbl,
                        aes(x = kappa,
                            y = 2 - brier_score,
                            size = correct_rate,
                            fill = dataset)) +
        ggplot2::labs(title = 'Multi-category classification model performance',
                      subtitle = plot_subtitle,
                      tag = plot_tag,
                      x = expression("Cohen's " * kappa),
                      y = '2 - Brier score',
                      size = 'Accuracy')

    }

    ## common plot format -------

    sc_plot <- sc_plot +
      ggplot2::geom_point(shape = 21,
                          color = 'black') +
      ggplot2::scale_fill_manual(values = data_colors,
                                 labels = data_labs,
                                 name = 'Data set') +
      cust_theme

    if(show_txt) {

      sc_plot <- sc_plot +
        ggrepel::geom_text_repel(aes(label = unname(data_labs[dataset]),
                                     color = dataset),
                                 size = txt_size,
                                 show.legend = FALSE, ...) +
        ggplot2::scale_color_manual(values = data_colors,
                                    labels = data_labs,
                                    name = 'Data set')

    }

    sc_plot

  }

# Brier scores and class assignment p in the outcome classes -----

#' Squared distance to outcome and class assignment probability
#' in the outcome classes.
#'
#' @description
#' `plot_class_p`: This internally used function plots squared distances
#' to the outcome (as defined by Brier at al.) and class-assignment
#' probabilities for the outcome classes as scatter plots.
#' The correct/false class assignment is color-coded.
#' The observations are sorted by the statistic value.
#' Numbers of complete observations are indicated in the plot subtitle
#' (if not provided by the user) and class n numbers can be displayed in
#' the plot facets (`show_class_n` set to TRUE).
#'
#' `plot_class_stats`: The function plots squared distances
#' to the outcome (as defined by Brier at al.) and class-assignment
#' probabilities for the outcome classes as box plots. Class n numbers are
#' indicated in the X axis.
#'
#' @details
#' For regression, NULL and a warning is returned.
#'
#'
#' @references
#' Brier GW. VERIFICATION OF FORECASTS EXPRESSED IN TERMS OF PROBABILITY.
#' Mon Weather Rev (1950) 78:1–3.
#' doi:10.1175/1520-0493(1950)078<0001:vofeit>2.0.co;2
#' @references
#' Goldstein-Greenwood J. A Brief on Brier Scores | UVA Library. (2021)
#' Available at: https://library.virginia.edu/data/articles/a-brief-on-brier-scores
#'
#' @param predx_object a `predx` class object.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag.
#' @param show_class_n logical, should n numbers for classes be presented
#' in the plot facets?
#' @param flip logical, exchange the X and Y axes in the plots?
#' @param point_size size of the data points.
#' @param hide_obs_labels logical, hide labels of single observations?
#' @param label_misclassified logical, should misclassified observation be
#' labeled with their numbers?
#' @param txt_size size of the text in the observation labels. Ignored if
#' `label_misclassified` is set to FALSE.
#' @param txt_color color of the text in the observation labels. Ignored if
#' `label_misclassified` is set to FALSE.
#' @param point_alpha alpha of the data points.
#' @param point_hjitter height of the data point jittering.
#' @param point_wjitter width of the data point jittering.
#' @param box_alpha alpha of the box plots.
#' @param cust_theme a custom ggplot theme.
#'
#' @return a list of ggplot graphics: one plot for squared distances
#' (`square_dist`) and one plot for class assignment probabilities (`winner_p`).

  plot_class_p <- function(predx_object,
                           plot_subtitle = NULL,
                           plot_tag = NULL,
                           show_class_n = TRUE,
                           flip = FALSE,
                           point_size = 2,
                           hide_obs_labels = TRUE,
                           label_misclassified = TRUE,
                           txt_size = 2.75,
                           txt_color = 'firebrick',
                           cust_theme = ggplot2::theme_classic()) {

    ## entry control ------

    stopifnot(is_predx(predx_object))
    stopifnot(is.logical(show_class_n))
    stopifnot(is.numeric(point_size))
    stopifnot(is.logical(hide_obs_labels))
    stopifnot(is.logical(label_misclassified))

    if(!inherits(cust_theme, 'theme')) {

      stop("'cust_theme' has to be a valid ggplot 'theme' object.",
           call. = FALSE)

    }

    if(predx_object$type == 'regression') {

      warning('Class-specific plots are not available for regression.',
              call. = FALSE)

      return(NULL)

    }

    correct <- NULL
    .outcome <- NULL
    .fitted <- NULL

    .observation <- NULL
    .resample <- NULL
    square_dist <- NULL
    winner_p <- NULL

    ## plotting data ------

    sq_tbl <- squared(predx_object)
    class_p_tbl <- classp(predx_object)

    sq_tbl <- select(sq_tbl,
                     any_of(c('.observation', '.resample',
                              '.outcome', '.fitted',
                              'square_dist')))

    class_p_tbl <- select(class_p_tbl,
                          any_of(c('.observation', '.resample',
                                   'winner_p')))

    if(predx_object$prediction == 'cv') {

      by_vars <- c('.observation', '.resample')

    } else {

      by_vars <- '.observation'

    }

    plot_tbl <- left_join(sq_tbl, class_p_tbl, by = by_vars)

    plot_tbl <- mutate(plot_tbl,
                       correct = ifelse(.outcome == .fitted,
                                        'correct', 'misclassified'),
                       correct = factor(correct,
                                        c('correct', 'misclassified')),
                       plot_lab = ifelse(correct == 'misclassified',
                                         .observation, NA))

    ## plot subtitle and labeller ------

    if(show_class_n) {

      n_numbers <-
        count(components(predx_object, 'data'), .outcome, .drop = FALSE)

      facet_labs <- map2_chr(n_numbers[[1]], n_numbers[[2]],
                             paste, sep = '\nn = ')

      facet_labs <- set_names(facet_labs, n_numbers[[1]])

      facet_labs <- ggplot2::as_labeller(facet_labs)

    } else {

      facet_labs <- 'label_value'

    }

    if(is.null(plot_subtitle)) plot_subtitle <- paste('n =', nobs(predx_object))

    ## base plots --------

    if(!flip) {

      sc_plots <-
        list(square_dist = ggplot(plot_tbl,
                                  aes(x = reorder(.observation, square_dist),
                                      y = square_dist,
                                      fill = correct)),
             winner_p = ggplot(plot_tbl,
                               aes(x = reorder(.observation, winner_p),
                                   y = winner_p,
                                   fill = correct)))

      sc_plots <- map(sc_plots,
                      ~.x +
                        ggplot2::facet_grid(. ~ .outcome,
                                            labeller = facet_labs,
                                            scales = 'free',
                                            space = 'free'))


    } else {

      sc_plots <-
        list(square_dist = ggplot(plot_tbl,
                                  aes(y = reorder(.observation, square_dist),
                                      x = square_dist,
                                      fill = correct)),
             winner_p = ggplot(plot_tbl,
                               aes(y = reorder(.observation, winner_p),
                                   x = winner_p,
                                   fill = correct)))

      sc_plots <- map(sc_plots,
                      ~.x +
                        ggplot2::facet_grid(.outcome ~ .,
                                            labeller = facet_labs,
                                            scales = 'free',
                                            space = 'free'))

    }

    ## points, titles and labels --------

    plot_lst <-
      list(x = sc_plots,
           y = c('Square distance to outcome',
                 'Class assignment probability'),
           z = c('square distance', 'p'))

    if(!flip) {

      sc_plots <-
        pmap(plot_lst,
             function(x, y, z) x +
               ggplot2::labs(title = y,
                             subtitle = plot_subtitle,
                             tag = plot_tag,
                             y = z,
                             x = 'observation'))

    } else {

      sc_plots <-
        pmap(plot_lst,
             function(x, y, z) x +
               ggplot2::labs(title = y,
                             subtitle = plot_subtitle,
                             tag = plot_tag,
                             x = z,
                             y = 'Observation'))

    }

    sc_plots <-
      map(sc_plots,
          ~.x +
            ggplot2::geom_point(shape = 21,
                                size = point_size) +
            ggplot2::scale_fill_manual(values = c(correct = 'steelblue',
                                                  misclassified = 'firebrick'),
                                       name = '') +
            cust_theme)

    if(hide_obs_labels & !flip) {

      sc_plots <-
        map(sc_plots,
            ~.x +
              ggplot2::theme(axis.text.x = ggplot2::element_blank(),
                             axis.ticks.x = ggplot2::element_blank(),
                             panel.grid.major.x = ggplot2::element_blank()))

    }

    if(hide_obs_labels & flip) {

      sc_plots <-
        map(sc_plots,
            ~.x +
              ggplot2::theme(axis.text.y = ggplot2::element_blank(),
                             axis.ticks.y = ggplot2::element_blank(),
                             panel.grid.major.y = ggplot2::element_blank()))

    }

    if(label_misclassified) {

      sc_plots <-
        map(sc_plots,
            ~.x +
              ggrepel::geom_text_repel(aes(label = plot_lab),
                                       size = txt_size,
                                       color = txt_color))

    }

    set_names(sc_plots,
              c('square_dist', 'winner_p'))

  }

#' @rdname plot_class_p

  plot_class_stats <- function(predx_object,
                               plot_subtitle = NULL,
                               plot_tag = NULL,
                               point_size = 2,
                               point_hjitter = 0,
                               point_wjitter = 0.1,
                               point_alpha = 0.75,
                               box_alpha = 0.5,
                               cust_theme = ggplot2::theme_classic()) {

    ## entry control -------

    stopifnot(is_predx(predx_object))
    stopifnot(is.numeric(point_size))

    if(!inherits(cust_theme, 'theme')) {

      stop("'cust_theme' has to be a valid ggplot 'theme' object.",
           call. = FALSE)

    }

    if(predx_object$type == 'regression') {

      warning('Class-specific plots are not available for regression.',
              call. = FALSE)

      return(NULL)

    }

    .outcome <- NULL

    ## plotting data -------

    sq_tbl <- squared(predx_object)
    class_p_tbl <- classp(predx_object)

    sq_tbl <- select(sq_tbl,
                     any_of(c('.observation', '.resample',
                              '.outcome', '.fitted',
                              'square_dist')))

    class_p_tbl <- select(class_p_tbl,
                          any_of(c('.observation', '.resample',
                                   'winner_p')))

    if(predx_object$prediction == 'cv') {

      by_vars <- c('.observation', '.resample')

    } else {

      by_vars <- '.observation'

    }

    plot_tbl <- left_join(sq_tbl, class_p_tbl, by = by_vars)

    ## plot subtitle and n numbers -------

    n_numbers <-
      count(components(predx_object, 'data'), .outcome, .drop = FALSE)

    x_labs <- map2_chr(n_numbers[[1]], n_numbers[[2]],
                       paste, sep = '\nn = ')

    x_labs <- set_names(x_labs, n_numbers[[1]])

    if(is.null(plot_subtitle)) plot_subtitle <- paste('n =', nobs(predx_object))

    ## base plots ---------

    box_plots <-
      pmap(list(var = c('square_dist', 'winner_p'),
                title = c('Square distance to outcome',
                          'Class assignment probability'),
                y_lab = c('square distance', 'p')),
           function(var, title, y_lab) ggplot(plot_tbl,
                                              aes(x = .outcome,
                                                  y = .data[[var]],
                                                  fill = .outcome)) +
             ggplot2::geom_boxplot(alpha = box_alpha,
                                   outlier.color = NA) +
             ggplot2::geom_point(shape = 21,
                                 size = point_size,
                                 alpha = point_alpha,
                                 color = 'black',
                                 position = ggplot2::position_jitter(width = point_wjitter,
                                                                     height = point_hjitter)) +
             ggplot2::scale_x_discrete(labels = x_labs) +
             cust_theme +
             ggplot2::labs(title = title,
                           y = y_lab,
                           subtitle = plot_subtitle,
                           tag = plot_tag))

    set_names(box_plots,
              c('square_dist', 'winner_p'))

  }

# END -----
PiotrTymoszuk/caretExtra documentation built on Oct. 15, 2023, 10:03 p.m.