metric_categorical_crossentropy: Computes the crossentropy metric between the labels and...

metric_categorical_crossentropyR Documentation

Computes the crossentropy metric between the labels and predictions.

Description

This is the crossentropy metric class to be used when there are multiple label classes (2 or more). It assumes that labels are one-hot encoded, e.g., when labels values are c(2, 0, 1), then y_true is ⁠rbind(c([0, 0, 1), c(1, 0, 0), c(0, 1, 0))⁠.

Usage

metric_categorical_crossentropy(
  y_true,
  y_pred,
  from_logits = FALSE,
  label_smoothing = 0,
  axis = -1L,
  ...,
  name = "categorical_crossentropy",
  dtype = NULL
)

Arguments

y_true

Tensor of one-hot true targets.

y_pred

Tensor of predicted targets.

from_logits

(Optional) Whether output is expected to be a logits tensor. By default, we consider that output encodes a probability distribution.

label_smoothing

(Optional) Float in ⁠[0, 1]⁠. When > 0, label values are smoothed, meaning the confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a value of 0.1 for label "0" and 0.9 for label "1".

axis

(Optional) Defaults to -1. The dimension along which entropy is computed.

...

For forward/backward compatability.

name

(Optional) string name of the metric instance.

dtype

(Optional) data type of the metric result.

Value

If y_true and y_pred are missing, a Metric instance is returned. The Metric instance that can be passed directly to compile(metrics = ), or used as a standalone object. See ?Metric for example usage. If y_true and y_pred are provided, then a tensor with the computed value is returned.

Examples

Standalone usage:

# EPSILON = 1e-7, y = y_true, y` = y_pred
# y` = clip_op_clip_by_value(output, EPSILON, 1. - EPSILON)
# y` = rbind(c(0.05, 0.95, EPSILON), c(0.1, 0.8, 0.1))
# xent = -sum(y * log(y'), axis = -1)
#      = -((log 0.95), (log 0.1))
#      = [0.051, 2.302]
# Reduced xent = (0.051 + 2.302) / 2
m <- metric_categorical_crossentropy()
m$update_state(rbind(c(0, 1, 0), c(0, 0, 1)),
               rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1)))
m$result()
## tf.Tensor(1.1769392, shape=(), dtype=float32)

# 1.1769392
m$reset_state()
m$update_state(rbind(c(0, 1, 0), c(0, 0, 1)),
               rbind(c(0.05, 0.95, 0), c(0.1, 0.8, 0.1)),
               sample_weight = c(0.3, 0.7))
m$result()
## tf.Tensor(1.6271976, shape=(), dtype=float32)

Usage with compile() API:

model %>% compile(
  optimizer = 'sgd',
  loss = 'mse',
  metrics = list(metric_categorical_crossentropy()))

See Also

Other losses:
Loss()
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()

Other metrics:
Metric()
custom_metric()
metric_auc()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()

Other probabilistic metrics:
metric_binary_crossentropy()
metric_kl_divergence()
metric_poisson()
metric_sparse_categorical_crossentropy()


rstudio/keras documentation built on April 27, 2024, 10:11 p.m.