collect_predictions: Obtain and format results produced by tuning functions

View source: R/collect.R

collect_predictionsR Documentation

Obtain and format results produced by tuning functions

Description

Obtain and format results produced by tuning functions

Usage

collect_predictions(x, ...)

## Default S3 method:
collect_predictions(x, ...)

## S3 method for class 'tune_results'
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)

collect_metrics(x, ...)

## S3 method for class 'tune_results'
collect_metrics(x, summarize = TRUE, ...)

## S3 method for class 'tune_race'
collect_metrics(x, summarize = TRUE, ...)

collect_notes(x, ...)

## S3 method for class 'tune_results'
collect_notes(x, ...)

Arguments

x

The results of tune_grid(), tune_bayes(), fit_resamples(), or last_fit(). For collect_predictions(), the control option save_pred = TRUE should have been used.

...

Not currently used.

summarize

A logical; should metrics be summarized over resamples (TRUE) or return the values for each individual resample. Note that, if x is created by last_fit(), summarize has no effect. For the other object types, the method of summarizing predictions is detailed below.

parameters

An optional tibble of tuning parameter values that can be used to filter the predicted values before processing. This tibble should only have columns for each tuning parameter identifier (e.g. "my_param" if tune("my_param") was used).

Value

A tibble. The column names depend on the results and the mode of the model.

For collect_metrics() and collect_predictions(), when unsummarized, there are columns for each tuning parameter (using the id from tune(), if any). collect_metrics() also has columns .metric, and .estimator. When the results are summarized, there are columns for mean, n, and std_err. When not summarized, the additional columns for the resampling identifier(s) and .estimate.

For collect_predictions(), there are additional columns for the resampling identifier(s), columns for the predicted values (e.g., .pred, .pred_class, etc.), and a column for the outcome(s) using the original column name(s) in the data.

collect_predictions() can summarize the various results over replicate out-of-sample predictions. For example, when using the bootstrap, each row in the original training set has multiple holdout predictions (across assessment sets). To convert these results to a format where every training set same has a single predicted value, the results are averaged over replicate predictions.

For regression cases, the numeric predictions are simply averaged. For classification models, the problem is more complex. When class probabilities are used, these are averaged and then re-normalized to make sure that they add to one. If hard class predictions also exist in the data, then these are determined from the summarized probability estimates (so that they match). If only hard class predictions are in the results, then the mode is used to summarize.

collect_notes() returns a tibble with columns for the resampling indicators, the location (preprocessor, model, etc.), type (error or warning), and the notes.

Examples


data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)

# Summarized over resamples
collect_metrics(ames_grid_search)

# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)


# ---------------------------------------------------------------------------

library(parsnip)
library(rsample)
library(dplyr)
library(recipes)
library(tibble)

lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE)

spline_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp, deg_free = tune("df"))

grid <- tibble(df = 3:6)

resampled <-
  lm_mod %>%
  tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)

collect_predictions(resampled) %>% arrange(.row)
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)


tune documentation built on March 19, 2022, 2:14 a.m.