Nothing
# @file KaplanMeier.R
#
# Copyright 2026 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#' Plot the Kaplan-Meier curve
#'
#' @description
#' \code{plotKaplanMeier} creates the Kaplan-Meier (KM) survival plot. Based (partially) on recommendations
#' in Pocock et al (2002).
#'
#' When variable-sized strata are detected, an adjusted KM plot is computed to account for stratified data,
#' as described in Galimberti eta al (2002), using the closed form variance estimator described in Xie et al
#' (2005).
#'
#' @param population A population object generated by \code{createStudyPopulation},
#' potentially filtered by other functions.
#' @param censorMarks Whether or not to include censor marks in the plot.
#' @param confidenceIntervals Plot 95 percent confidence intervals? Default is TRUE, as recommended by Pocock et al.
#' @param includeZero Should the y axis include zero, or only go down to the lowest observed
#' survival? The default is FALSE, as recommended by Pocock et al.
#' @param dataTable Should the numbers at risk be shown in a table? Default is TRUE, as recommended by Pocock et al.
#' @param dataCutoff Fraction of the data (number censored) after which the graph will not
#' be shown. The default is 90 percent as recommended by Pocock et al.
#' @param targetLabel A label to us for the target cohort.
#' @param comparatorLabel A label to us for the comparator cohort.
#' @param title The main title of the plot.
#' @param fileName Name of the file where the plot should be saved, for example
#' 'plot.png'. See the function \code{ggsave} in the ggplot2 package for
#' supported file formats.
#'
#' @return
#' A ggplot object. Use the \code{\link[ggplot2]{ggsave}} function to save to file in a different
#' format.
#'
#' @references
#' Pocock SJ, Clayton TC, Altman DG. (2002) Survival plots of time-to-event outcomes in clinical trials:
#' good practice and pitfalls, Lancet, 359:1686-89.
#'
#' Galimberti S, Sasieni P, Valsecchi MG (2002) A weighted Kaplan-Meier estimator for matched data with
#' application to the comparison of chemotherapy and bone-marrow transplant in leukaemia. Statistics in
#' Medicine, 21(24):3847-64.
#'
#' Xie J, Liu C. (2005) Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment
#' weighting for survival data. Statistics in Medicine, 26(10):2276.
#'
#' @export
plotKaplanMeier <- function(population,
censorMarks = FALSE,
confidenceIntervals = TRUE,
includeZero = FALSE,
dataTable = TRUE,
dataCutoff = 0.90,
targetLabel = "Treated",
comparatorLabel = "Comparator",
title = NULL,
fileName = NULL) {
errorMessages <- checkmate::makeAssertCollection()
checkmate::assertDataFrame(population, add = errorMessages)
checkmate::assertLogical(censorMarks, len = 1, add = errorMessages)
checkmate::assertLogical(confidenceIntervals, len = 1, add = errorMessages)
checkmate::assertLogical(includeZero, len = 1, add = errorMessages)
checkmate::assertLogical(dataTable, len = 1, add = errorMessages)
checkmate::assertNumber(dataCutoff, lower = 0, upper = 1, add = errorMessages)
checkmate::assertCharacter(targetLabel, len = 1, add = errorMessages)
checkmate::assertCharacter(comparatorLabel, len = 1, add = errorMessages)
checkmate::assertCharacter(title, len = 1, null.ok = TRUE, add = errorMessages)
checkmate::assertCharacter(fileName, len = 1, null.ok = TRUE, add = errorMessages)
checkmate::reportAssertions(collection = errorMessages)
if (nrow(population) == 0) {
warning("Population is empty. Cannot plot KM curves.")
return(NULL)
}
data <- prepareKaplanMeier(population, dataCutoff)
vizData <- data |>
mutate(treatment = factor(if_else(.data$treatment == 1, targetLabel, comparatorLabel),
levels = c(targetLabel, comparatorLabel)))
xBreaks <- data |>
filter(!is.na(.data$nAtRisk)) |>
distinct(.data$time) |>
pull()
cutoff <- max(data$time)
xLabel <- "Time in days"
yLabel <- "Survival probability"
xlims <- c(-cutoff / 40, cutoff)
if (includeZero) {
ylims <- c(0, 1)
} else if (confidenceIntervals) {
ylims <- c(min(data$lower), 1)
} else {
ylims <- c(min(data$surv), 1)
}
plot <- ggplot2::ggplot(vizData, ggplot2::aes(
x = .data$time,
y = .data$survival,
group = .data$treatment,
color = .data$treatment,
fill = .data$treatment,
ymin = .data$lower,
ymax = .data$upper
))
if (confidenceIntervals) {
plot <- plot + ggplot2::geom_ribbon(color = rgb(0, 0, 0, alpha = 0))
}
plot <- plot +
ggplot2::geom_step(linewidth = 1) +
ggplot2::scale_color_manual(values = c(
rgb(0.8, 0, 0, alpha = 0.8),
rgb(0, 0, 0.8, alpha = 0.8)
)) +
ggplot2::scale_fill_manual(values = c(
rgb(0.8, 0, 0, alpha = 0.3),
rgb(0, 0, 0.8, alpha = 0.3)
)) +
ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
ggplot2::scale_y_continuous(yLabel, limits = ylims) +
ggplot2::theme(
legend.title = ggplot2::element_blank(),
legend.position = "top",
plot.title = ggplot2::element_text(hjust = 0.5)
)
if (censorMarks == TRUE) {
plot <- plot + ggplot2::geom_point(
data = vizData |> filter(.data$nCensor >= 1),
ggplot2::aes(x = .data$time, y = .data$survival),
shape = "|",
size = 3
)
}
if (!is.null(title)) {
plot <- plot + ggplot2::ggtitle(title)
}
if (dataTable) {
labels <- bind_rows(
tibble(
x = 0,
y = "Number at risk",
label = ""
),
data |>
filter(!is.na(.data$nAtRisk)) |>
transmute(x = .data$time,
y = if_else(.data$treatment == 1, targetLabel, comparatorLabel),
label = formatC(.data$nAtRisk, big.mark = ","))
) |>
mutate(y = factor(.data$y, levels = c(comparatorLabel, targetLabel, "Number at risk")))
dataTable <- ggplot2::ggplot(labels, ggplot2::aes(x = .data$x, y = .data$y, label = .data$label)) +
ggplot2::geom_text(size = 3.5, vjust = 0.5) +
ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
ggplot2::theme(
panel.grid.major = ggplot2::element_blank(),
panel.grid.minor = ggplot2::element_blank(),
legend.position = "none",
panel.border = ggplot2::element_blank(),
panel.background = ggplot2::element_blank(),
axis.text.x = ggplot2::element_text(color = "white"),
axis.title.x = ggplot2::element_text(color = "white"),
axis.title.y = ggplot2::element_blank(),
axis.ticks = ggplot2::element_line(color = "white")
)
plots <- list(plot, dataTable)
grobs <- widths <- list()
for (i in 1:length(plots)) {
grobs[[i]] <- ggplot2::ggplotGrob(plots[[i]])
widths[[i]] <- grobs[[i]]$widths[2:5]
}
maxwidth <- do.call(grid::unit.pmax, widths)
for (i in 1:length(grobs)) {
grobs[[i]]$widths[2:5] <- as.list(maxwidth)
}
plot <- gridExtra::grid.arrange(grobs[[1]], grobs[[2]], heights = c(400, 100))
}
if (!is.null(fileName)) {
ggplot2::ggsave(fileName, plot, width = 7, height = 5, dpi = 400)
}
return(plot)
}
prepareKaplanMeier <- function(population, dataCutoff = 0.90) {
# Prepare curve data -----------------------------------------------------------------------------
population <- population |>
mutate(y = if_else(.data$outcomeCount == 0, 0, 1))
if (!"stratumId" %in% colnames(population) || length(unique(population$stratumId)) == nrow(population) / 2) {
message("No strata or 1-on-1 matching detected. Using unadjusted KM.")
sv <- survival::survfit(survival::Surv(survivalTime, y) ~ treatment, population, conf.int = TRUE)
if (all(population$treatment == 1) || all(population$treatment == 0)) {
# The strata property disappears when there is only one stratum:
treatment <- population$treatment[1]
} else {
treatment <- if_else(summary(sv, censored = T)$strata == "treatment=1", 1, 0)
}
data <- tibble(
time = sv$time,
nCensor = sv$n.censor,
survival = sv$surv,
treatment = treatment,
upper = sv$upper,
lower = sv$lower
)
} else {
message("Variable size strata detected. Using adjusted KM for stratified data.")
strataSizes <- population |>
group_by(.data$stratumId) |>
summarise(stratumSizeT = sum(.data$treatment == 1),
stratumSizeC = sum(.data$treatment == 0))
weights <- strataSizes |>
mutate(weight = .data$stratumSizeT / .data$stratumSizeC)
population <- population |>
inner_join(weights |>
select("stratumId", "weight"),
by = join_by("stratumId")) |>
mutate(weight = if_else(.data$treatment == 1, 1, .data$weight))
idx <- population$treatment == 1
survTarget <- adjustedKm(
weight = population$weight[idx],
time = population$survivalTime[idx],
y = population$y[idx]
)
survTarget <- survTarget |>
mutate(treatment = 1)
idx <- population$treatment == 0
survComparator <- adjustedKm(
weight = population$weight[idx],
time = population$survivalTime[idx],
y = population$y[idx]
)
survComparator <- survComparator |>
mutate(treatment = 0)
data <- bind_rows(survTarget, survComparator) |>
mutate(upper = .data$s^exp(qnorm(1 - 0.025) / log(.data$s) * sqrt(.data$var) / .data$s),
lower = .data$s^exp(qnorm(0.025) / log(.data$s) * sqrt(.data$var) / .data$s)) |>
mutate(lower = if_else(.data$s > 0.9999, .data$s, .data$lower)) |>
select(-"var") |>
rename(survival = "s")
# Add censor data
data <- data |>
inner_join(
population |>
group_by(.data$treatment, .data$survivalTime) |>
summarise(nCensor = n() - sum(.data$y), .groups = "drop") |>
rename(time = "survivalTime"),
by = join_by("treatment", "time")
)
}
# Prepare table data -----------------------------------------------------------------------------
cutoff <- quantile(population$survivalTime, dataCutoff)
if (cutoff <= 300) {
timeBreaks <- seq(0, cutoff, by = 50)
} else if (cutoff <= 600) {
timeBreaks <- seq(0, cutoff, by = 100)
} else {
timeBreaks <- seq(0, cutoff, by = 250)
}
dataTable <- tibble(
time = rep(timeBreaks, 2),
treatment = rep(c(1, 0), each = length(timeBreaks)),
nAtRisk = NA
)
for (i in seq_len(nrow(dataTable))) {
dataTable$nAtRisk[i] <- population |>
filter(.data$treatment == dataTable$treatment[i], .data$survivalTime >= dataTable$time[i]) |>
count() |>
pull()
}
# Combine ----------------------------------------------------------------------------------------
combined <- data |>
filter(time <= cutoff) |>
full_join(dataTable, by = join_by("treatment", "time")) |>
mutate(nCensor = if_else(is.na(.data$nCensor), 0, .data$nCensor))
while (sum(is.na(combined$survival)) > 0) {
combined <- combined |>
group_by(.data$treatment) |>
mutate(survival = if_else(is.na(.data$survival),
lag(.data$survival, order_by = .data$time, default = 1),
.data$survival),
lower = if_else(is.na(.data$lower),
lag(.data$lower, order_by = .data$time, default = 1),
.data$lower),
upper = if_else(is.na(.data$upper),
lag(.data$upper, order_by = .data$time, default = 1),
.data$upper)) |>
ungroup()
}
return(combined)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.