#' @title Prediction Object for Survival
#'
#' @description
#' This object stores the predictions returned by a learner of class [LearnerSurv].
#'
#' The `task_type` is set to `"surv"`.
#'
#' For accessing survival and hazard functions, as well as other complex methods
#' from a [PredictionSurv] object, see public methods on [distr6::ExoticStatistics()]
#' and example below.
#'
#' @family Prediction
#' @export
#' @examples
#' library(mlr3)
#' task = tsk("rats")
#' learner = lrn("surv.kaplan")
#' p = learner$train(task, row_ids = 1:26)$predict(task, row_ids = 27:30)
#' head(as.data.table(p))
#'
#' p$distr # distr6::Matdist class (test obs x time points)
#'
#' # survival probabilities of the 4 test rats at two time points
#' p$distr$survival(c(20, 100))
PredictionSurv = R6Class("PredictionSurv",
inherit = Prediction,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @details
#' Upon **initialization**, if the `distr` input is a [Distribution][distr6::Distribution],
#' we try to coerce it either to a survival matrix or a survival array and store it
#' in the `$data$distr` slot for internal use.
#'
#' If the stored `$data$distr` is a [Distribution][distr6::Distribution] object,
#' the active field `$distr` (**external user API**) returns it without modification.
#' Otherwise, if `$data$distr` is a survival matrix or array, `$distr`
#' constructs a distribution out of the `$data$distr` object, which will be a
#' [Matdist][distr6::Matdist] or [Arrdist][distr6::Arrdist] respectively.
#'
#' Note that if a survival 3d array is stored in `$data$distr`, the `$distr`
#' field returns an [Arrdist][distr6::Arrdist] initialized with `which.curve = 0.5`
#' by default (i.e. the median curve). This means that measures that require
#' a `distr` prediction like [MeasureSurvGraf], [MeasureSurvRCLL], etc.
#' will use the median survival probabilities.
#' Note that it is possible to manually change `which.curve` after construction
#' of the predicted distribution but we advise against this as it may lead to
#' inconsistent results.
#'
#' @param task ([TaskSurv])\cr
#' Task, used to extract defaults for `row_ids` and `truth`.
#'
#' @param row_ids (`integer()`)\cr
#' Row ids of the predicted observations, i.e. the row ids of the test set.
#'
#' @param truth (`survival::Surv()`)\cr
#' True (observed) response.
#'
#' @param crank (`numeric()`)\cr
#' Numeric vector of predicted continuous rankings (or relative risks). One element for each
#' observation in the test set. For a pair of continuous ranks, a higher rank indicates that
#' the observation is more likely to experience the event.
#'
#' @param distr (`matrix()|[distr6::Arrdist]|[distr6::Matdist]|[distr6::VectorDistribution]`)\cr
#' Either a matrix of predicted survival probabilities, a [distr6::VectorDistribution],
#' a [distr6::Matdist] or an [distr6::Arrdist].
#' If a matrix/array then column names must be given and correspond to survival times.
#' Rows of matrix correspond to individual predictions. It is advised that the
#' first column should be time `0` with all entries `1` and the last
#' with all entries `0`. If a `VectorDistribution` then each distribution in the vector
#' should correspond to a predicted survival distribution.
#'
#' @param lp (`numeric()`)\cr
#' Numeric vector of linear predictor scores. One element for each observation in the test
#' set. \eqn{lp = X\beta} where \eqn{X} is a matrix of covariates and \eqn{\beta} is a vector
#' of estimated coefficients.
#'
#' @param response (`numeric()`)\cr
#' Numeric vector of predicted survival times.
#' One element for each observation in the test set.
#'
#' @param check (`logical(1)`)\cr
#' If `TRUE`, performs argument checks and predict type conversions.
initialize = function(task = NULL, row_ids = task$row_ids, truth = task$truth(), crank = NULL,
distr = NULL, lp = NULL, response = NULL, check = TRUE) {
if (inherits(distr, "Distribution")) {
# coerce to matrix/array if possible
distr = private$.simplify_distr(distr)
}
pdata = list(row_ids = row_ids, truth = truth, crank = crank, distr = distr,
lp = lp, response = response)
pdata = discard(pdata, is.null)
class(pdata) = c("PredictionDataSurv", "PredictionData")
if (check) {
pdata = check_prediction_data(pdata)
}
self$task_type = "surv"
self$man = "mlr3proba::PredictionSurv"
self$data = pdata
self$predict_types = intersect(c("crank", "distr", "lp", "response"), names(pdata))
}
),
active = list(
#' @field truth (`Surv`)\cr
#' True (observed) outcome.
truth = function() {
self$data$truth
},
#' @field crank (`numeric()`)\cr
#' Access the stored predicted continuous ranking.
crank = function() {
self$data$crank %??% rep(NA_real_, length(self$data$row_ids))
},
#' @field distr ([distr6::Matdist]|[distr6::Arrdist]|[distr6::VectorDistribution])\cr
#' Convert the stored survival array or matrix to a survival distribution.
distr = function() {
if (inherits(self$data$distr, "Distribution")) {
return(self$data$distr)
}
private$.distrify_survarray(self$data$distr)
},
#' @field lp (`numeric()`)\cr
#' Access the stored predicted linear predictor.
lp = function() {
self$data$lp
},
#' @field response (`numeric()`)\cr
#' Access the stored predicted survival time.
response = function() {
self$data$response
}
),
private = list(
.censtype = NULL,
.distr = function() self$data$distr %??% NA_real_,
.simplify_distr = function(x) {
if (inherits(x, c("Matdist", "Arrdist"))) {
1 - gprm(x, "cdf") # matrix or 3d array
} else {
if (!inherits(x, "VectorDistribution")) {
stop("'x' is not a 'VectorDistribution'")
}
## check all distributions equal - return x if not
if (x$distlist) {
return(x)
}
## check all distributions are WeightedDiscrete - return x if not
if (all(x$modelTable$Distribution != "WeightedDiscrete")) {
return(x)
}
times = x$getParameterValue("x")
time1 = times[[1L]]
## check all times equal - return x if not
if (!all(map_lgl(times, identical, y = time1))) {
return(x)
}
surv = 1 - do.call(rbind, x$getParameterValue("cdf"))
rownames(surv) = NULL
colnames(surv) = time1
surv
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array") && nrow(x) > 0L) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
}
}
)
)
#' @export
as.data.table.PredictionSurv = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "crank", "lp", "response")])
tab$time = x$data$truth[, 1L]
tab$status = as.logical(x$data$truth[, 2L])
if ("distr" %in% x$predict_types) {
# annoyingly need this many lists to get nice printing
tab$distr = list(list(list(r6_private(x)$.distr())))
}
setcolorder(tab, c("row_ids", "time", "status"))[]
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.