| Loss | R Documentation |
Loss classUse this to define a custom loss class. Note, in most cases you do not need
to subclass Loss to define a custom loss: you can also pass a bare R
function, or a named R function defined with custom_metric(), as a loss
function to compile().
Loss(
classname,
call = NULL,
...,
public = list(),
private = list(),
inherit = NULL,
parent_env = parent.frame()
)
classname |
String, the name of the custom class. (Conventionally, CamelCase). |
call |
function(y_true, y_pred) Method to be implemented by subclasses:
Function that contains the logic for loss calculation using
|
..., public |
Additional methods or public members of the custom class. |
private |
Named list of R objects (typically, functions) to include in
instance private environments. |
inherit |
What the custom class will subclass. By default, the base keras class. |
parent_env |
The R environment that all class methods will have as a grandparent. |
Example subclass implementation:
loss_custom_mse <- Loss(
classname = "CustomMeanSquaredError",
call = function(y_true, y_pred) {
op_mean(op_square(y_pred - y_true), axis = -1)
}
)
# Usage in compile()
model <- keras_model_sequential(input_shape = 10) |> layer_dense(10)
model |> compile(loss = loss_custom_mse())
# Standalone usage
mse <- loss_custom_mse(name = "my_custom_mse_instance")
y_true <- op_arange(20) |> op_reshape(c(4, 5))
y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2
(loss <- mse(y_true, y_pred))
## tf.Tensor(143.5, shape=(), dtype=float32)
loss2 <- (y_pred - y_true)^2 |> op_mean(axis = -1) |> op_mean() stopifnot(all.equal(as.array(loss), as.array(loss2))) sample_weight <-array(c(.25, .25, 1, 1)) (weighted_loss <- mse(y_true, y_pred, sample_weight = sample_weight))
## tf.Tensor(129.0625, shape=(), dtype=float32)
weighted_loss2 <- (y_true - y_pred)^2 |>
op_mean(axis = -1) |>
op_multiply(sample_weight) |>
op_mean()
stopifnot(all.equal(as.array(weighted_loss),
as.array(weighted_loss2)))
A function that returns Loss instances, similar to the
builtin loss functions.
Loss class:initialize(name=NULL, reduction="sum_over_batch_size", dtype=NULL)
Args:
name: Optional name for the loss instance.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be "sum_over_batch_size". Supported options are
"sum", "sum_over_batch_size", "mean",
"mean_with_sample_weight" or NULL. "sum" sums the loss,
"sum_over_batch_size" and "mean" sum the loss and divide by the
sample size, and "mean_with_sample_weight" sums the loss and
divides by the sum of the sample weights. "none" and NULL
perform no aggregation. Defaults to "sum_over_batch_size".
dtype: The dtype of the loss's computations. Defaults to NULL, which
means using config_floatx(). config_floatx() is a
"float32" unless set to different value
(via config_set_floatx()). If a keras$DTypePolicy is
provided, then the compute_dtype will be utilized.
__call__(y_true, y_pred, sample_weight=NULL)
Call the loss instance as a function, optionally with sample_weight.
get_config()
dtype
All R function custom methods (public and private) will have the following symbols in scope:
self: The custom class instance.
super: The custom class superclass.
private: An R environment specific to the class instance.
Any objects assigned here are invisible to the Keras framework.
__class__ and as.symbol(classname): the custom class type object.
Other losses:
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_circle()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.