R/PipeOpPredClassifSurvIPCW.R

#' @title PipeOpPredClassifSurvIPCW
#' @name mlr_pipeops_trafopred_classifsurv_IPCW
#'
#' @description
#' Transform [PredictionClassif] to [PredictionSurv] using the **I**nverse
#' **P**robability of **C**ensoring **W**eights (IPCW) method by Vock et al. (2016).
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredClassifSurvIPCW$new()
#' mlr_pipeops$get("trafopred_classifsurv_IPCW")
#' po("trafopred_classifsurv_IPCW")
#' ```
#'
#' @section Input and Output Channels:
#' The input is a [PredictionClassif] and a [data.table] containing observed times,
#' censoring indicators and row ids, all generated by [PipeOpTaskSurvClassifIPCW]
#' during the prediction phase.
#'
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv].
#' Each input classification probability prediction corresponds to the
#' probability of having the event up to the specified cutoff time
#' \eqn{\hat{\pi}(\bold{X}_i) = P(T_i < \tau|\bold{X}_i)},
#' see Vock et al. (2016) and [PipeOpTaskSurvClassifIPCW].
#' Therefore, these predictions serve as **continuous risk scores** that can be
#' directly interpreted as `crank` predictions in the right-censored survival
#' setting. We also map them to the survival distribution prediction `distr`,
#' at the specified cutoff time point \eqn{\tau}, i.e. as
#' \eqn{S_i(\tau) = 1 - \hat{\pi}(\bold{X}_i)}.
#' Survival measures that use the survival distribution (eg [ISBS][mlr_measures_surv.brier])
#' should be evaluated exactly at the cutoff time point \eqn{\tau}, see example.
#'
#' @references
#' `r format_bib("vock_2016")`
#'
#' @seealso [pipeline_survtoclassif_IPCW]
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurvIPCW = R6Class("PipeOpPredClassifSurvIPCW",
  inherit = mlr3pipelines::PipeOp,

  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    #' @param id (character(1))\cr
    #' Identifier of the resulting object.
    initialize = function(id = "trafopred_classifsurv_IPCW") {
      super$initialize(
        id = id,
        input = data.table(
          name = c("input", "data"),
          train = c("NULL", "NULL"),
          predict = c("PredictionClassif", "list")
        ),
        output = data.table(
          name = "output",
          train = "NULL",
          predict = "PredictionSurv"
        )
      )
    }
  ),

  active = list(
    #' @field predict_type (`character(1)`)\cr
    #' Returns the active predict type of this PipeOp, which is `"crank"`
    predict_type = function(rhs) {
      assert_ro_binding(rhs)
      "crank"
    }
  ),

  private = list(
    .predict = function(input) {
      pred = input[[1]] # classification predictions
      data = input[[2]] # row_ids, times, status, tau

      # risk => prob of having the event up until the cutoff time
      risk = pred$prob[, "1"]
      surv = matrix(data = 1 - risk, ncol = 1)
      colnames(surv) = data$tau

      p = PredictionSurv$new(
        # the original row ids
        row_ids = data$row_ids,
        # the original truth (times, status)
        truth = Surv(time = data$times, event = data$status),
        crank = risk,
        distr = surv
      )

      list(p)
    },

    .train = function(input) {
      self$state = list()
      list(input)
    }
  )
)

register_pipeop("trafopred_classifsurv_IPCW", PipeOpPredClassifSurvIPCW)
mlr-org/mlr3proba documentation built on April 12, 2025, 4:38 p.m.