mlr_context_torch | R Documentation |
Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx
.
For more information on callbacks, see CallbackSet
.
learner
(Learner
)
The torch learner.
task_train
(Task
)
The training task.
task_valid
(Task
or NULL
)
The validation task.
loader_train
(torch::dataloader
)
The data loader for training.
loader_valid
(torch::dataloader
)
The data loader for validation.
measures_train
(list()
of Measure
s)
Measures used for training.
measures_valid
(list()
of Measure
s)
Measures used for validation.
network
(torch::nn_module
)
The torch network.
optimizer
(torch::optimizer
)
The optimizer.
loss_fn
(torch::nn_module
)
The loss function.
total_epochs
(integer(1)
)
The total number of epochs the learner is trained for.
last_scores_train
(named list()
or NULL
)
The scores from the last training batch. Names are the ids of the training measures.
If LearnerTorch
sets eval_freq
different from 1
, this is NULL
in all epochs
that don't evaluate the model.
last_scores_valid
(list()
)
The scores from the last validation batch. Names are the ids of the validation measures.
If LearnerTorch
sets eval_freq
different from 1
, this is NULL
in all epochs
that don't evaluate the model.
last_loss
(numeric(1)
)
The loss from the last trainings batch.
epoch
(integer(1)
)
The current epoch.
step
(integer(1)
)
The current iteration.
prediction_encoder
(function()
)
The learner's prediction encoder.
batch
(named list()
of torch_tensor
s)
The current batch.
terminate
(logical(1)
)
If this field is set to TRUE
at the end of an epoch, training stops.
device
(torch::torch_device
)
The device.
new()
Creates a new instance of this R6 class.
ContextTorch$new( learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL, measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder, eval_freq = 1L, device )
learner
(Learner
)
The torch learner.
task_train
(Task
)
The training task.
task_valid
(Task
or NULL
)
The validation task.
loader_train
(torch::dataloader
)
The data loader for training.
loader_valid
(torch::dataloader
or NULL
)
The data loader for validation.
measures_train
(list()
of Measure
s or NULL
)
Measures used for training. Default is NULL
.
measures_valid
(list()
of Measure
s or NULL
)
Measures used for validation.
network
(torch::nn_module
)
The torch network.
optimizer
(torch::optimizer
)
The optimizer.
loss_fn
(torch::nn_module
)
The loss function.
total_epochs
(integer(1)
)
The total number of epochs the learner is trained for.
prediction_encoder
(function()
)
The learner's prediction encoder.
eval_freq
(integer(1)
)
The evaluation frequency.
device
(character(1)
)
The device.
clone()
The objects of this class are cloneable with this method.
ContextTorch$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Callback:
TorchCallback
,
as_torch_callback()
,
as_torch_callbacks()
,
callback_set()
,
mlr3torch_callbacks
,
mlr_callback_set
,
mlr_callback_set.checkpoint
,
mlr_callback_set.progress
,
mlr_callback_set.tb
,
mlr_callback_set.unfreeze
,
t_clbk()
,
torch_callback()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.