R/PipeOpPredRegrSurvPEM.R

#' @title PipeOpPredRegrSurvPEM
#' @name mlr_pipeops_trafopred_regrsurv_pem
#'
#' @description
#' Transform [PredictionRegr] to [PredictionSurv].
#' The predicted piece-wise constant hazards contained in [PredictionRegr] are transformed into survival probabilities and wrapped in a
#' [PredictionSurv] object.
#'
#'  We compute the survival probability from the predicted hazards using the following relation:
#'  \deqn{S(t | \mathbf{x}) = \exp \left( - \int_{0}^{t} \lambda(s | \mathbf{x}) \, ds \right) = \exp \left( - \sum_{j = 1}^{J} \lambda(j | \mathbf{x}) d_j\,  \right),}
#'  where \eqn{j = 1, \ldots, J} denotes the interval, \eqn{t} the time, and \eqn{d_j} the duration of interval \eqn{j}.
#'
#'  For a more detailed description of PEM, refer to [pipeline_survtoregr_pem] or the referred article.
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredRegrSurvPEM$new()
#' mlr_pipeops$get("trafopred_regrsurv_pem")
#' po("trafopred_regrsurv_pem")
#' ```
#'
#' @section Input and Output Channels:
#' The input consists of a [PredictionRegr] and a [data.table][data.table::data.table]
#' containing the transformed data. The [PredictionRegr] is provided by the [mlr3::LearnerRegr],
#' while the [data.table] is generated by [PipeOpTaskSurvRegrPEM].
#' The output is the input [PredictionRegr] transformed to a [PredictionSurv].
#' Only works during prediction phase.
#'
#' @references
#' `r format_bib("bender_2018")`
#'
#' @seealso [pipeline_survtoregr_pem]
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredRegrSurvPEM = R6Class(
  "PipeOpPredRegrSurvPEM",
  inherit = mlr3pipelines::PipeOp,

  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    #' @param id (character(1))\cr
    #' Identifier of the resulting object.
    initialize = function(id = "trafopred_regrsurv_pem") {
      super$initialize(
        id = id,
        input = data.table(
          name = c("input", "transformed_data"),
          train = c("NULL", "data.table"),
          predict = c("PredictionRegr", "data.table")
        ),
        output = data.table(
          name = "output",
          train = "NULL",
          predict = "PredictionSurv"
        )
      )
    }
  ),

  active = list(
    #' @field predict_type (`character(1)`)\cr
    #' Returns the active predict type of this PipeOp, which is `"crank"`
    predict_type = function(rhs) {
      assert_ro_binding(rhs)
      "crank"
    }
  ),

  private = list(
    .predict = function(input) {
      pred = input[[1]] # predicted hazards provided by the regression learner
      data = input[[2]] # transformed data
      assert_true(!is.null(pred$response))

      data = cbind(data, dt_hazard = pred$response)

      # From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset)))
      rows_per_id = nrow(data) / length(unique(data$id))

      surv = t(vapply(unique(data$id), function(unique_id) {
        exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])))
      }, numeric(rows_per_id)))

      unique_end_times = sort(unique(data$tend))
      # coerce to distribution and crank
      pred_list = .surv_return(times = unique_end_times, surv = surv)

      # select the real tend values by only selecting the last row of each id
      # basically a slightly more complex unique()
      real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0]

      ids = unique(data$id)
      # select last row for every id => observed times
      id = pem_status = NULL # to fix note
      data = data[, .SD[.N, list(pem_status)], by = id]

      # create prediction object
      p = PredictionSurv$new(
        row_ids = ids,
        crank = pred_list$crank, distr = pred_list$distr,
        truth = Surv(real_tend, as.integer(as.character(data$pem_status))))

      list(p)
    },

    .train = function(input) {
      self$state = list()
      list(input)
    }
  )
)
register_pipeop("trafopred_regrsurv_pem", PipeOpPredRegrSurvPEM)
mlr-org/mlr3proba documentation built on April 12, 2025, 4:38 p.m.