Tuning and comparing models

knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.align = "center"
)
library(klaR)
library(mda)
library(rpart)
library(earth)
library(tidymodels)
library(discrim)
theme_set(theme_bw() + theme(legend.position = "top"))

Workflow sets are collections of tidymodels workflow objects that are created as a set. A workflow object is a combination of a preprocessor (e.g. a formula or recipe) and a parsnip model specification.

For some problems, users might want to try different combinations of preprocessing options, models, and/or predictor sets. Instead of creating a large number of individual objects, a cohort of workflows can be created simultaneously.

In this example we'll use a small, two-dimensional data set for illustrating classification models. The data are in the modeldata package:

library(tidymodels)

data(parabolic)
str(parabolic)

Let's hold back 25% of the data for a test set:

set.seed(1)
split <- initial_split(parabolic)

train_set <- training(split)
test_set <- testing(split)

Visually, we can see that the predictors are mildly correlated and some type of nonlinear class boundary is probably needed.

ggplot(train_set, aes(x = X1, y = X2, col = class)) + 
  geom_point(alpha = 0.5) + 
  coord_fixed(ratio = 1) + 
  scale_color_brewer(palette = "Dark2")

Defining the models

We'll fit two types of discriminant analysis (DA) models (regularized DA and flexible DA using MARS, multivariate adaptive regression splines) as well as a simple classification tree. Let's create those parsnip model objects:

library(discrim)

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

Next, we'll need a resampling method. Let's use the bootstrap:

set.seed(2)
train_resamples <- bootstraps(train_set)

We have a simple data set so a basic formula will suffice for our preprocessing. (If we needed more complex feature engineering, we could use a recipe as a preprocessor instead.)

The workflow set takes a named list of preprocessors and a named list of parsnip model specifications, and can cross them to find all combinations. For our case, it will just make a set of workflows for our models:

all_workflows <- 
  workflow_set(
    preproc = list("formula" = class ~ .),
    models = list(regularized = reg_disc_sepc, mars = mars_disc_spec, cart = cart_spec)
  )
all_workflows

Adding options to the models

We can add any specific options that we think are important for tuning or resampling using the option_add() function.

For illustration, let's use the extract argument of the control function to save the fitted workflow. We can then pick which workflow should use this option with the id argument:

all_workflows <- 
   all_workflows %>% 
   option_add(id = "formula_cart", 
              control = control_grid(extract = function(x) x))
all_workflows

Keep in mind that this will save the fitted workflow for each resample and each tuning parameter combination that we evaluate.

Tuning the models

Since these models all have tuning parameters, we can apply the workflow_map() function to execute grid search for each of these model-specific arguments. The default function to apply across the workflows is tune_grid() but other tune_*() functions and fit_resamples() can be used by passing the function name as the first argument.

Let's use the same grid size for each model. For the MARS model, there are only two possible tuning parameter values but tune_grid() is forgiving about our request of 20 parameter values.

The verbose option provides a concise listing for which workflow is being processed:

all_workflows <- 
  all_workflows %>% 
  # Specifying arguments here adds to any previously set with `option_add()`:
  workflow_map(resamples = train_resamples, grid = 20, verbose = TRUE)
all_workflows

The result column now has the results of each tune_grid() call.

From these results, we can get quick assessments of how well these models classified the data:

rank_results(all_workflows, rank_metric = "roc_auc")

# or a handy plot: 
autoplot(all_workflows, metric = "roc_auc")

Examining specific model results

It looks like the MARS model did well. We can plot its results and also pull out the tuning object too:

autoplot(all_workflows, metric = "roc_auc", id = "formula_mars")

Not much of a difference in performance; it may be prudent to use the additive model (via prod_degree = 1).

We can also pull out the results of tune_grid() for this model:

mars_results <- 
  all_workflows %>% 
  extract_workflow_set_result("formula_mars")
mars_results

Let's get that workflow object and finalize the model:

mars_workflow <- 
  all_workflows %>% 
  extract_workflow("formula_mars")
mars_workflow

mars_workflow_fit <- 
  mars_workflow %>% 
  finalize_workflow(tibble(prod_degree = 1)) %>% 
  fit(data = train_set)
mars_workflow_fit

Let's see how well these data work on the test set:

# Make a grid to predict the whole space:
grid <-
  crossing(X1 = seq(min(train_set$X1), max(train_set$X1), length.out = 250),
           X2 = seq(min(train_set$X1), max(train_set$X2), length.out = 250))

grid <- 
  grid %>% 
  bind_cols(predict(mars_workflow_fit, grid, type = "prob"))

We can produce a contour plot for the class boundary, then overlay the data:

ggplot(grid, aes(x = X1, y = X2)) + 
  geom_contour(aes(z = .pred_Class2), breaks = 0.5, col = "black") + 
  geom_point(data = test_set, aes(col = class), alpha = 0.5) + 
  coord_fixed(ratio = 1)+ 
  scale_color_brewer(palette = "Dark2")

The workflow set allows us to screen many models to find one that does very well. This can be combined with parallel processing and, especially, racing methods from the finetune package to optimize efficiency.

Extracting information from the results

Recall that we added an option to the CART model to extract the model results. Let's pull out the CART tuning results and see what we have:

cart_res <- 
   all_workflows %>% 
   extract_workflow_set_result("formula_cart")
cart_res

The .extracts has 20 rows for each resample (since there were 20 tuning parameter candidates). Each tibble in that column has a fitted workflow for each candidate and, since cart_res has r nrow(cart_res) rows, a value returned for each resample. That's r nrow(cart_res) * 20 fitted workflows.

Let's slim that down by keeping the ones that correspond to the best tuning parameters:

# Get the best results
best_cart <- select_best(cart_res, metric = "roc_auc")

cart_wflows <- 
   cart_res %>% 
   select(id, .extracts) %>% 
   unnest(cols = .extracts) %>% 
   inner_join(best_cart)

cart_wflows

What can we do with these? Let's write a function to return the number of terminal nodes in the tree.

num_nodes <- function(wflow) {
   var_imps <- 
      wflow %>% 
      # Pull out the rpart model
      extract_fit_engine() %>% 
      # The 'frame' element is a matrix with a column that
      # indicates which leaves are terminal
      pluck("frame") %>% 
      # Convert to a data frame
      as_tibble() %>% 
      # Save only the rows that are terminal nodes
      filter(var == "<leaf>") %>% 
      # Count them
      nrow() 
}

cart_wflows$.extracts[[1]] %>% num_nodes()

Now let's create a column with the results for each resample:

cart_wflows <- 
   cart_wflows %>% 
   mutate(num_nodes = map_int(.extracts, num_nodes))
cart_wflows

The average number of terminal nodes for this model is r round(mean(cart_wflows$num_nodes), 1) nodes.



Try the workflowsets package in your browser

Any scripts or data that you put into this service are public.

workflowsets documentation built on April 7, 2023, 1:05 a.m.