Nothing
#' Freeze and unfreeze weights
#'
#' Freeze weights in a model or layer so that they are no longer trainable.
#'
#' @param object Keras model or layer object
#' @param from Layer instance, layer name, or layer index within model
#' @param to Layer instance, layer name, or layer index within model
#' @param which layer names, integer positions, layers, logical vector (of
#' `length(object$layers)`), or a function returning a logical vector.
#'
#' @note The `from` and `to` layer arguments are both inclusive.
#'
#' When applied to a model, the freeze or unfreeze is a global operation over
#' all layers in the model (i.e. layers not within the specified range will be
#' set to the opposite value, e.g. unfrozen for a call to freeze).
#'
#' Models must be compiled again after weights are frozen or unfrozen.
#'
#' @details
#' # Examples
#'
#' ```{r, strip.white = FALSE}
#' # instantiate a VGG16 model
#' conv_base <- application_vgg16(
#' weights = "imagenet",
#' include_top = FALSE,
#' input_shape = c(150, 150, 3)
#' )
#'
#' # freeze it's weights
#' freeze_weights(conv_base)
#'
#' # Note the "Trainable" column
#' conv_base
#'
#' # create a composite model that includes the base + more layers
#' model <- keras_model_sequential(input_batch_shape = shape(conv_base$input)) |>
#' conv_base() |>
#' layer_flatten() |>
#' layer_dense(units = 256, activation = "relu") |>
#' layer_dense(units = 1, activation = "sigmoid")
#'
#' # compile
#' model |> compile(
#' loss = "binary_crossentropy",
#' optimizer = optimizer_rmsprop(learning_rate = 2e-5),
#' metrics = c("accuracy")
#' )
#'
#' model
#' print(model, expand_nested = TRUE)
#'
#'
#' # unfreeze weights from "block5_conv1" on
#' unfreeze_weights(conv_base, from = "block5_conv1")
#'
#' # compile again since we froze or unfroze weights
#' model |> compile(
#' loss = "binary_crossentropy",
#' optimizer = optimizer_rmsprop(learning_rate = 2e-5),
#' metrics = c("accuracy")
#' )
#'
#' conv_base
#' print(model, expand_nested = TRUE)
#'
#' # freeze only the last 5 layers
#' freeze_weights(conv_base, from = -5)
#' conv_base
#' # freeze only the last 5 layers, a different way
#' unfreeze_weights(conv_base, to = -6)
#' conv_base
#'
#' # Freeze only layers of a certain type, e.g, BatchNorm layers
#' batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
#' is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)
#'
#' model <- application_efficientnet_b0()
#' freeze_weights(model, which = is_batch_norm_layer)
#' # print(model)
#'
#' # equivalent to:
#' for(layer in model$layers) {
#' if(is_batch_norm_layer(layer))
#' layer$trainable <- FALSE
#' else
#' layer$trainable <- TRUE
#' }
#' ```
#' @returns The input `object` with frozen weights is returned, invisibly. Note,
#' `object` is modified in place, and the return value is only provided to
#' make usage with the pipe convenient.
#' @export
freeze_weights <- function(object, from = NULL, to = NULL, which = NULL) {
if (!is.null(which)) {
if(!is.null(from) && !is.null(to))
stop("both `which` and `from`/`to` can not be supplied")
return(apply_which_trainable(object, which, FALSE))
}
# check for from and to and apply accordingly
if (missing(from) && missing(to)) {
object$trainable <- FALSE
} else {
object$trainable <- TRUE
apply_trainable(object, from, to, FALSE)
}
# return model invisibly (for chaining)
invisible(object)
}
#' @rdname freeze_weights
#' @export
unfreeze_weights <- function(object, from = NULL, to = NULL, which = NULL) {
# object always trainable after unfreeze
object$trainable <- TRUE
if (!is.null(which)) {
if(!is.null(from) && !is.null(to))
stop("both `which` and `from`/`to` can not be supplied")
return(apply_which_trainable(object, which, TRUE))
}
# apply to individual layers if requested
if (!missing(from) || !missing(to))
apply_trainable(object, from, to, TRUE)
# return model invisibly (for chaining)
invisible(object)
}
apply_trainable <- function(object, from, to, trainable) {
# first resolve from and to into layer names
layers <- object$layers
# NULL means beginning and end respectively
if (is.null(from))
from <- layers[[1]]$name
if (is.null(to))
to <- layers[[length(layers)]]$name
# layer instances become layer names
if (is_layer(from))
from <- from$name
if (is_layer(to))
to <- to$name
# layer indexes become layer names
if (is.numeric(from)) {
if(from < 0)
from <- length(layers) - abs(from) + 1
from <- layers[[from]]$name
}
if (is.numeric(to)) {
if(to < 0)
to <- length(layers) - abs(to) + 1
to <- layers[[to]]$name
}
# apply trainable property
set_trainable <- FALSE
for (layer in layers) {
# flag to begin applying property
if (layer$name == from)
set_trainable <- TRUE
# apply property
if (set_trainable)
layer$trainable <- trainable
else
layer$trainable <- !trainable
# flag to stop applying property
if (layer$name == to)
set_trainable <- FALSE
}
}
apply_which_trainable <- function(object, which, trainable) {
# presumably, since user is being selective, some parts of the object are
# still trainable
object$trainable <- TRUE
layers <- object$layers
names(layers) <- vapply(layers, function(l) l$name, "", USE.NAMES = FALSE)
if(inherits(which, "formula"))
which <- rlang::as_function(which)
if(is.function(which))
which <- vapply(layers, which, TRUE, USE.NAMES = FALSE)
if(is.logical(which))
which <- base::which(which)
# invert all the layers, then set just the flag for the requested layers
for(l in layers)
l$trainable <- !trainable
for (i in which) {
if (is.character(i)) {
layer <- layers[[i]]
} else if (is.numeric(i)) {
if (i < 0)
i <- length(layers) + i + 1
layer <- layers[[i]]
} else if (is_layer(i)) {
layer <- i
} else {
stop(
"`which` must be:layer names, index position, layers, ",
"logical vector, or a function returning a logical vector"
)
}
layer$trainable <- trainable
}
return(invisible(object))
}
is_layer <- function(object) {
inherits(object, "keras.src.layers.layer.Layer")
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.