| mlr_callback_set.unfreeze | R Documentation |
Unfreeze some weights (parameters of the network) after some number of steps or epochs.
mlr3torch::CallbackSet -> CallbackSetUnfreeze
new()Creates a new instance of this R6 class.
CallbackSetUnfreeze$new(starting_weights, unfreeze)
starting_weights(Select)
A Select denoting the weights that are trainable from the start.
unfreeze(data.table)
A data.table with a column weights (a list column of Selects) and a column epoch or batch.
The selector indicates which parameters to unfreeze, while the epoch or batch column indicates when to do so.
on_begin()Sets the starting weights
CallbackSetUnfreeze$on_begin()
on_epoch_begin()Unfreezes weights if the training is at the correct epoch
CallbackSetUnfreeze$on_epoch_begin()
on_batch_begin()Unfreezes weights if the training is at the correct batch
CallbackSetUnfreeze$on_batch_begin()
clone()The objects of this class are cloneable with this method.
CallbackSetUnfreeze$clone(deep = FALSE)
deepWhether to make a deep clone.
Other Callback:
TorchCallback,
as_torch_callback(),
as_torch_callbacks(),
callback_set(),
mlr3torch_callbacks,
mlr_callback_set,
mlr_callback_set.checkpoint,
mlr_callback_set.progress,
mlr_callback_set.tb,
mlr_context_torch,
t_clbk(),
torch_callback()
task = tsk("iris")
cb = t_clbk("unfreeze")
mlp = lrn("classif.mlp", callbacks = cb,
cb.unfreeze.starting_weights = select_invert(
select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
),
cb.unfreeze.unfreeze = data.table(
epoch = c(2, 5),
weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
),
epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
)
mlp$train(task)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.