| TorchLoss | R Documentation |
This wraps a torch::nn_loss and annotates it with metadata, most importantly a ParamSet.
The loss function is created for the given parameter values by calling the $generate() method.
This class is usually used to configure the loss function of a torch learner, e.g.
when construcing a learner or in a ModelDescriptor.
For a list of available losses, see mlr3torch_losses.
Items from this dictionary can be retrieved using t_loss().
Defined by the constructor argument param_set.
If no parameter set is provided during construction, the parameter set is constructed by creating a parameter
for each argument of the wrapped loss function, where the parametes are then of type ParamUty.
mlr3torch::TorchDescriptor -> TorchLoss
task_types(character())
The task types this loss supports.
new()Creates a new instance of this R6 class.
TorchLoss$new( torch_loss, task_types = NULL, param_set = NULL, id = NULL, label = NULL, packages = NULL, man = NULL )
torch_loss(nn_loss or function)
The loss module or function that generates the loss module.
Can have arguments task that will be provided when the loss is instantiated.
task_types(character())
The task types supported by this loss.
param_set(ParamSet or NULL)
The parameter set. If NULL (default) it is inferred from torch_loss.
id(character(1))
The id for of the new object.
label(character(1))
Label for the new instance.
packages(character())
The R packages this object depends on.
man(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object.
The referenced help package can be opened via method $help().
print()Prints the object
TorchLoss$print(...)
...any
generate()Instantiates the loss function.
TorchLoss$generate(task = NULL)
task(Task)
The task. Must be provided if the loss function requires a task.
torch_loss
clone()The objects of this class are cloneable with this method.
TorchLoss$clone(deep = FALSE)
deepWhether to make a deep clone.
Other Torch Descriptor:
TorchCallback,
TorchDescriptor,
TorchOptimizer,
as_torch_callbacks(),
as_torch_loss(),
as_torch_optimizer(),
mlr3torch_losses,
mlr3torch_optimizers,
t_clbk(),
t_loss(),
t_opt()
# Create a new torch loss
torch_loss = TorchLoss$new(torch_loss = nn_mse_loss, task_types = "regr")
torch_loss
# the parameters are inferred
torch_loss$param_set
# Retrieve a loss from the dictionary:
torch_loss = t_loss("mse", reduction = "mean")
# is the same as
torch_loss
torch_loss$param_set
torch_loss$label
torch_loss$task_types
torch_loss$id
# Create the loss function
loss_fn = torch_loss$generate()
loss_fn
# Is the same as
nn_mse_loss(reduction = "mean")
# open the help page of the wrapped loss function
# torch_loss$help()
# Use in a learner
learner = lrn("regr.mlp", loss = t_loss("mse"))
# The parameters of the loss are added to the learner's parameter set
learner$param_set
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.