plot_probabilities: Plot predicted probabilities

View source: R/plot_probabilities.R

plot_probabilitiesR Documentation

Plot predicted probabilities

Description

\Sexpr[results=rd, stage=render]{lifecycle::badge("experimental")}

Creates a ggplot2 line plot object with the probabilities of either the target classes or the predicted classes.

The observations are ordered by the highest probability.

TODO line geom: average probability per observation

TODO points geom: actual probabilities per observation

The meaning of the horizontal lines depend on the settings. These are either recall scores, precision scores, or accuracy scores, depending on the `probability_of` and `apply_facet` arguments.

Usage

plot_probabilities(
  data,
  target_col,
  probability_cols,
  predicted_class_col = NULL,
  obs_id_col = NULL,
  group_col = NULL,
  probability_of = "target",
  positive = 2,
  order = "centered",
  theme_fn = ggplot2::theme_minimal,
  color_scale = ggplot2::scale_colour_brewer(palette = "Dark2"),
  apply_facet = length(probability_cols) > 1,
  smoothe = FALSE,
  add_points = !is.null(obs_id_col),
  add_hlines = TRUE,
  add_caption = TRUE,
  show_x_scale = FALSE,
  line_settings = list(),
  smoothe_settings = list(),
  point_settings = list(),
  hline_settings = list(),
  facet_settings = list(),
  ylim = c(0, 1)
)

Arguments

data

data.frame with probabilities, target classes and (optional) predicted classes. Can also include observation identifiers and a grouping variable.

Example for binary classification:

Classifier Observation Probability Target Prediction
SVM 1 0.3 cl_1 cl_1
SVM 2 0.7 cl_1 cl_2
NB 1 0.2 cl_2 cl_1
NB 2 0.8 cl_2 cl_2
... ... ... ... ...

Example for multiclass classification:

Classifier Observation cl_1 cl_2 cl_3 Target Prediction
SVM 1 0.2 0.1 0.7 cl_1 cl_3
SVM 2 0.3 0.5 0.2 cl_1 cl_2
NB 1 0.8 0.1 0.1 cl_2 cl_1
NB 2 0.1 0.6 0.3 cl_3 cl_2
... ... ... ... ... ... ...

You can have multiple rows per observation ID per group. If, for instance, we have run repeated cross-validation of 3 classifiers, we would have one predicted probability per fold column per classifier.

As created with the various validation functions in cvms, like cross_validate_fn().

target_col

Name of column with target levels.

probability_cols

Name of columns with predicted probabilities.

For binary classification, this should be one column with the probability of the second class (alphabetically).

For multiclass classification, this should be one column per class. These probabilities must sum to 1 row-wise.

predicted_class_col

Name of column with predicted classes.

This is required when probability_of = "prediction" and/or add_hlines = TRUE.

obs_id_col

Name of column with observation identifiers for grouping the x-axis. When NULL, each row is an observation.

Use case: when you have multiple predicted probabilities per observation by a classifier (e.g. from repeated cross-validation).

Can also be a grouping variable that you wish to aggregate.

group_col

Name of column with groups. The plot elements are split by these groups and can be identified by their color.

E.g. the classifier responsible for the prediction.

N.B. With more than 8 groups, the default `color_scale` might run out of colors.

probability_of

Whether to plot the probabilities of the target classes ("target") or the predicted classes ("prediction").

For each row, we extract the probability of either the target class or the predicted class. Both are useful to plot, as they show the behavior of the classifier in a way a confusion matrix doesn't. One classifier might be very certain in its predictions (whether wrong or right), whereas another might be less certain.

positive

TODO

order

How to order of the the probabilities. (Character)

One of: "descending", "ascending", and "centered".

theme_fn

The ggplot2 theme function to apply.

color_scale

ggplot2 color scale object for adding discrete colors to the plot.

E.g. the output of ggplot2::scale_colour_brewer() or ggplot2::scale_colour_viridis_d().

N.B. The number of colors in the object's palette should be at least the same as the number of groups in the `group_col` column.

apply_facet

Whether to use ggplot2::facet_wrap(). (Logical)

By default, faceting is applied when there are more than one probability column (multiclass).

smoothe

Whether to use ggplot2::geom_smooth() instead of ggplot2::geom_line(). This also adds a 95% confidence interval by default.

Settings can be passed via the `smoothe_settings` argument.

add_points

