# This script provides plotting functions to visualize correlation of the fitted and outcome values.
# Plotting for regression -----
#' Plot fitted values versus outcome for regression objects.
#'
#' @description
#' Generates a point plot with for the fitted and outcome values for
#' regression predictions with a fitted trend as an option.
#'
#' @details The fitted trend is generated by \code{\link[ggplot2]{geom_smooth}}.
#'
#' @param predx_object a `predx` object.
#' @param x_var name of the variable presented in the X axis.
#' @param y_var name of the variable presented in the Y axis.
#' @param point_size size of the plot points.
#' @param point_shape shape of the points.
#' @param point_color color of the data points.
#' @param point_wjitter horizontal jittering of the points.
#' @param point_hjitter vertical jittering of the points.
#' @param point_alpha plot point alpha.
#' @param show_trend logical, should a trend line be displayed?
#' @param trend_method a method for fitting the trend line, see
#' \code{\link[ggplot2]{geom_smooth}} for details, defaults to 'lm'.
#' @param show_calibration logical, should a line with slope 1 and
#' intercept 1 be displayed in the plot?
#' @param line_size size of the calibration or trend line.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, number of complete observations if not specified by the user.
#' @param x_lab X axis title.
#' @param y_lab Y axis title.
#' @param cust_theme customized plot theme provided by the user.
#' @param ... extra arguments passed to \code{\link[ggplot2]{geom_smooth}}.
#'
#' @return returns a ggplot object.
plot_regression <- function(predx_object,
x_var = '.outcome',
y_var = '.fitted',
point_size = 2,
point_shape = 21,
point_color = 'steelblue',
point_wjitter = 0.01,
point_hjitter = 0.01,
point_alpha = 0.75,
show_trend = TRUE,
trend_method = 'lm',
show_calibration = TRUE,
line_size = 0.5,
plot_title = NULL,
plot_subtitle = NULL,
plot_tag = NULL,
x_lab = x_var,
y_lab = y_var,
cust_theme = ggplot2::theme_classic(), ...) {
## entry control ------
stopifnot(is_predx(predx_object))
stopifnot(is.logical(show_trend))
stopifnot(is.logical(show_calibration))
stopifnot(inherits(cust_theme, 'theme'))
stopifnot(is.numeric(line_size))
if(predx_object$type %in% c('multi_class', 'binary')) {
warning(paste('Regression plots for the multi-class or',
'binary predictions are not available.'),
call. = FALSE)
return(NULL)
}
if(is.null(plot_tag)) {
plot_tag <- paste('n =', nobs(predx_object))
}
## plotting -------
reg_plot <-
ggplot(components(predx_object, 'data'),
aes(x = .data[[x_var]],
y = .data[[y_var]])) +
ggplot2::geom_point(shape = point_shape,
size = point_size,
alpha = point_alpha,
fill = point_color,
position = ggplot2::position_jitter(width = point_wjitter,
height = point_hjitter)) +
cust_theme +
ggplot2::labs(title = plot_title,
subtitle = plot_subtitle,
tag = plot_tag,
x = x_lab,
y = y_lab)
if(show_trend) {
reg_plot <- reg_plot +
ggplot2::geom_smooth(method = trend_method,
size = line_size, ...)
}
if(show_calibration) {
reg_plot <- reg_plot +
ggplot2::geom_abline(intercept = 0,
slope = 1,
color = 'black',
size = line_size)
}
reg_plot
}
# Binary classification: ROC plot ------
#' Plot a receiver-operator characteristic curve.
#'
#' @description
#' Generates a ROC plot with for the fitted and outcome values.
#' Optionally, a custom annotation inside the plot may be added.
#'
#' @details The plot is generated by \code{\link[plotROC]{geom_roc}}.
#'
#' @param predx_object a `predx` object.
#' @param line_color color of the ROC line.
#' @param line_size size of the ROC line.
#' @param cutoffs_at numeric, between 0 and 1, indicates the
#' cut points to be presented in the ROC curve.
#' @param point_size size of the cutoff point.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, contains the number of complete observations and
#' vents if not specified by the user.
#' @param annotation_txt annotation text.
#' @param annotation_color annotation text color.
#' @param annotation_size size of the annotation text and of the cutoff label.
#' @param annotation_x annotation x position.
#' @param annotation_y annotation y position.
#' @param annotation_hjust horizontal justification of the annotation text.
#' @param annotation_vjust horizontal justification of the annotation text.
#' @param cust_theme customized plot theme provided by the user.
#' @param ... extra arguments passed to \code{\link[plotROC]{geom_roc}}.
#'
#' @return returns a ggplot object.
plot_roc <- function(predx_object,
line_color = 'steelblue',
line_size = 0.5,
cutoffs_at = 0.5,
point_size = 0.3,
plot_title = NULL,
plot_subtitle = NULL,
plot_tag = NULL,
annotation_txt = NULL,
annotation_color = line_color,
annotation_size = 2.75,
annotation_x = 0.6,
annotation_y = 0.3,
annotation_hjust = 0,
annotation_vjust = 1,
cust_theme = NULL, ...) {
## entry control ------
stopifnot(is_predx(predx_object))
.outcome <- NULL
if(!is.null(cust_theme)) stopifnot(inherits(cust_theme, 'theme'))
if(predx_object$type %in% c('regression', 'multi_class')) {
warning(paste('ROC plots for the multi-class or regression',
'predictions are not available.'),
call. = FALSE)
return(NULL)
}
if(is.null(plot_tag)) {
plot_tag <-
paste0('total: n = ', nobs(predx_object),
', events: n = ', count(predx_object$data, .outcome)[2, 2])
}
if(is.factor(predx_object$data[['.outcome']])) {
data <- mutate(predx_object$data,
.outcome = as.numeric(.outcome) - 1)
} else {
data <- predx_object$data
}
## plotting -------
roc_plot <-
ggplot(data,
aes(d = .data[['.outcome']],
m = .data[[predx_object$classes[2]]])) +
plotROC::geom_roc(labelsize = annotation_size,
cutoffs.at = cutoffs_at,
pointsize = point_size,
color = line_color,
size = line_size, ...) +
plotROC::style_roc() +
ggplot2::geom_abline(slope = 1,
intercept = 0,
linetype = 'dashed') +
ggplot2::labs(title = plot_title,
subtitle = plot_subtitle,
tag = plot_tag)
if(!is.null(cust_theme)) {
roc_plot <- roc_plot +
cust_theme
}
if(!is.null(annotation_txt)) {
roc_plot <- roc_plot +
ggplot2::annotate('text',
label = annotation_txt,
x = annotation_x,
y = annotation_y,
hjust = annotation_hjust,
vjust = annotation_vjust,
size = annotation_size)
}
roc_plot
}
# Classification models: plots of confusion matrix -----
#' Plot confusion matrix for a classification predx object.
#'
#' @description Generates a heat map representation of the confusion matrix.
#' Outcome is presented in the X axis, fitted is presented in the Y axis.
#'
#' @param predx_object a `predx` object.
#' @param scale indicates, how the table is to be scaled.
#' 'none' returns the counts (default),
#' 'fraction' returns the fraction of all observations,
#' 'percent' returns the percent of all observations.
#' @param show_labels logical, indicates if counts/fractions/percents are shown
#' in the plot.
#' @param label_size size of the label text.
#' @param label_color color of the label text.
#' @param signif_digits significant digits for rounding of the label values.
#' @param plot_title plot title.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag, contains a reference to scale and the
#' observation number if not specified.
#' @param x_lab X axis label.
#' @param y_lab Y axis label.
#' @param cust_theme customized plot theme provided by the user.
#' @return returns a ggplot object.
plot_confusion <- function(predx_object,
scale = c('none', 'fraction', 'percent'),
show_labels = TRUE,
label_size = 2.75,
label_color = 'black',
signif_digits = 2,
plot_title = NULL,
plot_subtitle = NULL,
plot_tag = NULL,
x_lab = '.outcome',
y_lab = '.fitted',
cust_theme = ggplot2::theme_classic()) {
## entry control --------
stopifnot(is_predx(predx_object))
stopifnot(is.logical(show_labels))
if(!is.null(cust_theme)) stopifnot(inherits(cust_theme, 'theme'))
if(predx_object$type == 'regression') {
warning(paste('Confusion matrix in not available',
'for regression predictions.'),
call. = FALSE)
return(NULL)
}
if(is.null(plot_tag)) {
scale_lab <- c('none' = 'Counts',
'fraction' = 'Fraction of total',
'percent' = '% of total')
plot_tag <- paste0(scale_lab[scale],
', total: n = ',
nobs(predx_object))
}
## plotting ------
conf <- as.data.frame(confusion(predx_object, scale = scale))
conf_plot <- ggplot(conf,
aes(x = .data[['.outcome']],
y = .data[['.fitted']],
fill = .data[['Freq']])) +
ggplot2::geom_tile(color = 'black') +
ggplot2::scale_fill_gradient2(low = 'steelblue',
high = 'firebrick',
mid = 'white') +
cust_theme +
ggplot2::labs(title = plot_title,
subtitle = plot_subtitle,
tag = plot_tag,
x = x_lab,
y = y_lab,
fill = NULL)
if(show_labels) {
conf_plot <- conf_plot +
ggplot2::geom_text(aes(label = signif(.data[['Freq']], signif_digits)),
size = label_size,
hjust = 0.5,
vjust = 0.5,
color = label_color)
}
conf_plot
}
# Plots of performance stats --------
#' Plots of performance metrics in the training, resample and test data.
#'
#' @description
#' This internally utilized function takes a `caretx` model and plots selected
#' performance stats appropriate for the prediction type in the test, resample
#' and, optionally, training data set.
#' Scatter plots are generated, the data set type is color-coded and,
#' optionally, indicated in the plot.
#'
#' @details
#' The plotted performance stats are:
#'
#' * for regression: pseudo-R-squared (Y axis) and
#' root mean square error (RMSE, X axis), Spearman's coefficient for
#' correlation of the outcome and prediction is represented by the point size.
#'
#' * for classification: Brier score (Y axis) and Cohen's kappa, the overall
#' accuracy is represented by the point size.
#'
#' Numbers of complete observations are displayed in the plot subtitle,
#' if no user-provided subtitle is specified.
#'
#' @param caretx_object a `caretx` object.
#' @param newdata optional, a data frame with predictions.
#' @param plot_subtitle plot subtitle, see: `Details`.
#' @param plot_tag plot tag.
#' @param show_txt logical, should the data set names (training, CV and test)
#' be displayed in the plot?
#' @param txt_size size of the text labels of the data sets.
#' @param cust_theme a custom ggplot theme
#' @param ... extra arguments passed to \code{\link[ggrepel]{geom_text_repel}}
#'
#' @return a ggplot graphic.
plot_performance <- function(caretx_object,
newdata = NULL,
plot_subtitle = NULL,
plot_tag = NULL,
show_txt = FALSE,
txt_size = 2.75,
cust_theme = ggplot2::theme_classic(), ...) {
## entry control -------
stopifnot(is_caretx(caretx_object))
stopifnot(is.numeric(txt_size))
if(!inherits(cust_theme, 'theme')) {
stop("'cust_theme' needs to be a valid ggplot 'theme' object.",
call. = FALSE)
}
data_labs <-
c(train = 'training',
cv = 'CV',
test = 'test')
data_colors <-
c(train = 'steelblue',
cv = 'gray40',
test = 'firebrick4')
RMSE <- NULL
rsq <- NULL
spearman <- NULL
dataset <- NULL
brier_score <- NULL
correct_rate <- NULL
## predictions ------
preds <- compact(predict(caretx_object, newdata = newdata))
## plot subtitles ------
if(is.null(plot_subtitle)) {
n_numbers <- map_dbl(preds, nobs)
plot_subtitle <-
map2_chr(data_labs[names(n_numbers)], n_numbers,
paste, sep = ': n = ')
plot_subtitle <- paste(plot_subtitle, collapse = ', ')
}
## plotting data -----
plot_tbl <- map(preds, summary)
plot_tbl <- map(plot_tbl, ~.x[c('statistic', 'estimate')])
plot_tbl <- reduce(plot_tbl, left_join, by = 'statistic')
plot_tbl <- set_names(plot_tbl, c('statistic', names(preds)))
plot_tbl <- t(column_to_rownames(plot_tbl, 'statistic'))
plot_tbl <- rownames_to_column(as.data.frame(plot_tbl), 'dataset')
## base plots -------
pred_type <- preds[[1]]$type
if(pred_type == 'regression') {
sc_plot <- ggplot(plot_tbl,
aes(x = RMSE,
y = rsq,
size = spearman,
fill = dataset)) +
ggplot2::labs(title = 'Regression model performance',
subtitle = plot_subtitle,
tag = plot_tag,
x = 'RMSE',
y = expression('pseudo-R'^2),
size = expression("Spearman's " * rho))
}
if(pred_type == 'binary') {
sc_plot <- ggplot(plot_tbl,
aes(x = kappa,
y = 1 - brier_score,
size = correct_rate,
fill = dataset)) +
ggplot2::labs(title = 'Binary classification model performance',
subtitle = plot_subtitle,
tag = plot_tag,
x = expression("Cohen's " * kappa),
y = '1 - Brier score',
size = 'Accuracy')
}
if(pred_type == 'multi_class') {
sc_plot <- ggplot(plot_tbl,
aes(x = kappa,
y = 2 - brier_score,
size = correct_rate,
fill = dataset)) +
ggplot2::labs(title = 'Multi-category classification model performance',
subtitle = plot_subtitle,
tag = plot_tag,
x = expression("Cohen's " * kappa),
y = '2 - Brier score',
size = 'Accuracy')
}
## common plot format -------
sc_plot <- sc_plot +
ggplot2::geom_point(shape = 21,
color = 'black') +
ggplot2::scale_fill_manual(values = data_colors,
labels = data_labs,
name = 'Data set') +
cust_theme
if(show_txt) {
sc_plot <- sc_plot +
ggrepel::geom_text_repel(aes(label = unname(data_labs[dataset]),
color = dataset),
size = txt_size,
show.legend = FALSE, ...) +
ggplot2::scale_color_manual(values = data_colors,
labels = data_labs,
name = 'Data set')
}
sc_plot
}
# Brier scores and class assignment p in the outcome classes -----
#' Squared distance to outcome and class assignment probability
#' in the outcome classes.
#'
#' @description
#' `plot_class_p`: This internally used function plots squared distances
#' to the outcome (as defined by Brier at al.) and class-assignment
#' probabilities for the outcome classes as scatter plots.
#' The correct/false class assignment is color-coded.
#' The observations are sorted by the statistic value.
#' Numbers of complete observations are indicated in the plot subtitle
#' (if not provided by the user) and class n numbers can be displayed in
#' the plot facets (`show_class_n` set to TRUE).
#'
#' `plot_class_stats`: The function plots squared distances
#' to the outcome (as defined by Brier at al.) and class-assignment
#' probabilities for the outcome classes as box plots. Class n numbers are
#' indicated in the X axis.
#'
#' @details
#' For regression, NULL and a warning is returned.
#'
#'
#' @references
#' Brier GW. VERIFICATION OF FORECASTS EXPRESSED IN TERMS OF PROBABILITY.
#' Mon Weather Rev (1950) 78:1–3.
#' doi:10.1175/1520-0493(1950)078<0001:vofeit>2.0.co;2
#' @references
#' Goldstein-Greenwood J. A Brief on Brier Scores | UVA Library. (2021)
#' Available at: https://library.virginia.edu/data/articles/a-brief-on-brier-scores
#'
#' @param predx_object a `predx` class object.
#' @param plot_subtitle plot subtitle.
#' @param plot_tag plot tag.
#' @param show_class_n logical, should n numbers for classes be presented
#' in the plot facets?
#' @param flip logical, exchange the X and Y axes in the plots?
#' @param point_size size of the data points.
#' @param hide_obs_labels logical, hide labels of single observations?
#' @param label_misclassified logical, should misclassified observation be
#' labeled with their numbers?
#' @param txt_size size of the text in the observation labels. Ignored if
#' `label_misclassified` is set to FALSE.
#' @param txt_color color of the text in the observation labels. Ignored if
#' `label_misclassified` is set to FALSE.
#' @param point_alpha alpha of the data points.
#' @param point_hjitter height of the data point jittering.
#' @param point_wjitter width of the data point jittering.
#' @param box_alpha alpha of the box plots.
#' @param cust_theme a custom ggplot theme.
#'
#' @return a list of ggplot graphics: one plot for squared distances
#' (`square_dist`) and one plot for class assignment probabilities (`winner_p`).
plot_class_p <- function(predx_object,
plot_subtitle = NULL,
plot_tag = NULL,
show_class_n = TRUE,
flip = FALSE,
point_size = 2,
hide_obs_labels = TRUE,
label_misclassified = TRUE,
txt_size = 2.75,
txt_color = 'firebrick',
cust_theme = ggplot2::theme_classic()) {
## entry control ------
stopifnot(is_predx(predx_object))
stopifnot(is.logical(show_class_n))
stopifnot(is.numeric(point_size))
stopifnot(is.logical(hide_obs_labels))
stopifnot(is.logical(label_misclassified))
if(!inherits(cust_theme, 'theme')) {
stop("'cust_theme' has to be a valid ggplot 'theme' object.",
call. = FALSE)
}
if(predx_object$type == 'regression') {
warning('Class-specific plots are not available for regression.',
call. = FALSE)
return(NULL)
}
correct <- NULL
.outcome <- NULL
.fitted <- NULL
.observation <- NULL
.resample <- NULL
square_dist <- NULL
winner_p <- NULL
## plotting data ------
sq_tbl <- squared(predx_object)
class_p_tbl <- classp(predx_object)
sq_tbl <- select(sq_tbl,
any_of(c('.observation', '.resample',
'.outcome', '.fitted',
'square_dist')))
class_p_tbl <- select(class_p_tbl,
any_of(c('.observation', '.resample',
'winner_p')))
if(predx_object$prediction == 'cv') {
by_vars <- c('.observation', '.resample')
} else {
by_vars <- '.observation'
}
plot_tbl <- left_join(sq_tbl, class_p_tbl, by = by_vars)
plot_tbl <- mutate(plot_tbl,
correct = ifelse(.outcome == .fitted,
'correct', 'misclassified'),
correct = factor(correct,
c('correct', 'misclassified')),
plot_lab = ifelse(correct == 'misclassified',
.observation, NA))
## plot subtitle and labeller ------
if(show_class_n) {
n_numbers <-
count(components(predx_object, 'data'), .outcome, .drop = FALSE)
facet_labs <- map2_chr(n_numbers[[1]], n_numbers[[2]],
paste, sep = '\nn = ')
facet_labs <- set_names(facet_labs, n_numbers[[1]])
facet_labs <- ggplot2::as_labeller(facet_labs)
} else {
facet_labs <- 'label_value'
}
if(is.null(plot_subtitle)) plot_subtitle <- paste('n =', nobs(predx_object))
## base plots --------
if(!flip) {
sc_plots <-
list(square_dist = ggplot(plot_tbl,
aes(x = reorder(.observation, square_dist),
y = square_dist,
fill = correct)),
winner_p = ggplot(plot_tbl,
aes(x = reorder(.observation, winner_p),
y = winner_p,
fill = correct)))
sc_plots <- map(sc_plots,
~.x +
ggplot2::facet_grid(. ~ .outcome,
labeller = facet_labs,
scales = 'free',
space = 'free'))
} else {
sc_plots <-
list(square_dist = ggplot(plot_tbl,
aes(y = reorder(.observation, square_dist),
x = square_dist,
fill = correct)),
winner_p = ggplot(plot_tbl,
aes(y = reorder(.observation, winner_p),
x = winner_p,
fill = correct)))
sc_plots <- map(sc_plots,
~.x +
ggplot2::facet_grid(.outcome ~ .,
labeller = facet_labs,
scales = 'free',
space = 'free'))
}
## points, titles and labels --------
plot_lst <-
list(x = sc_plots,
y = c('Square distance to outcome',
'Class assignment probability'),
z = c('square distance', 'p'))
if(!flip) {
sc_plots <-
pmap(plot_lst,
function(x, y, z) x +
ggplot2::labs(title = y,
subtitle = plot_subtitle,
tag = plot_tag,
y = z,
x = 'observation'))
} else {
sc_plots <-
pmap(plot_lst,
function(x, y, z) x +
ggplot2::labs(title = y,
subtitle = plot_subtitle,
tag = plot_tag,
x = z,
y = 'Observation'))
}
sc_plots <-
map(sc_plots,
~.x +
ggplot2::geom_point(shape = 21,
size = point_size) +
ggplot2::scale_fill_manual(values = c(correct = 'steelblue',
misclassified = 'firebrick'),
name = '') +
cust_theme)
if(hide_obs_labels & !flip) {
sc_plots <-
map(sc_plots,
~.x +
ggplot2::theme(axis.text.x = ggplot2::element_blank(),
axis.ticks.x = ggplot2::element_blank(),
panel.grid.major.x = ggplot2::element_blank()))
}
if(hide_obs_labels & flip) {
sc_plots <-
map(sc_plots,
~.x +
ggplot2::theme(axis.text.y = ggplot2::element_blank(),
axis.ticks.y = ggplot2::element_blank(),
panel.grid.major.y = ggplot2::element_blank()))
}
if(label_misclassified) {
sc_plots <-
map(sc_plots,
~.x +
ggrepel::geom_text_repel(aes(label = plot_lab),
size = txt_size,
color = txt_color))
}
set_names(sc_plots,
c('square_dist', 'winner_p'))
}
#' @rdname plot_class_p
plot_class_stats <- function(predx_object,
plot_subtitle = NULL,
plot_tag = NULL,
point_size = 2,
point_hjitter = 0,
point_wjitter = 0.1,
point_alpha = 0.75,
box_alpha = 0.5,
cust_theme = ggplot2::theme_classic()) {
## entry control -------
stopifnot(is_predx(predx_object))
stopifnot(is.numeric(point_size))
if(!inherits(cust_theme, 'theme')) {
stop("'cust_theme' has to be a valid ggplot 'theme' object.",
call. = FALSE)
}
if(predx_object$type == 'regression') {
warning('Class-specific plots are not available for regression.',
call. = FALSE)
return(NULL)
}
.outcome <- NULL
## plotting data -------
sq_tbl <- squared(predx_object)
class_p_tbl <- classp(predx_object)
sq_tbl <- select(sq_tbl,
any_of(c('.observation', '.resample',
'.outcome', '.fitted',
'square_dist')))
class_p_tbl <- select(class_p_tbl,
any_of(c('.observation', '.resample',
'winner_p')))
if(predx_object$prediction == 'cv') {
by_vars <- c('.observation', '.resample')
} else {
by_vars <- '.observation'
}
plot_tbl <- left_join(sq_tbl, class_p_tbl, by = by_vars)
## plot subtitle and n numbers -------
n_numbers <-
count(components(predx_object, 'data'), .outcome, .drop = FALSE)
x_labs <- map2_chr(n_numbers[[1]], n_numbers[[2]],
paste, sep = '\nn = ')
x_labs <- set_names(x_labs, n_numbers[[1]])
if(is.null(plot_subtitle)) plot_subtitle <- paste('n =', nobs(predx_object))
## base plots ---------
box_plots <-
pmap(list(var = c('square_dist', 'winner_p'),
title = c('Square distance to outcome',
'Class assignment probability'),
y_lab = c('square distance', 'p')),
function(var, title, y_lab) ggplot(plot_tbl,
aes(x = .outcome,
y = .data[[var]],
fill = .outcome)) +
ggplot2::geom_boxplot(alpha = box_alpha,
outlier.color = NA) +
ggplot2::geom_point(shape = 21,
size = point_size,
alpha = point_alpha,
color = 'black',
position = ggplot2::position_jitter(width = point_wjitter,
height = point_hjitter)) +
ggplot2::scale_x_discrete(labels = x_labs) +
cust_theme +
ggplot2::labs(title = title,
y = y_lab,
subtitle = plot_subtitle,
tag = plot_tag))
set_names(box_plots,
c('square_dist', 'winner_p'))
}
# END -----
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.