R/KaplanMeier.R

Defines functions prepareKaplanMeier plotKaplanMeier

Documented in plotKaplanMeier

# @file KaplanMeier.R
#
# Copyright 2026 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Plot the Kaplan-Meier curve
#'
#' @description
#' \code{plotKaplanMeier} creates the Kaplan-Meier (KM) survival plot. Based (partially) on recommendations
#' in Pocock et al (2002).
#'
#' When variable-sized strata are detected, an adjusted KM plot is computed to account for stratified data,
#' as described in Galimberti eta al (2002), using the closed form variance estimator described in Xie et al
#' (2005).
#'
#' @param population            A population object generated by \code{createStudyPopulation},
#'                              potentially filtered by other functions.
#' @param censorMarks           Whether or not to include censor marks in the plot.
#' @param confidenceIntervals   Plot 95 percent confidence intervals? Default is TRUE, as recommended by Pocock et al.
#' @param includeZero           Should the y axis include zero, or only go down to the lowest observed
#'                              survival? The default is FALSE, as recommended by Pocock et al.
#' @param dataTable             Should the numbers at risk be shown in a table? Default is TRUE, as recommended by Pocock et al.
#' @param dataCutoff            Fraction of the data (number censored) after which the graph will not
#'                              be shown. The default is 90 percent as recommended by Pocock et al.
#' @param targetLabel           A label to us for the target cohort.
#' @param comparatorLabel       A label to us for the comparator cohort.
#' @param title                 The main title of the plot.
#' @param fileName              Name of the file where the plot should be saved, for example
#'                              'plot.png'. See the function \code{ggsave} in the ggplot2 package for
#'                              supported file formats.
#'
#' @return
#' A ggplot object. Use the \code{\link[ggplot2]{ggsave}} function to save to file in a different
#' format.
#'
#' @references
#' Pocock SJ, Clayton TC, Altman DG. (2002) Survival plots of time-to-event outcomes in clinical trials:
#' good practice and pitfalls, Lancet, 359:1686-89.
#'
#' Galimberti S, Sasieni P, Valsecchi MG (2002) A weighted Kaplan-Meier estimator for matched data with
#' application to the comparison of chemotherapy and bone-marrow transplant in leukaemia. Statistics in
#' Medicine, 21(24):3847-64.
#'
#' Xie J, Liu C. (2005) Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment
#' weighting for survival data. Statistics in Medicine, 26(10):2276.
#'
#' @export
plotKaplanMeier <- function(population,
                            censorMarks = FALSE,
                            confidenceIntervals = TRUE,
                            includeZero = FALSE,
                            dataTable = TRUE,
                            dataCutoff = 0.90,
                            targetLabel = "Treated",
                            comparatorLabel = "Comparator",
                            title = NULL,
                            fileName = NULL) {
  errorMessages <- checkmate::makeAssertCollection()
  checkmate::assertDataFrame(population, add = errorMessages)
  checkmate::assertLogical(censorMarks, len = 1, add = errorMessages)
  checkmate::assertLogical(confidenceIntervals, len = 1, add = errorMessages)
  checkmate::assertLogical(includeZero, len = 1, add = errorMessages)
  checkmate::assertLogical(dataTable, len = 1, add = errorMessages)
  checkmate::assertNumber(dataCutoff, lower = 0, upper = 1, add = errorMessages)
  checkmate::assertCharacter(targetLabel, len = 1, add = errorMessages)
  checkmate::assertCharacter(comparatorLabel, len = 1, add = errorMessages)
  checkmate::assertCharacter(title, len = 1, null.ok = TRUE, add = errorMessages)
  checkmate::assertCharacter(fileName, len = 1, null.ok = TRUE, add = errorMessages)
  checkmate::reportAssertions(collection = errorMessages)

  if (nrow(population) == 0) {
    warning("Population is empty. Cannot plot KM curves.")
    return(NULL)
  }

  data <- prepareKaplanMeier(population, dataCutoff)

  vizData <- data |>
    mutate(treatment = factor(if_else(.data$treatment == 1, targetLabel, comparatorLabel),
                              levels = c(targetLabel, comparatorLabel)))

  xBreaks <- data |>
    filter(!is.na(.data$nAtRisk)) |>
    distinct(.data$time) |>
    pull()

  cutoff <- max(data$time)
  xLabel <- "Time in days"
  yLabel <- "Survival probability"
  xlims <- c(-cutoff / 40, cutoff)

  if (includeZero) {
    ylims <- c(0, 1)
  } else if (confidenceIntervals) {
    ylims <- c(min(data$lower), 1)
  } else {
    ylims <- c(min(data$surv), 1)
  }

  plot <- ggplot2::ggplot(vizData, ggplot2::aes(
    x = .data$time,
    y = .data$survival,
    group = .data$treatment,
    color = .data$treatment,
    fill = .data$treatment,
    ymin = .data$lower,
    ymax = .data$upper
  ))

  if (confidenceIntervals) {
    plot <- plot + ggplot2::geom_ribbon(color = rgb(0, 0, 0, alpha = 0))
  }

  plot <- plot +
    ggplot2::geom_step(linewidth = 1) +
    ggplot2::scale_color_manual(values = c(
      rgb(0.8, 0, 0, alpha = 0.8),
      rgb(0, 0, 0.8, alpha = 0.8)
    )) +
    ggplot2::scale_fill_manual(values = c(
      rgb(0.8, 0, 0, alpha = 0.3),
      rgb(0, 0, 0.8, alpha = 0.3)
    )) +
    ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
    ggplot2::scale_y_continuous(yLabel, limits = ylims) +
    ggplot2::theme(
      legend.title = ggplot2::element_blank(),
      legend.position = "top",
      plot.title = ggplot2::element_text(hjust = 0.5)
    )

  if (censorMarks == TRUE) {
    plot <- plot + ggplot2::geom_point(
      data = vizData |> filter(.data$nCensor >= 1),
      ggplot2::aes(x = .data$time, y = .data$survival),
      shape = "|",
      size = 3
    )
  }
  if (!is.null(title)) {
    plot <- plot + ggplot2::ggtitle(title)
  }
  if (dataTable) {
    labels <- bind_rows(
      tibble(
        x = 0,
        y = "Number at risk",
        label = ""
      ),
      data |>
        filter(!is.na(.data$nAtRisk)) |>
        transmute(x = .data$time,
                  y = if_else(.data$treatment == 1, targetLabel, comparatorLabel),
                  label = formatC(.data$nAtRisk, big.mark = ","))
    ) |>
      mutate(y = factor(.data$y, levels = c(comparatorLabel, targetLabel, "Number at risk")))

    dataTable <- ggplot2::ggplot(labels, ggplot2::aes(x = .data$x, y = .data$y, label = .data$label)) +
      ggplot2::geom_text(size = 3.5, vjust = 0.5) +
      ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
      ggplot2::theme(
        panel.grid.major = ggplot2::element_blank(),
        panel.grid.minor = ggplot2::element_blank(),
        legend.position = "none",
        panel.border = ggplot2::element_blank(),
        panel.background = ggplot2::element_blank(),
        axis.text.x = ggplot2::element_text(color = "white"),
        axis.title.x = ggplot2::element_text(color = "white"),
        axis.title.y = ggplot2::element_blank(),
        axis.ticks = ggplot2::element_line(color = "white")
      )
    plots <- list(plot, dataTable)
    grobs <- widths <- list()
    for (i in 1:length(plots)) {
      grobs[[i]] <- ggplot2::ggplotGrob(plots[[i]])
      widths[[i]] <- grobs[[i]]$widths[2:5]
    }
    maxwidth <- do.call(grid::unit.pmax, widths)
    for (i in 1:length(grobs)) {
      grobs[[i]]$widths[2:5] <- as.list(maxwidth)
    }
    plot <- gridExtra::grid.arrange(grobs[[1]], grobs[[2]], heights = c(400, 100))
  }

  if (!is.null(fileName)) {
    ggplot2::ggsave(fileName, plot, width = 7, height = 5, dpi = 400)
  }
  return(plot)
}

