Nothing
#' @keywords internal
.validate_loss_fn = function(fn) {
# ---- Arity check ----
if (length(formals(fn)) < 2L) {
cli::cli_abort(c(
"Custom loss function must accept at least 2 arguments: {.arg input} and {.arg target}.",
i = "Use a lambda like {.code \\(input, target) torch::nnf_mse_loss(input, target)}."
), class = "loss_fn_arity_error")
}
# ---- Dry-run probe ----
if (requireNamespace("torch", quietly = TRUE)) {
dummy_input = torch::torch_randn(c(2L, 1L))
dummy_target = torch::torch_randn(c(2L, 1L))
out = tryCatch(
fn(dummy_input, dummy_target),
error = function(e) {
cli::cli_abort(c(
"Dry-run of custom loss function failed.",
x = "{e$message}",
i = "Ensure your function accepts two {.cls torch_tensor}s and returns a scalar tensor."
), class = "loss_fn_probe_error")
}
)
if (!inherits(out, "torch_tensor")) {
cli::cli_abort(c(
"Custom loss function must return a {.cls torch_tensor}.",
x = "Got {.cls {class(out)}}.",
i = "Ensure your function returns the result of a {.pkg torch} operation."
), class = "loss_fn_output_error")
}
if (out$numel() != 1L) {
cli::cli_warn(c(
"Custom loss function returned a non-scalar tensor with {out$numel()} elements.",
i = "Loss functions should return a scalar. Did you forget to reduce (e.g. {.code $mean()})?"
), class = "loss_fn_shape_warning")
}
}
# ---- Build the call-time type-guarded wrapper ----
function(input, target) {
out = fn(input, target)
if (!inherits(out, "torch_tensor")) {
cli::cli_abort(c(
"Custom loss function must return a {.cls torch_tensor}.",
x = "Got {.cls {class(out)}}."
), class = "loss_fn_output_error")
}
out
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.