R/PipeOpDistrCompositor.R

#' @title PipeOpDistrCompositor
#' @name mlr_pipeops_distrcompose
#' @template param_pipelines
#'
#' @description
#' `r lifecycle::badge("experimental")`
#'
#' Estimates (or 'composes') a survival distribution from a predicted baseline
#' survival distribution (`distr`) and a linear predictor (`lp`) from two [PredictionSurv]s.
#'
#' Compositor Assumptions:
#' * The baseline `distr` is a discrete estimator, e.g. [surv.kaplan][LearnerSurvKaplan].
#' * The composed `distr` is of a linear form
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] or with the associated sugar
#' function [mlr3pipelines::po()]:
#' ```
#' PipeOpDistrCompositor$new()
#' mlr_pipeops$get("distrcompose")
#' po("distrcompose")
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpDistrCompositor] has two input channels, `"base"` and `"pred"`.
#' Both input channels take `NULL` during training and [PredictionSurv] during prediction.
#'
#' [PipeOpDistrCompositor] has one output channel named `"output"`, producing
#' `NULL` during training and a [PredictionSurv] during prediction.
#'
#' The output during prediction is the [PredictionSurv] from the `"pred"` input
#' but with an extra (or overwritten) column for the `distr` predict type; which
#' is composed from the `distr` of `"base"` and the `lp` of `"pred"`.
#' If no `lp` predictions have been made or exist, then the `"pred"` is returned unchanged.
#'
#' @section State:
#' The `$state` is left empty (`list()`).
#'
#' @section Parameters:
#' The parameters are:
#' * `form` :: `character(1)` \cr
#'    Determines the form that the predicted linear survival model should take. This is either,
#'    accelerated-failure time, `aft`, proportional hazards, `ph`, or proportional odds, `po`.
#'    Default `aft`.
#' * `overwrite` :: `logical(1)` \cr
#'    If `FALSE` (default) then if the "pred" input already has a `distr`, the compositor does
#'    nothing and returns the given [PredictionSurv]. If `TRUE`, then the `distr` is overwritten
#'    with the `distr` composed from `lp` - this is useful for changing the prediction
#'    `distr` from one model form to another.
#' * `scale_lp` :: `logical(1)` \cr
#'    This option is only applicable to `form` equal to `"aft"`. If `TRUE`, it
#'    min-max scales the linear prediction scores to be in the interval \eqn{[0,1]},
#'    avoiding extrapolation of the baseline \eqn{S_0(t)} on the transformed time
#'    points \eqn{\frac{t}{\exp(lp)}}, as these will be \eqn{\in [\frac{t}{e}, t]},
#'    and so always smaller than the maximum time point for which we have estimated
#'    \eqn{S_0(t)}.
#'    Note that this is just a **heuristic** to get reasonable results in the
#'    case you observe survival predictions to be e.g. constant after the AFT
#'    composition and it definitely provides no guarantee for creating calibrated
#'    distribution predictions (as none of these methods do). Therefore, it is
#'    set to `FALSE` by default.
#'
#' @section Internals:
#' The respective `form`s above have respective survival distributions:
#'    \deqn{aft: S(t) = S_0(\frac{t}{\exp(lp)})}
#'    \deqn{ph: S(t) = S_0(t)^{\exp(lp)}}
#'    \deqn{po: S(t) = \frac{S_0(t)}{\exp(-lp) + (1-\exp(-lp)) S_0(t)}}
#' where \eqn{S_0} is the estimated baseline survival distribution, and \eqn{lp} is the
#' predicted linear predictor.
#'
#' For an example use of the `"aft"` composition using Kaplan-Meier as a baseline
#' distribution, see Norman et al. (2024).
#'
#' @seealso [pipeline_distrcompositor]
#' @references
#' `r format_bib("norman_2024")`
#' @export
#' @family survival compositors
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#'   library(mlr3)
#'   library(mlr3pipelines)
#'   task = tsk("rats")
#'
#'   base = lrn("surv.kaplan")$train(task)$predict(task)
#'   pred = lrn("surv.coxph")$train(task)$predict(task)
#'   # let's change the distribution prediction of Cox (Breslow-based) to an AFT form:
#'   pod = po("distrcompose", param_vals = list(form = "aft", overwrite = TRUE))
#'   pod$predict(list(base = base, pred = pred))[[1]]
#' }
PipeOpDistrCompositor = R6Class("PipeOpDistrCompositor",
  inherit = mlr3pipelines::PipeOp,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function(id = "distrcompose", param_vals = list()) {
      param_set = ps(
        form = p_fct(default = "aft", levels = c("aft", "ph", "po"), tags = "predict"),
        overwrite = p_lgl(default = FALSE, tags = "predict"),
        scale_lp = p_lgl(default = FALSE, tags = "predict")
      )
      param_set$set_values(form = "aft", overwrite = FALSE, scale_lp = FALSE)

      super$initialize(
        id = id,
        param_set = param_set,
        param_vals = param_vals,
        input = data.table(name = c("base", "pred"), train = "NULL", predict = "PredictionSurv"),
        output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"),
        packages = c("mlr3proba", "distr6")
      )
    }
  ),

  private = list(
    .train = function(inputs) {
      self$state = list()
      list(NULL)
    },

    .predict = function(inputs) {
      base = inputs$base
      pred = inputs$pred

      # if no `lp` predictions, we return the survival prediction object unchanged
      if (is.null(pred$lp)) {
        return(list(pred))
      }

      overwrite = assert_logical(self$param_set$values$overwrite)
      if ("distr" %in% pred$predict_types & !overwrite) {
        return(list(pred))
      } else {
        assert("distr" %in% base$predict_types)

        # check: targets are the same
        assert_true(identical(base$truth, pred$truth))

        form = self$param_set$values$form
        nr = length(pred$data$row_ids)

        # we need 'lp' predictions
        lp = pred$lp

        if (inherits(base$data$distr, "Distribution")) {
          base = distr6::as.MixtureDistribution(base$distr)
          times = unlist(base[1L]$properties$support$elements)
          nc = length(times)
          survmat = matrix(1 - base$cdf(times), nrow = nr, ncol = nc, byrow = TRUE)
        } else {
          # average survival probability across observations (on the test set)
          avg_surv = colMeans(base$data$distr)
          times = as.numeric(names(avg_surv))
          nc = length(times)
          survmat = matrix(avg_surv, nrow = nr, ncol = nc, byrow = TRUE)
        }

        timesmat = matrix(times, nrow = nr, ncol = nc, byrow = TRUE)
        lpmat = matrix(lp, nrow = nr, ncol = nc)

        # compose survival distribution
        if (form == "ph") {
          cdf = 1 - (survmat ^ exp(lpmat))
        } else if (form == "aft") {
          # add heuristic to keep the transformed t/exp(lp) time points within
          # the domain of S_0(t)
          if (self$param_set$values$scale_lp) {
            lpmat = (lpmat - min(lpmat)) / (max(lpmat) - min(lpmat))
          }
          # calculate cdf = 1 - S_0(t) on the time points t/exp(lp)
          mtc = findInterval(timesmat / exp(lpmat), times)
          mtc[mtc == 0] = NA
          cdf = 1 - matrix(survmat[1L, mtc], nr, nc, FALSE)
          cdf[is.na(cdf)] = 0
        } else if (form == "po") {
          cdf = 1 - (survmat * ((exp(-lpmat) + ((1 - exp(-lpmat)) * survmat))^-1))
          cdf[survmat == 1] = 0
        }

        distr = .surv_return(times = times, surv = 1 - cdf)$distr

        p = PredictionSurv$new(
          row_ids = pred$row_ids,
          truth = pred$truth,
          crank = pred$crank,
          lp = pred$lp,
          response = pred$response,
          distr = distr # overwrite only the distribution
        )

        return(list(p))
      }
    }
  )
)

register_pipeop("distrcompose", PipeOpDistrCompositor)
mlr-org/mlr3proba documentation built on April 12, 2025, 4:38 p.m.