luz_metric | R Documentation |
Creates a new luz metric
luz_metric(
name = NULL,
...,
private = NULL,
active = NULL,
parent_env = parent.frame(),
inherit = NULL
)
name |
string naming the new metric. |
... |
named list of public methods. You should implement at least
|
private |
An optional list of private members, which can be functions and non-functions. |
active |
An optional list of active binding functions. |
parent_env |
An environment to use as the parent of newly-created objects. |
inherit |
A R6ClassGenerator object to inherit from; in other words, a
superclass. This is captured as an unevaluated expression which is
evaluated in |
In order to implement a new luz_metric
we need to implement 3 methods:
initialize
: defines the metric initial state. This function is
called for each epoch for both training and validation loops.
update
: updates the metric internal state. This function is called
at every training and validation step with the predictions obtained by
the model and the target values obtained from the dataloader.
compute
: uses the internal state to compute metric values. This
function is called whenever we need to obtain the current metric
value. Eg, it’s called every training step for metrics displayed in
the progress bar, but only called once per epoch to record it’s value
when the progress bar is not displayed.
Optionally, you can implement an abbrev
field that gives the metric an
abbreviation that will be used when displaying metric information in the
console or tracking record. If no abbrev
is passed, the class name
will be used.
Let’s take a look at the implementation of luz_metric_accuracy
so you
can see how to implement a new one:
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total } )
Note: It’s good practice that the compute
metric returns regular R
values instead of torch tensors and other parts of luz will expect that.
Returns new luz metric.
Other luz_metrics:
luz_metric_accuracy()
,
luz_metric_binary_accuracy_with_logits()
,
luz_metric_binary_accuracy()
,
luz_metric_binary_auroc()
,
luz_metric_mae()
,
luz_metric_mse()
,
luz_metric_multiclass_auroc()
,
luz_metric_rmse()
luz_metric_accuracy <- luz_metric(
# An abbreviation to be shown in progress bars, or
# when printing progress
abbrev = "Acc",
# Initial setup for the metric. Metrics are initialized
# every epoch, for both training and validation
initialize = function() {
self$correct <- 0
self$total <- 0
},
# Run at every training or validation step and updates
# the internal state. The update function takes `preds`
# and `target` as parameters.
update = function(preds, target) {
pred <- torch::torch_argmax(preds, dim = 2)
self$correct <- self$correct + (pred == target)$
to(dtype = torch::torch_float())$
sum()$
item()
self$total <- self$total + pred$numel()
},
# Use the internal state to query the metric value
compute = function() {
self$correct/self$total
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.