Add a point for each predicted probability. These are grouped on the x-axis by the `obs_id_col` column. (Logical)

add_hlines

Add horizontal lines. (Logical)

The meaning of these lines depends on the `probability_of` and `apply_facet` arguments:

apply_facet probability_of Metric
FALSE "target" Accuracy
FALSE "prediction" Accuracy
TRUE "target" Recall / Sensitivity
TRUE "prediction" Precision / PPV
add_caption

Whether to add a caption explaining the plot. This is dynamically generated and intended as a starting point. (Logical)

You can overwrite the text with ggplot2::labs(caption = "...").

show_x_scale

TODO

line_settings

Named list of arguments for ggplot2::geom_line().

The mapping argument is set separately.

Any argument not in the list will use its default value.

Default: list(size = 0.5)

N.B. Ignored when smoothe = TRUE.

smoothe_settings

Named list of arguments for ggplot2::geom_smooth().

The mapping argument is set separately.

Any argument not in the list will use its default value.

Default: list(size = 0.5, alpha = 0.18, level = 0.95, se = TRUE)

N.B. Only used when smoothe = TRUE.

point_settings

Named list of arguments for ggplot2::geom_point().

The mapping argument is set separately.

Any argument not in the list will use its default value.

Default: list(size = 0.1, alpha = 0.4)

hline_settings

Named list of arguments for ggplot2::geom_hline().

The mapping argument is set separately.

Any argument not in the list will use its default value.

Default: list(size = 0.35, alpha = 0.5)

facet_settings

Named list of arguments for ggplot2::facet_wrap().

The facets argument is set separately.

Any argument not in the list will use its default value.

Commonly set arguments are nrow and ncol.

ylim

Limits for the y-scale.

Details

TODO

Value

A ggplot2 object with a faceted line plot. TODO

Author(s)

Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk

See Also

Other plotting functions: font(), plot_confusion_matrix(), plot_metric_density(), plot_probabilities_ecdf(), sum_tile_settings()

Examples


# Attach cvms
library(cvms)
library(ggplot2)
library(dplyr)

#
# Multiclass
#

# Plot probabilities of target classes
# From repeated cross-validation of three classifiers

# plot_probabilities(
#   data = predicted.musicians,
#   target_col = "Target",
#   probability_cols = c("A", "B", "C", "D"),
#   predicted_class_col = "Predicted Class",
#   group_col = "Classifier",
#   obs_id_col = "ID",
#   probability_of = "target"
# )

# Plot probabilities of predicted classes
# From repeated cross-validation of three classifiers

# plot_probabilities(
#   data = predicted.musicians,
#   target_col = "Target",
#   probability_cols = c("A", "B", "C", "D"),
#   predicted_class_col = "Predicted Class",
#   group_col = "Classifier",
#   obs_id_col = "ID",
#   probability_of = "prediction"
# )

# Center probabilities

# plot_probabilities(
#   data = predicted.musicians,
#   target_col = "Target",
#   probability_cols = c("A", "B", "C", "D"),
#   predicted_class_col = "Predicted Class",
#   group_col = "Classifier",
#   obs_id_col = "ID",
#   probability_of = "prediction",
#   order = "centered"
# )

#
# Binary
#

# Filter the predicted.musicians dataset
# binom_data <- predicted.musicians %>%
#   dplyr::filter(
#     Target %in% c("A", "B")
#   ) %>%
#   # "B" is the second class alphabetically
#   dplyr::rename(Probability = B) %>%
#   dplyr::mutate(`Predicted Class` = ifelse(
#     Probability > 0.5, "B", "A")) %>%
#   dplyr::select(-dplyr::all_of(c("A","C","D")))

# Plot probabilities of predicted classes
# From repeated cross-validation of three classifiers

# plot_probabilities(
#   data = binom_data,
#   target_col = "Target",
#   probability_cols = "Probability",
#   predicted_class_col = "Predicted Class",
#   group_col = "Classifier",
#   obs_id_col = "ID",
#   probability_of = "target"
# )

# plot_probabilities(
#   data = binom_data,
#   target_col = "Target",
#   probability_cols = "Probability",
#   predicted_class_col = "Predicted Class",
#   group_col = "Classifier",
#   obs_id_col = "ID",
#   probability_of = "prediction",
#   ylim = c(0.5, 1)
# )



cvms documentation built on July 9, 2023, 6:56 p.m.