Keras 3 is a deep learning framework works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras 3 workflows.
Let's start by installing Keras 3:
install.packages("keras3") keras3::install_keras()
We're going to be using the tensorflow backend here -- but you can
edit the string below to "jax"
or "torch"
and hit
"Restart runtime", and the whole notebook will run just the same!
This entire guide is backend-agnostic.
library(tensorflow, exclude = c("shape", "set_random_seed")) library(keras3) # Note that you must configure the backend # before calling any other keras functions. # The backend cannot be changed once the # package is imported. use_backend("tensorflow")
Let's start with the Hello World of ML: training a convnet to classify MNIST digits.
Here's the data:
# Load the data and split it between train and test sets c(c(x_train, y_train), c(x_test, y_test)) %<-% keras3::dataset_mnist() # Scale images to the [0, 1] range x_train <- x_train / 255 x_test <- x_test / 255 # Make sure images have shape (28, 28, 1) x_train <- array_reshape(x_train, c(-1, 28, 28, 1)) x_test <- array_reshape(x_test, c(-1, 28, 28, 1)) dim(x_train)
## [1] 60000 28 28 1
dim(x_test)
## [1] 10000 28 28 1
Here's our model.
Different model-building options that Keras offers include:
# Model parameters num_classes <- 10 input_shape <- c(28, 28, 1) model <- keras_model_sequential(input_shape = input_shape) model |> layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |> layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |> layer_max_pooling_2d(pool_size = c(2, 2)) |> layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |> layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |> layer_global_average_pooling_2d() |> layer_dropout(rate = 0.5) |> layer_dense(units = num_classes, activation = "softmax")
Here's our model summary:
summary(model)
## [1mModel: "sequential"[0m ## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ## ┃[1m [0m[1mLayer (type) [0m[1m [0m┃[1m [0m[1mOutput Shape [0m[1m [0m┃[1m [0m[1m Param #[0m[1m [0m┃ ## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ ## │ conv2d ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m26[0m, [38;5;34m26[0m, [38;5;34m64[0m) │ [38;5;34m640[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ conv2d_1 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m24[0m, [38;5;34m24[0m, [38;5;34m64[0m) │ [38;5;34m36,928[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ max_pooling2d ([38;5;33mMaxPooling2D[0m) │ ([38;5;45mNone[0m, [38;5;34m12[0m, [38;5;34m12[0m, [38;5;34m64[0m) │ [38;5;34m0[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ conv2d_2 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, [38;5;34m128[0m) │ [38;5;34m73,856[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ conv2d_3 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m8[0m, [38;5;34m8[0m, [38;5;34m128[0m) │ [38;5;34m147,584[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ global_average_pooling2d │ ([38;5;45mNone[0m, [38;5;34m128[0m) │ [38;5;34m0[0m │ ## │ ([38;5;33mGlobalAveragePooling2D[0m) │ │ │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ dropout ([38;5;33mDropout[0m) │ ([38;5;45mNone[0m, [38;5;34m128[0m) │ [38;5;34m0[0m │ ## ├─────────────────────────────────┼────────────────────────┼───────────────┤ ## │ dense ([38;5;33mDense[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m) │ [38;5;34m1,290[0m │ ## └─────────────────────────────────┴────────────────────────┴───────────────┘ ## [1m Total params: [0m[38;5;34m260,298[0m (1016.79 KB) ## [1m Trainable params: [0m[38;5;34m260,298[0m (1016.79 KB) ## [1m Non-trainable params: [0m[38;5;34m0[0m (0.00 B)
We use the compile()
method to specify the optimizer, loss function,
and the metrics to monitor. Note that with the JAX and TensorFlow backends,
XLA compilation is turned on by default.
model |> compile( optimizer = "adam", loss = "sparse_categorical_crossentropy", metrics = list( metric_sparse_categorical_accuracy(name = "acc") ) )
Let's train and evaluate the model. We'll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.
batch_size <- 128 epochs <- 10 callbacks <- list( callback_model_checkpoint(filepath="model_at_epoch_{epoch}.keras"), callback_early_stopping(monitor="val_loss", patience=2) ) model |> fit( x_train, y_train, batch_size = batch_size, epochs = epochs, validation_split = 0.15, callbacks = callbacks )
## Epoch 1/10 ## 399/399 - 7s - 18ms/step - acc: 0.7496 - loss: 0.7385 - val_acc: 0.9641 - val_loss: 0.1228 ## Epoch 2/10 ## 399/399 - 3s - 7ms/step - acc: 0.9382 - loss: 0.2037 - val_acc: 0.9769 - val_loss: 0.0774 ## Epoch 3/10 ## 399/399 - 3s - 7ms/step - acc: 0.9567 - loss: 0.1458 - val_acc: 0.9816 - val_loss: 0.0636 ## Epoch 4/10 ## 399/399 - 3s - 7ms/step - acc: 0.9658 - loss: 0.1163 - val_acc: 0.9866 - val_loss: 0.0468 ## Epoch 5/10 ## 399/399 - 3s - 9ms/step - acc: 0.9719 - loss: 0.0975 - val_acc: 0.9880 - val_loss: 0.0433 ## Epoch 6/10 ## 399/399 - 3s - 8ms/step - acc: 0.9758 - loss: 0.0853 - val_acc: 0.9874 - val_loss: 0.0413 ## Epoch 7/10 ## 399/399 - 3s - 7ms/step - acc: 0.9765 - loss: 0.0782 - val_acc: 0.9891 - val_loss: 0.0398 ## Epoch 8/10 ## 399/399 - 3s - 8ms/step - acc: 0.9797 - loss: 0.0678 - val_acc: 0.9881 - val_loss: 0.0419 ## Epoch 9/10 ## 399/399 - 3s - 7ms/step - acc: 0.9805 - loss: 0.0652 - val_acc: 0.9897 - val_loss: 0.0381 ## Epoch 10/10 ## 399/399 - 3s - 8ms/step - acc: 0.9831 - loss: 0.0576 - val_acc: 0.9912 - val_loss: 0.0340
score <- model |> evaluate(x_test, y_test, verbose = 0)
During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:
save_model(model, "final_model.keras", overwrite=TRUE)
And reload it like this:
model <- load_model("final_model.keras")
Next, you can query predictions of class probabilities with predict()
:
predictions <- model |> predict(x_test)
## 313/313 - 1s - 2ms/step
dim(predictions)
## [1] 10000 10
That's it for the basics!
Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look at custom layers first.
The op_
namespace contains:
op_stack
or op_matmul
.op_conv
or op_binary_crossentropy
.Let's make a custom Dense
layer that works with all backends:
layer_my_dense <- Layer( classname = "MyDense", initialize = function(units, activation = NULL, name = NULL, ...) { super$initialize(name = name, ...) self$units <- units self$activation <- activation }, build = function(input_shape) { input_dim <- tail(input_shape, 1) self$w <- self$add_weight( shape = shape(input_dim, self$units), initializer = initializer_glorot_normal(), name = "kernel", trainable = TRUE ) self$b <- self$add_weight( shape = shape(self$units), initializer = initializer_zeros(), name = "bias", trainable = TRUE ) }, call = function(inputs) { # Use Keras ops to create backend-agnostic layers/metrics/etc. x <- op_matmul(inputs, self$w) + self$b if (!is.null(self$activation)) x <- self$activation(x) x } )
Next, let's make a custom Dropout
layer that relies on the random_*
namespace:
layer_my_dropout <- Layer( "MyDropout", initialize = function(rate, name = NULL, seed = NULL, ...) { super$initialize(name = name) self$rate <- rate # Use seed_generator for managing RNG state. # It is a state element and its seed variable is # tracked as part of `layer$variables`. self$seed_generator <- random_seed_generator(seed) }, call = function(inputs) { # Use `keras3::random_*` for random ops. random_dropout(inputs, self$rate, seed = self$seed_generator) } )
Next, let's write a custom subclassed model that uses our two custom layers:
MyModel <- Model( "MyModel", initialize = function(num_classes, ...) { super$initialize(...) self$conv_base <- keras_model_sequential() |> layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |> layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |> layer_max_pooling_2d(pool_size = c(2, 2)) |> layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |> layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |> layer_global_average_pooling_2d() self$dp <- layer_my_dropout(rate = 0.5) self$dense <- layer_my_dense(units = num_classes, activation = activation_softmax) }, call = function(inputs) { inputs |> self$conv_base() |> self$dp() |> self$dense() } )
Let's compile it and fit it:
model <- MyModel(num_classes = 10) model |> compile( loss = loss_sparse_categorical_crossentropy(), optimizer = optimizer_adam(learning_rate = 1e-3), metrics = list( metric_sparse_categorical_accuracy(name = "acc") ) ) model |> fit( x_train, y_train, batch_size = batch_size, epochs = 1, # For speed validation_split = 0.15 )
## 399/399 - 7s - 18ms/step - acc: 0.7344 - loss: 0.7749 - val_acc: 0.9259 - val_loss: 0.2411
All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you're using. This includes:
tf_dataset
objectsDataLoader
objectsPyDataset
objectsThey all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.
Let's try this out with tf_dataset
:
library(tfdatasets, exclude = "shape") train_dataset <- list(x_train, y_train) |> tensor_slices_dataset() |> dataset_batch(batch_size) |> dataset_prefetch(buffer_size = tf$data$AUTOTUNE) test_dataset <- list(x_test, y_test) |> tensor_slices_dataset() |> dataset_batch(batch_size) |> dataset_prefetch(buffer_size = tf$data$AUTOTUNE) model <- MyModel(num_classes = 10) model |> compile( loss = loss_sparse_categorical_crossentropy(), optimizer = optimizer_adam(learning_rate = 1e-3), metrics = list( metric_sparse_categorical_accuracy(name = "acc") ) ) model |> fit(train_dataset, epochs = 1, validation_data = test_dataset)
## 469/469 - 8s - 17ms/step - acc: 0.7493 - loss: 0.7476 - val_acc: 0.9123 - val_loss: 0.2965
This concludes our short overview of the new multi-backend capabilities of Keras 3. Next, you can learn about:
fit()
Want to implement a non-standard training algorithm yourself but still want to benefit from
the power and usability of fit()
? It's easy to customize
fit()
to support arbitrary use cases:
Enjoy the library! 🚀
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.