#' @template surv_measure
#' @templateVar title Negative Log-Likelihood
#' @templateVar fullname MeasureSurvLogloss
#' @templateVar eps 1e-15
#' @template param_eps
#' @template param_se
#' @template param_erv
#'
#' @description
#' Calculates the cross-entropy, or negative log-likelihood (NLL) or logarithmic (log), loss.
#' @section Parameter details:
#' - `IPCW` (`logical(1)`)\cr
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper). See Sonabend et al. (2024) for more details.
#'
#' @details
#' The Log Loss, in the context of probabilistic predictions, is defined as the
#' negative log probability density function, \eqn{f}, evaluated at the
#' observation time (event or censoring), \eqn{t},
#' \deqn{L_{NLL}(f, t) = -\log[f(t)]}
#'
#' The standard error of the Log Loss, L, is approximated via,
#' \deqn{se(L) = sd(L)/\sqrt{N}}{se(L) = sd(L)/\sqrt N}
#' where \eqn{N} are the number of observations in the test set, and \eqn{sd} is the standard
#' deviation.
#'
#' The **Re-weighted Negative Log-Likelihood** (RNLL) or IPCW (Inverse Probability Censoring Weighted) Log Loss is defined by
#' \deqn{L_{RNLL}(f, t, \delta) = - \frac{\delta \log[f(t)]}{G(t)}}
#' where \eqn{\delta} is the censoring indicator and \eqn{G(t)} is the Kaplan-Meier estimator of the
#' censoring distribution.
#' So only observations that have experienced the event are taking into account
#' for RNLL (i.e. \eqn{\delta = 1}) and both \eqn{f(t), G(t)} are calculated only at the event times.
#' If only censored observations exist in the test set, `NaN` is returned.
#'
#' @template details_trainG
#'
#' @references
#' `r format_bib("sonabend2024")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvLogloss = R6Class("MeasureSurvLogloss",
inherit = MeasureSurv,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param ERV (`logical(1)`)\cr
#' Standardize measure against a Kaplan-Meier baseline
#' (Explained Residual Variation)
initialize = function(ERV = FALSE) {
assert_logical(ERV)
ps = ps(
eps = p_dbl(0, 1, default = 1e-15),
se = p_lgl(default = FALSE),
IPCW = p_lgl(default = TRUE),
ERV = p_lgl(default = FALSE)
)
ps$set_values(eps = 1e-15, se = FALSE, IPCW = TRUE, ERV = ERV)
range = if (ERV) c(-Inf, 1) else c(0, Inf)
super$initialize(
id = "surv.logloss",
range = range,
minimize = !ERV,
predict_type = "distr",
packages = "distr6",
label = "Log Loss",
man = "mlr3proba::mlr_measures_surv.logloss",
param_set = ps
)
invisible(self)
}
),
private = list(
.score = function(prediction, task, train_set, ...) {
if (self$param_set$values$ERV) {
return(.scoring_rule_erv(self, prediction, task, train_set))
}
x = as.integer(!is.null(task)) + as.integer(!is.null(train_set))
if (x == 1) {
stop("Either 'task' and 'train_set' should be passed to measure or neither.")
} else if (x) {
train = task$truth(train_set)
} else {
train = NULL
}
ps = self$param_set$values
if (ps$se) {
ll = surv_logloss(prediction$truth, prediction$data$distr, ps$eps, ps$IPCW, train) # nolint
sd(ll) / sqrt(length(ll))
} else {
mean(surv_logloss(prediction$truth, prediction$data$distr, ps$eps, ps$IPCW, train)) # nolint
}
}
)
)
register_measure("surv.logloss", MeasureSurvLogloss)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.