R/surv-roc_curve_survival.R

Defines functions autoplot.roc_survival_df roc_curve_survival_impl_one roc_curve_survival_impl roc_curve_survival_vec roc_curve_survival.data.frame roc_curve_survival

Documented in roc_curve_survival roc_curve_survival.data.frame

#' Time-Dependent ROC surve for Censored Data
#'
#' Compute the ROC survival curve using predicted survival probabilities that
#' corresponds to different time points.
#'
#' @family survival curve metrics
#' @templateVar fn roc_curve_survival
#'
#' @inheritParams brier_survival
#'
#' @details
#'
#' This formulation takes survival probability predictions at one or more
#' specific _evaluation times_ and, for each time, computes the ROC curve. To
#' account for censoring, inverse probability of censoring weights (IPCW) are
#' used in the calculations. See equation 7 of section 4.3 in Blanche _at al_
#' (2013) for the details.
#'
#' The column passed to `...` should be a list column with one element per
#' independent experiential unit (e.g. patient). The list column should contain
#' data frames with several columns:
#'
#'  - `.eval_time`: The time that the prediction is made.
#'  - `.pred_survival`: The predicted probability of survival up to `.eval_time`
#'  - `.weight_censored`: The case weight for the inverse probability of censoring.
#'
#' The last column can be produced using [parsnip::.censoring_weights_graf()].
#' This corresponds to the weighting scheme of  Graf _et al_ (1999). The
#' internal data set `lung_surv` shows an example of the format.
#'
#' This method automatically groups by the `.eval_time` argument.
#'
#' @return
#' A tibble with class `roc_survival_df`, `grouped_roc_survival_df` having
#' columns `.threshold`, `sensitivity`, `specificity`, and `.eval_time`.
#'
#' @seealso
#' Compute the area under the ROC survival curve with [roc_auc_survival()].
#'
#' @author Emil Hvitfeldt
#'
#' @references
#'
#' Blanche, P., Dartigues, J.-F. and Jacqmin-Gadda, H. (2013), Review and
#' comparison of ROC curve estimators for a time-dependent outcome with
#' marker-dependent censoring. _Biom. J._, 55: 687-704.
#'
#' Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999), Assessment
#' and comparison of prognostic classification schemes for survival data.
#' _Statist. Med._, 18: 2529-2545.
#'
#' @examples
#' result <- roc_curve_survival(
#'   lung_surv,
#'   truth = surv_obj,
#'   .pred
#' )
#' result
#'
#' # ---------------------------------------------------------------------------
#' # `autoplot()`
#'
#' # Visualize the curve using ggplot2 manually
#' library(ggplot2)
#' library(dplyr)
#' result %>%
#'   mutate(.eval_time = format(.eval_time)) %>%
#'   ggplot(aes(
#'     x = 1 - specificity, y = sensitivity,
#'     group = .eval_time, col = .eval_time
#'   )) +
#'   geom_step(direction = "hv") +
#'   geom_abline(lty = 3) +
#'   coord_equal() +
#'   theme_bw()
#'
#' # Or use autoplot
#' autoplot(result)
#' @keywords internal
#' @export
roc_curve_survival <- function(data, ...) {
  UseMethod("roc_curve_survival")
}

#' @export
#' @rdname roc_curve_survival
roc_curve_survival.data.frame <- function(data,
                                          truth,
                                          ...,
                                          na_rm = TRUE,
                                          case_weights = NULL) {
  result <- curve_survival_metric_summarizer(
    name = "roc_curve_survival",
    fn = roc_curve_survival_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )

  curve_finalize(result, data, "roc_survival_df", "grouped_roc_survival_df")
}

roc_curve_survival_vec <- function(truth,
                                   estimate,
                                   na_rm = TRUE,
                                   case_weights = NULL,
                                   ...) {
  check_dynamic_survival_metric(
    truth, estimate, case_weights
  )

  if (na_rm) {
    result <- yardstick_remove_missing(
      truth, seq_along(estimate), case_weights
    )

    truth <- result$truth
    estimate <- estimate[result$estimate]
    case_weights <- result$case_weights
  } else {
    any_missing <- yardstick_any_missing(
      truth, estimate, case_weights
    )
    if (any_missing) {
      return(NA_real_)
    }
  }

  roc_curve_survival_impl(truth = truth, estimate = estimate)
}

roc_curve_survival_impl <- function(truth,
                                    estimate) {
  event_time <- .extract_surv_time(truth)
  delta <- .extract_surv_status(truth)
  data <- dplyr::tibble(event_time, delta, estimate)
  data <- tidyr::unnest(data, cols = estimate)

  .eval_times <- unique(data$.eval_time)

  not_missing_pred_survival <- !is.na(data$.pred_survival)

  out <- list()
  for (i in seq_along(.eval_times)) {
    .eval_time_ind <- .eval_times[[i]] == data$.eval_time & not_missing_pred_survival

    res <- roc_curve_survival_impl_one(
      data$event_time[.eval_time_ind],
      data$delta[.eval_time_ind],
      data[.eval_time_ind, ]
    )

    res$.eval_time <- .eval_times[[i]]
    out[[i]] <- res
  }

  dplyr::bind_rows(out)
}

roc_curve_survival_impl_one <- function(event_time, delta, data) {
  res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf))))

  obs_time_le_time <- event_time <= data$.eval_time
  obs_time_gt_time <- event_time > data$.eval_time
  n <- nrow(data)
  multiplier <- delta / (n * data$.weight_censored)

  sensitivity_denom <- sum(obs_time_le_time * multiplier, na.rm = TRUE)
  specificity_denom <- sum(obs_time_gt_time, na.rm = TRUE)

  data_df <- data.frame(
    le_time = obs_time_le_time,
    ge_time = obs_time_gt_time,
    multiplier = multiplier
  )
  data_split <- vec_split(data_df, data$.pred_survival)
  data_split <- data_split$val[order(data_split$key)]

  sensitivity <- vapply(
    data_split,
    function(x) sum(x$le_time * x$multiplier, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )
  sensitivity <- cumsum(sensitivity)
  sensitivity <- sensitivity / sensitivity_denom
  sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity)
  sensitivity <- dplyr::if_else(sensitivity < 0, 0, sensitivity)
  sensitivity <- c(0, sensitivity, 1)
  res$sensitivity <- sensitivity

  specificity <- vapply(
    data_split,
    function(x) sum(x$ge_time, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )
  specificity <- cumsum(specificity)
  specificity <- specificity / specificity_denom
  specificity <- dplyr::if_else(specificity > 1, 1, specificity)
  specificity <- dplyr::if_else(specificity < 0, 0, specificity)
  specificity <- c(0, specificity, 1)
  specificity <- 1 - specificity
  res$specificity <- specificity

  res
}

# Dynamically exported
autoplot.roc_survival_df <- function(object, ...) {
  `%+%` <- ggplot2::`%+%`
  object$.eval_time <- format(object$.eval_time)

  # Base chart
  roc_chart <- ggplot2::ggplot(data = object)

  # create aesthetic
  roc_aes <- ggplot2::aes(
    x = 1 - specificity,
    y = sensitivity,
    color = .eval_time,
    group = .eval_time
  )

  # build the graph
  roc_chart <- roc_chart %+%
    ggplot2::geom_step(mapping = roc_aes, direction = "hv") %+%
    ggplot2::geom_abline(lty = 3) %+%
    ggplot2::coord_equal() %+%
    ggplot2::theme_bw()

  roc_chart
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 21, 2023, 9:08 a.m.