knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = reticulate::py_module_available("keras") ) # Suppress verbose Keras output for the vignette options(keras.fit_verbose = 0) set.seed(123)
Transfer learning is a powerful technique where a model developed for one task is reused as the starting point for a model on a second task. It is especially popular in computer vision, where pre-trained models like ResNet50, which were trained on the massive ImageNet dataset, can be used as powerful, ready-made feature extractors.
The kerasnip package makes it easy to incorporate these pre-trained Keras Applications directly into a tidymodels workflow. This vignette will demonstrate how to:
kerasnip model that uses a pre-trained ResNet50 as a frozen base layer.tidymodels workflow.First, we load the necessary packages.
library(kerasnip) library(tidymodels) library(keras3)
We'll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 classes. keras3 provides a convenient function to download it.
The ResNet50 model was pre-trained on ImageNet, which has a different set of classes. Our goal is to fine-tune it to classify the 10 classes in CIFAR-10.
# Load CIFAR-10 dataset cifar10 <- dataset_cifar10() # Separate training and test data x_train <- cifar10$train$x y_train <- cifar10$train$y x_test <- cifar10$test$x y_test <- cifar10$test$y # Rescale pixel values from [0, 255] to [0, 1] x_train <- x_train / 255 x_test <- x_test / 255 # Convert outcomes to factors for tidymodels y_train_factor <- factor(y_train[, 1]) y_test_factor <- factor(y_test[, 1]) # For tidymodels, it's best to work with data frames. # We'll use a list-column to hold the image arrays. train_df <- tibble::tibble( x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]), y = y_train_factor ) test_df <- tibble::tibble( x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]), y = y_test_factor ) # Use a smaller subset for faster vignette execution train_df_small <- train_df[1:500, ] test_df_small <- test_df[1:100, ]
The standard approach for transfer learning is to use the Keras Functional API. We will define a model where:
1. The base is a pre-trained ResNet50, with its final classification layer removed (include_top = FALSE).
2. The weights of the base are frozen (trainable = FALSE) so that only our new layers are trained.
3. A new classification "head" is added, consisting of a flatten layer and a dense output layer.
# Input block: shape is determined automatically from the data input_block <- function(input_shape) { layer_input(shape = input_shape) } # ResNet50 base block resnet_base_block <- function(tensor) { # The base model is not trainable; we use it for feature extraction. resnet_base <- application_resnet50( weights = "imagenet", include_top = FALSE ) resnet_base$trainable <- FALSE resnet_base(tensor) } # New classification head flatten_block <- function(tensor) { tensor |> layer_flatten() } output_block_functional <- function(tensor, num_classes) { tensor |> layer_dense(units = num_classes, activation = "softmax") }
kerasnip SpecificationWe connect these blocks using create_keras_functional_spec().
create_keras_functional_spec( model_name = "resnet_transfer", layer_blocks = list( input = input_block, resnet_base = inp_spec(resnet_base_block, "input"), flatten = inp_spec(flatten_block, "resnet_base"), output = inp_spec(output_block_functional, "flatten") ), mode = "classification" )
Now we can use our new resnet_transfer() specification within a tidymodels workflow.
spec_functional <- resnet_transfer( fit_epochs = 5, fit_validation_split = 0.2 ) |> set_engine("keras") rec_functional <- recipe(y ~ x, data = train_df_small) wf_functional <- workflow() |> add_recipe(rec_functional) |> add_model(spec_functional) fit_functional <- fit(wf_functional, data = train_df_small) # Evaluate on the test set predictions <- predict(fit_functional, new_data = test_df_small) bind_cols(predictions, test_df_small) |> accuracy(truth = y, estimate = .pred_class)
Even with a small dataset and few epochs, the pre-trained features from ResNet50 give us a reasonable starting point for accuracy.
This vignette demonstrated how kerasnip bridges the world of pre-trained Keras applications with the structured, reproducible workflows of tidymodels.
The Functional API is the most direct way to perform transfer learning by attaching a new head to a frozen base model.
This approach allows you to leverage the power of deep learning models that have been trained on massive datasets, significantly boosting performance on smaller, domain-specific tasks.
remove_keras_spec("resnet_transfer")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.