| mlr_context_torch | R Documentation |
Context for training a torch learner.
This is the - mostly read-only - information callbacks have access to through the argument ctx.
For more information on callbacks, see CallbackSet.
learner(Learner)
The torch learner.
task_train(Task)
The training task.
task_valid(Task or NULL)
The validation task.
loader_train(torch::dataloader)
The data loader for training.
loader_valid(torch::dataloader)
The data loader for validation.
measures_train(list() of Measures)
Measures used for training.
measures_valid(list() of Measures)
Measures used for validation.
network(torch::nn_module)
The torch network.
optimizer(torch::optimizer)
The optimizer.
loss_fn(torch::nn_module)
The loss function.
total_epochs(integer(1))
The total number of epochs the learner is trained for.
last_scores_train(named list() or NULL)
The scores from the last training batch. Names are the ids of the training measures.
If LearnerTorch sets eval_freq different from 1, this is NULL in all epochs
that don't evaluate the model.
last_scores_valid(list())
The scores from the last validation batch. Names are the ids of the validation measures.
If LearnerTorch sets eval_freq different from 1, this is NULL in all epochs
that don't evaluate the model.
last_loss(numeric(1))
The loss from the last trainings batch.
y_hat(torch_tensor)
The model's prediction for the current batch.
epoch(integer(1))
The current epoch.
step(integer(1))
The current iteration.
prediction_encoder(function())
The learner's prediction encoder.
batch(named list() of torch_tensors)
The current batch.
terminate(logical(1))
If this field is set to TRUE at the end of an epoch, training stops.
device(torch::torch_device)
The device.
new()Creates a new instance of this R6 class.
ContextTorch$new( learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL, measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder, eval_freq = 1L, device )
learner(Learner)
The torch learner.
task_train(Task)
The training task.
task_valid(Task or NULL)
The validation task.
loader_train(torch::dataloader)
The data loader for training.
loader_valid(torch::dataloader or NULL)
The data loader for validation.
measures_train(list() of Measures or NULL)
Measures used for training. Default is NULL.
measures_valid(list() of Measures or NULL)
Measures used for validation.
network(torch::nn_module)
The torch network.
optimizer(torch::optimizer)
The optimizer.
loss_fn(torch::nn_module)
The loss function.
total_epochs(integer(1))
The total number of epochs the learner is trained for.
prediction_encoder(function())
The learner's prediction encoder.
See section Inheriting of LearnerTorch.
eval_freq(integer(1))
The evaluation frequency.
device(character(1))
The device.
clone()The objects of this class are cloneable with this method.
ContextTorch$clone(deep = FALSE)
deepWhether to make a deep clone.
Other Callback:
TorchCallback,
as_torch_callback(),
as_torch_callbacks(),
callback_set(),
mlr3torch_callbacks,
mlr_callback_set,
mlr_callback_set.checkpoint,
mlr_callback_set.progress,
mlr_callback_set.tb,
mlr_callback_set.unfreeze,
t_clbk(),
torch_callback()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.