loss_dice: Computes the Dice loss value between 'y_true' and 'y_pred'.

loss_diceR Documentation

Computes the Dice loss value between y_true and y_pred.

Description

Formula:

loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))

Formula:

loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))

Usage

loss_dice(
  y_true,
  y_pred,
  ...,
  reduction = "sum_over_batch_size",
  name = "dice",
  axis = NULL,
  dtype = NULL
)

Arguments

y_true

Tensor of true targets.

y_pred

Tensor of predicted targets.

...

For forward/backward compatability.

reduction

Type of reduction to apply to the loss. In almost all cases this should be "sum_over_batch_size". Supported options are "sum", "sum_over_batch_size" or NULL.

name

String, name for the object

axis

List of which dimensions the loss is calculated. Defaults to NULL.

dtype

The dtype of the loss's computations. Defaults to NULL, which means using config_floatx(). config_floatx() is a "float32" unless set to different value (via config_set_floatx()).

Value

if y_true and y_pred are provided, Dice loss value. Otherwise, a Loss() instance.

Example

y_true <- array(c(1, 1, 0, 0,
                  1, 1, 0, 0), dim = c(2, 2, 2, 1))
y_pred <- array(c(0, 0.4, 0,   0,
                  1,   0, 1, 0.9), dim = c(2, 2, 2, 1))

axis <- c(2, 3, 4)
loss <- loss_dice(y_true, y_pred, axis = axis)
stopifnot(shape(loss) == shape(2))
loss
## tf.Tensor([0.50000001 0.75757576], shape=(2), dtype=float64)

loss = loss_dice(y_true, y_pred)
stopifnot(shape(loss) == shape())
loss
## tf.Tensor(0.6164383614186526, shape=(), dtype=float64)

See Also

Other losses:
Loss()
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_cosine_similarity()
loss_ctc()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()


rstudio/keras documentation built on July 8, 2024, 3:07 p.m.