Loss | R Documentation |
Loss
classUse this to define a custom loss class. Note, in most cases you do not need
to subclass Loss
to define a custom loss: you can also pass a bare R
function, or a named R function defined with custom_metric()
, as a loss
function to compile()
.
Loss(
classname,
call = NULL,
...,
public = list(),
private = list(),
inherit = NULL,
parent_env = parent.frame()
)
classname |
String, the name of the custom class. (Conventionally, CamelCase). |
call |
function(y_true, y_pred) Method to be implemented by subclasses:
Function that contains the logic for loss calculation using
|
... , public |
Additional methods or public members of the custom class. |
private |
Named list of R objects (typically, functions) to include in
instance private environments. |
inherit |
What the custom class will subclass. By default, the base keras class. |
parent_env |
The R environment that all class methods will have as a grandparent. |
Example subclass implementation:
loss_custom_mse <- Loss( classname = "CustomMeanSquaredError", call = function(y_true, y_pred) { op_mean(op_square(y_pred - y_true), axis = -1) } ) # Usage in compile() model <- keras_model_sequential(input_shape = 10) |> layer_dense(10) model |> compile(loss = loss_custom_mse()) # Standalone usage mse <- loss_custom_mse(name = "my_custom_mse_instance") y_true <- op_arange(20) |> op_reshape(c(4, 5)) y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2 (loss <- mse(y_true, y_pred))
## tf.Tensor(123.5, shape=(), dtype=float32)
loss2 <- (y_pred - y_true)^2 |> op_mean(axis = -1) |> op_mean() stopifnot(all.equal(as.array(loss), as.array(loss2))) sample_weight <-array(c(.25, .25, 1, 1)) (weighted_loss <- mse(y_true, y_pred, sample_weight = sample_weight))
## tf.Tensor(112.8125, shape=(), dtype=float32)
weighted_loss2 <- (y_true - y_pred)^2 |> op_mean(axis = -1) |> op_multiply(sample_weight) |> op_mean() stopifnot(all.equal(as.array(weighted_loss), as.array(weighted_loss2)))
A function that returns Loss
instances, similar to the
builtin loss functions.
Loss
class:initialize(name=NULL, reduction="sum_over_batch_size", dtype=NULL)
Args:
name
: Optional name for the loss instance.
reduction
: Type of reduction to apply to the loss. In almost all cases
this should be "sum_over_batch_size"
. Supported options are
"sum"
, "sum_over_batch_size"
, "mean"
,
"mean_with_sample_weight"
or NULL
. "sum"
sums the loss,
"sum_over_batch_size"
and "mean"
sum the loss and divide by the
sample size, and "mean_with_sample_weight"
sums the loss and
divides by the sum of the sample weights. "none"
and NULL
perform no aggregation. Defaults to "sum_over_batch_size"
.
dtype
: The dtype of the loss's computations. Defaults to NULL
, which
means using config_floatx()
. config_floatx()
is a
"float32"
unless set to different value
(via config_set_floatx()
). If a keras$DTypePolicy
is
provided, then the compute_dtype
will be utilized.
__call__(y_true, y_pred, sample_weight=NULL)
Call the loss instance as a function, optionally with sample_weight
.
get_config()
dtype
All R function custom methods (public and private) will have the following symbols in scope:
self
: The custom class instance.
super
: The custom class superclass.
private
: An R environment specific to the class instance.
Any objects assigned here are invisible to the Keras framework.
__class__
and as.symbol(classname)
: the custom class type object.
Other losses:
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_circle()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.