mlr_callback_set.unfreeze | R Documentation |
Unfreeze some weights (parameters of the network) after some number of steps or epochs.
mlr3torch::CallbackSet
-> CallbackSetUnfreeze
new()
Creates a new instance of this R6 class.
CallbackSetUnfreeze$new(starting_weights, unfreeze)
starting_weights
(Select
)
A Select
denoting the weights that are trainable from the start.
unfreeze
(data.table
)
A data.table
with a column weights
(a list column of Select
s) and a column epoch
or batch
.
The selector indicates which parameters to unfreeze, while the epoch
or batch
column indicates when to do so.
on_begin()
Sets the starting weights
CallbackSetUnfreeze$on_begin()
on_epoch_begin()
Unfreezes weights if the training is at the correct epoch
CallbackSetUnfreeze$on_epoch_begin()
on_batch_begin()
Unfreezes weights if the training is at the correct batch
CallbackSetUnfreeze$on_batch_begin()
clone()
The objects of this class are cloneable with this method.
CallbackSetUnfreeze$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_callback_set.tb
,
mlr_context_torch
,
t_clbk()
,
torch_callback()
task = tsk("iris")
cb = t_clbk("unfreeze")
mlp = lrn("classif.mlp", callbacks = cb,
cb.unfreeze.starting_weights = select_invert(
select_name(c("0.weight", "3.weight", "6.weight", "6.bias"))
),
cb.unfreeze.unfreeze = data.table(
epoch = c(2, 5),
weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight")))
),
epochs = 6, batch_size = 150, neurons = c(1, 1, 1)
)
mlp$train(task)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.