R/surv_return.R

Defines functions get_mortality .ext_surv_mat .surv_return

Documented in get_mortality .surv_return

#' @title Get Survival Predict Types
#'
#' @description Internal helper function to easily return the correct survival predict types.
#'
#' @param times (`numeric()`) \cr Vector of survival times.
#' @param surv (`matrix()|array()`)\cr Matrix or array of predicted survival probabilities, rows (1st dimension) are observations, columns (2nd dimension) are times and in the case of an array there should be one more dimension.
#' Number of columns should be equal to length of `times`.
#' In case a `numeric()` vector is provided, it is converted to a single row (one observation) matrix.
#' @param crank (`numeric()`)\cr Relative risk/continuous ranking.
#' Higher value is associated with higher risk.
#' If `NULL` then either set as `-response` if available or `lp` if available (this assumes that the `lp` prediction comes from a PH type model - in case of an AFT model the user should provide `-lp`).
#' In case neither `response` or `lp` are provided, then `crank` is calculated as the sum of the cumulative hazard function (**expected mortality**) derived from the predicted survival function (`surv`), see [get_mortality].
#' In case `surv` is a 3d array, we use the `which.curve` parameter to decide which survival matrix (index in the 3rd dimension) will be chosen for the calculation of `crank`.
#' @param lp (`numeric()`)\cr Predicted linear predictor, used to impute `crank` if `NULL`.
#' @param response (`numeric()`)\cr Predicted survival time, passed through function without
#' modification.
#' @param which.curve Which curve (3rd dimension) should the `crank` be calculated for, in case `surv` is an `array`?
#' If between (0,1) it is taken as the quantile of the curves otherwise if greater than 1 it is taken as the  curve index.
#' It can also be 'mean' and the survival probabilities are averaged across the 3rd dimension.
#' Default value (`NULL`) is the **0.5 quantile** which is the median across the 3rd dimension of the survival array.
#'
#' @references
#' `r format_bib("sonabend_2022")`
#'
#' @examples
#' n = 10 # number of observations
#' k = 50 # time points
#'
#' # Create the matrix with random values between 0 and 1
#' mat = matrix(runif(n * k, min = 0, max = 1), nrow = n, ncol = k)
#'
#' # transform it to a survival matrix
#' surv_mat = t(apply(mat, 1L, function(row) sort(row, decreasing = TRUE)))
#'
#' # crank is expected mortality, distr is the survival matrix
#' .surv_return(times = 1:k, surv = surv_mat)
#'
#' # if crank is set, it's not overwritten
#' .surv_return(times = 1:k, surv = surv_mat, crank = rnorm(n))
#'
#' # lp = crank
#' .surv_return(lp = rnorm(n))
#'
#' # if response is set and no crank, crank = -response
#' .surv_return(response = sample(1:100, n))
#'
#' # if both are set, they are not overwritten
#' .surv_return(crank = rnorm(n), response = sample(1:100, n))
#'
#' @export
.surv_return = function(times = NULL, surv = NULL, crank = NULL, lp = NULL,
  response = NULL, which.curve = NULL) {

  if (!is.null(surv)) {
    if (class(surv)[1L] == "numeric") {
      # in case of a vector (one observation) convert to matrix
      surv = matrix(surv, nrow = 1L, dimnames = list(NULL, names(surv)))
    }
    if (class(surv)[1L] == "array" && length(dim(surv)) != 3L) {
      stop("3D survival arrays supported only")
    }
    times = times %??% colnames(surv)
    if (length(times) != ncol(surv)) {
      stop("'times' must have the same length as the 2nd dimension (columns of 'surv')")
    }
    colnames(surv) = times
  }

  if (is.null(crank)) {
    if (!is.null(response)) {
      # low survival time = high risk
      # high crank = high risk
      crank = -response
    } else if (!is.null(lp)) {
      # assumes PH-type lp where high value = high risk
      crank = lp
    } else if (!is.null(surv)) {
      if (inherits(surv, "matrix")) {
        crank = get_mortality(surv)
      } else { # array
        surv_mat = .ext_surv_mat(surv, which.curve)
        crank = get_mortality(surv_mat)
      }
    }
  }

  # TODO: pass the 'which.curve' parameter in PredictionSurv
  list(
    distr = surv, # matrix or array
    crank = crank,
    lp = lp,
    response = response
  )
}

# helper function to extract a survival matrix from a 3D survival array
.ext_surv_mat = function(arr, which.curve) {
  # if NULL return the 'median' curve (default)
  if (is.null(which.curve)) {
    return(array(apply(arr, 1:2, stats::quantile, 0.5), c(nrow(arr), ncol(arr)),
      dimnames(arr)[1:2]))
  }

  # which.curve must be length 1 and either 'mean' or >0
  ok = (length(which.curve) == 1L) &&
    ((is.character(which.curve) && which.curve == "mean") ||
      (is.numeric(which.curve) && which.curve > 0))
  if (!ok) {
    stop("'which.curve' has to be a numeric between (0,1) or the index of the
      3rd dimension or 'mean'")
  }

  if (is.numeric(which.curve) && which.curve > dim(arr)[3L]) {
    stopf(
      "Length is %s on third dimension but curve '%s' requested, change 'which.curve' parameter.",
      dim(arr)[3L],
      which.curve
    )
  }

  # mean
  if (which.curve == "mean") {
    apply(arr, 1:2, mean)
    # curve chosen based on quantile
  } else if (which.curve < 1) {
    array(apply(arr, 1:2, stats::quantile, which.curve), c(nrow(arr), ncol(arr)),
      dimnames(arr)[1:2])
    # curve chosen based on index
  } else {
    array(arr[, , which.curve], c(nrow(arr), ncol(arr)), dimnames(arr)[1:2])
  }
}

#' @title Calculate the expected mortality risks from a survival matrix
#'
#' @description Many methods can be used to reduce a discrete survival
#' distribution prediction (i.e. matrix) to a relative risk / ranking
#' prediction, see Sonabend et al. (2022).
#'
#' This function calculates a relative risk score as the sum of the
#' predicted cumulative hazard function, also called **ensemble/expected mortality**.
#' This risk score can be loosely interpreted as the expected number of deaths for
#' patients with similar characteristics, see Ishwaran et al. (2008) and has no
#' model or survival distribution assumptions.
#'
#' @param x (`matrix()`) \cr A survival matrix where rows are the
#' (predicted) observations and columns the time-points.
#' For more details, see [assert_surv_matrix].
#'
#' @return a `numeric` vector of the mortality risk scores, one per row of the
#' input survival matrix.
#'
#' @references
#' `r format_bib("sonabend_2022", "ishwaran_2008")`
#'
#' @examples
#' n = 10 # number of observations
#' k = 50 # time points
#'
#' # Create the matrix with random values between 0 and 1
#' mat = matrix(runif(n * k, min = 0, max = 1), nrow = n, ncol = k)
#'
#' # transform it to a survival matrix
#' surv_mat = t(apply(mat, 1L, function(row) sort(row, decreasing = TRUE)))
#' colnames(surv_mat) = 1:k # time points
#'
#' # get mortality scores (the larger, the more risk)
#' mort = get_mortality(surv_mat)
#' mort
#'
#' @export
get_mortality = function(x) {
  assert_surv_matrix(x)

  # H(t) = -log(S(t))
  cumhaz = -log(x)

  # Ignore S(t) = 0 => -log(S(t)) = Inf
  cumhaz[is.infinite(cumhaz)] = 0

  rowSums(cumhaz)
}
mlr-org/mlr3proba documentation built on April 12, 2025, 4:38 p.m.