luz_metric_binary_auroc: Computes the area under the ROC

luz_metric_binary_aurocR Documentation

Computes the area under the ROC

Description

To avoid storing all predictions and targets for an epoch we compute confusion matrices across a range of pre-established thresholds.

Usage

luz_metric_binary_auroc(
  num_thresholds = 200,
  thresholds = NULL,
  from_logits = FALSE
)

Arguments

num_thresholds

Number of thresholds used to compute confusion matrices. In that case, thresholds are created by getting num_thresholds values linearly spaced in the unit interval.

thresholds

(optional) If threshold are passed, then those are used to compute the confusion matrices and num_thresholds is ignored.

from_logits

Boolean indicating if predictions are logits, in that case we use sigmoid to put them in the unit interval.

See Also

Other luz_metrics: luz_metric_accuracy(), luz_metric_binary_accuracy_with_logits(), luz_metric_binary_accuracy(), luz_metric_mae(), luz_metric_mse(), luz_metric_multiclass_auroc(), luz_metric_rmse(), luz_metric()

Examples

if (torch::torch_is_installed()){
library(torch)
actual <- c(1, 1, 1, 0, 0, 0)
predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)

y_true <- torch_tensor(actual)
y_pred <- torch_tensor(predicted)

m <- luz_metric_binary_auroc(thresholds = predicted)
m <- m$new()

m$update(y_pred[1:2], y_true[1:2])
m$update(y_pred[3:4], y_true[3:4])
m$update(y_pred[5:6], y_true[5:6])

m$compute()
}

mlverse/torchlight documentation built on Sept. 19, 2024, 11:22 p.m.