knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
There are three main metric types in yardstick
: class, class probability, and
numeric. Each type of metric has standardized argument syntax, and all metrics
return the same kind of output (a tibble with 3 columns). This standardization
allows metrics to easily be grouped together and used with grouped data frames
for computing on multiple resamples at once. Below are the five types of
metrics, along with the types of the inputs they take.
1) Class metrics (hard predictions)
- `truth` - factor - `estimate` - factor
2) Class probability metrics (soft predictions)
- truth
- factor
- `estimate / ...` - multiple numeric columns containing class probabilities
3) Numeric metrics
- `truth` - numeric - `estimate` - numeric
4) Static survival metircs
- `truth` - Surv - `estimate` - numeric
5) dynamic survival metrics
- `truth` - Surv - `...` - list of data.frames, each containing the 3 columns `.eval_time`, `.pred_survival, and `.weight_censored`
In the following example, the hpc_cv
data set is used. It contains class
probabilities and class predictions for a linear discriminant analysis fit to
the HPC data set of Kuhn and Johnson (2013). It is fit with 10 fold cross-validation,
and the predictions for all folds are included.
library(yardstick) library(dplyr) data("hpc_cv") hpc_cv %>% group_by(Resample) %>% slice(1:3)
1 metric, 1 resample
hpc_cv %>% filter(Resample == "Fold01") %>% accuracy(obs, pred)
1 metric, 10 resamples
hpc_cv %>% group_by(Resample) %>% accuracy(obs, pred)
2 metrics, 10 resamples
class_metrics <- metric_set(accuracy, kap) hpc_cv %>% group_by(Resample) %>% class_metrics(obs, estimate = pred)
Below is a table of all of the metrics available in yardstick
, grouped
by type.
library(knitr) library(dplyr) yardns <- asNamespace("yardstick") fns <- lapply(names(yardns), get, envir = yardns) names(fns) <- names(yardns) get_metrics <- function(fns, type) { where <- vapply(fns, inherits, what = type, FUN.VALUE = logical(1)) paste0("`", sort(names(fns[where])), "()`") } all_metrics <- bind_rows( tibble(type = "class", metric = get_metrics(fns, "class_metric")), tibble(type = "class prob", metric = get_metrics(fns, "prob_metric")), tibble(type = "numeric", metric = get_metrics(fns, "numeric_metric")), tibble(type = "dynamic survival", metric = get_metrics(fns, "dynamic_survival_metric")), tibble(type = "static survival", metric = get_metrics(fns, "static_survival_metric")) ) kable(all_metrics, format = "html")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.