Nothing
## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
eval = reticulate::py_module_available("keras")
)
# Suppress verbose Keras output for the vignette
options(keras.fit_verbose = 0)
set.seed(123)
## ----load-packages------------------------------------------------------------
library(kerasnip)
library(tidymodels)
library(keras3)
## ----data-prep----------------------------------------------------------------
# Load CIFAR-10 dataset
cifar10 <- dataset_cifar10()
# Separate training and test data
x_train <- cifar10$train$x
y_train <- cifar10$train$y
x_test <- cifar10$test$x
y_test <- cifar10$test$y
# Rescale pixel values from [0, 255] to [0, 1]
x_train <- x_train / 255
x_test <- x_test / 255
# Convert outcomes to factors for tidymodels
y_train_factor <- factor(y_train[, 1])
y_test_factor <- factor(y_test[, 1])
# For tidymodels, it's best to work with data frames.
# We'll use a list-column to hold the image arrays.
train_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]),
y = y_train_factor
)
test_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]),
y = y_test_factor
)
# Use a smaller subset for faster vignette execution
train_df_small <- train_df[1:500, ]
test_df_small <- test_df[1:100, ]
## ----define-functional-blocks-------------------------------------------------
# Input block: shape is determined automatically from the data
input_block <- function(input_shape) {
layer_input(shape = input_shape)
}
# ResNet50 base block
resnet_base_block <- function(tensor) {
# The base model is not trainable; we use it for feature extraction.
resnet_base <- application_resnet50(
weights = "imagenet",
include_top = FALSE
)
resnet_base$trainable <- FALSE
resnet_base(tensor)
}
# New classification head
flatten_block <- function(tensor) {
tensor |> layer_flatten()
}
output_block_functional <- function(tensor, num_classes) {
tensor |> layer_dense(units = num_classes, activation = "softmax")
}
## ----create-functional-spec---------------------------------------------------
create_keras_functional_spec(
model_name = "resnet_transfer",
layer_blocks = list(
input = input_block,
resnet_base = inp_spec(resnet_base_block, "input"),
flatten = inp_spec(flatten_block, "resnet_base"),
output = inp_spec(output_block_functional, "flatten")
),
mode = "classification"
)
## ----fit-functional-model, cache=TRUE-----------------------------------------
spec_functional <- resnet_transfer(
fit_epochs = 5,
fit_validation_split = 0.2
) |>
set_engine("keras")
rec_functional <- recipe(y ~ x, data = train_df_small)
wf_functional <- workflow() |>
add_recipe(rec_functional) |>
add_model(spec_functional)
fit_functional <- fit(wf_functional, data = train_df_small)
# Evaluate on the test set
predictions <- predict(fit_functional, new_data = test_df_small)
bind_cols(predictions, test_df_small) |>
accuracy(truth = y, estimate = .pred_class)
## ----cleanup, include=FALSE---------------------------------------------------
remove_keras_spec("resnet_transfer")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.