prepareKaplanMeier <- function(population, dataCutoff = 0.90) {
  # Prepare curve data -----------------------------------------------------------------------------
  population <- population |>
    mutate(y = if_else(.data$outcomeCount == 0, 0, 1))
  if (!"stratumId" %in% colnames(population) || length(unique(population$stratumId)) == nrow(population) / 2) {
    message("No strata or 1-on-1 matching detected. Using unadjusted KM.")
    sv <- survival::survfit(survival::Surv(survivalTime, y) ~ treatment, population, conf.int = TRUE)
    if (all(population$treatment == 1) || all(population$treatment == 0)) {
      # The strata property disappears when there is only one stratum:
      treatment <- population$treatment[1]
    } else {
      treatment <- if_else(summary(sv, censored = T)$strata == "treatment=1", 1, 0)
    }
    data <- tibble(
      time = sv$time,
      nCensor = sv$n.censor,
      survival = sv$surv,
      treatment = treatment,
      upper = sv$upper,
      lower = sv$lower
    )
  } else {
    message("Variable size strata detected. Using adjusted KM for stratified data.")
    strataSizes <- population |>
      group_by(.data$stratumId) |>
      summarise(stratumSizeT = sum(.data$treatment == 1),
                stratumSizeC = sum(.data$treatment == 0))
      weights <- strataSizes |>
        mutate(weight = .data$stratumSizeT / .data$stratumSizeC)
    population <- population |>
      inner_join(weights |>
                   select("stratumId", "weight"),
                 by = join_by("stratumId")) |>
      mutate(weight = if_else(.data$treatment == 1, 1, .data$weight))
    idx <- population$treatment == 1
    survTarget <- adjustedKm(
      weight = population$weight[idx],
      time = population$survivalTime[idx],
      y = population$y[idx]
    )
    survTarget <- survTarget |>
      mutate(treatment = 1)
    idx <- population$treatment == 0
    survComparator <- adjustedKm(
      weight = population$weight[idx],
      time = population$survivalTime[idx],
      y = population$y[idx]
    )
    survComparator <- survComparator |>
      mutate(treatment = 0)
    data <- bind_rows(survTarget, survComparator) |>
      mutate(upper = .data$s^exp(qnorm(1 - 0.025) / log(.data$s) * sqrt(.data$var) / .data$s),
             lower = .data$s^exp(qnorm(0.025) / log(.data$s) * sqrt(.data$var) / .data$s)) |>
      mutate(lower = if_else(.data$s > 0.9999, .data$s, .data$lower)) |>
      select(-"var") |>
      rename(survival = "s")

    # Add censor data
    data <- data |>
      inner_join(
        population |>
          group_by(.data$treatment, .data$survivalTime) |>
          summarise(nCensor = n() - sum(.data$y), .groups = "drop") |>
          rename(time = "survivalTime"),
        by = join_by("treatment", "time")
      )
  }

  # Prepare table data -----------------------------------------------------------------------------
  cutoff <- quantile(population$survivalTime, dataCutoff)
  if (cutoff <= 300) {
    timeBreaks <- seq(0, cutoff, by = 50)
  } else if (cutoff <= 600) {
    timeBreaks <- seq(0, cutoff, by = 100)
  } else {
    timeBreaks <- seq(0, cutoff, by = 250)
  }
  dataTable <- tibble(
    time = rep(timeBreaks, 2),
    treatment = rep(c(1, 0), each = length(timeBreaks)),
    nAtRisk = NA
  )
  for (i in seq_len(nrow(dataTable))) {
    dataTable$nAtRisk[i] <- population |>
      filter(.data$treatment == dataTable$treatment[i], .data$survivalTime >= dataTable$time[i]) |>
      count() |>
      pull()
  }

  # Combine ----------------------------------------------------------------------------------------
  combined <- data |>
    filter(time <= cutoff) |>
    full_join(dataTable, by = join_by("treatment", "time")) |>
    mutate(nCensor = if_else(is.na(.data$nCensor), 0, .data$nCensor))

  while (sum(is.na(combined$survival)) > 0) {
    combined <- combined |>
      group_by(.data$treatment) |>
      mutate(survival = if_else(is.na(.data$survival),
                                lag(.data$survival, order_by = .data$time, default = 1),
                                .data$survival),
             lower = if_else(is.na(.data$lower),
                             lag(.data$lower, order_by = .data$time, default = 1),
                             .data$lower),
             upper = if_else(is.na(.data$upper),
                             lag(.data$upper, order_by = .data$time, default = 1),
                             .data$upper)) |>
      ungroup()
  }

  return(combined)
}

Try the CohortMethod package in your browser

Any scripts or data that you put into this service are public.

CohortMethod documentation built on March 21, 2026, 5:06 p.m.