R/KaplanMeier.R

Defines functions plotKaplanMeier

Documented in plotKaplanMeier

# @file KaplanMeier.R
#
# Copyright 2021 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


#' Plot the Kaplan-Meier curve
#'
#' @description
#' \code{plotKaplanMeier} creates the Kaplan-Meier (KM) survival plot. Based (partially) on recommendations
#' in Pocock et al (2002).
#'
#' When variable-sized strata are detected, an adjusted KM plot is computed to account for stratified data,
#' as described in Galimberti eta al (2002), using the closed form variance estimator described in Xie et al
#' (2005).
#'
#' @param population            A population object generated by \code{createStudyPopulation},
#'                              potentially filtered by other functions.
#' @param censorMarks           Whether or not to include censor marks in the plot.
#' @param confidenceIntervals   Plot 95 percent confidence intervals? Default is TRUE, as recommended by Pocock et al.
#' @param includeZero           Should the y axis include zero, or only go down to the lowest observed
#'                              survival? The default is FALSE, as recommended by Pocock et al.
#' @param dataTable             Should the numbers at risk be shown in a table? Default is TRUE, as recommended by Pocock et al.
#' @param dataCutoff            Fraction of the data (number censored) after which the graph will not
#'                              be shown. The default is 90 percent as recommended by Pocock et al.
#' @param targetLabel           A label to us for the target cohort.
#' @param comparatorLabel       A label to us for the comparator cohort.
#' @param title                 The main title of the plot.
#' @param fileName              Name of the file where the plot should be saved, for example
#'                              'plot.png'. See the function \code{ggsave} in the ggplot2 package for
#'                              supported file formats.
#'
#' @return
#' A ggplot object. Use the \code{\link[ggplot2]{ggsave}} function to save to file in a different
#' format.
#'
#' @references
#' Pocock SJ, Clayton TC, Altman DG. (2002) Survival plots of time-to-event outcomes in clinical trials:
#' good practice and pitfalls, Lancet, 359:1686-89.
#'
#' Galimberti S, Sasieni P, Valsecchi MG (2002) A weighted Kaplan-Meier estimator for matched data with
#' application to the comparison of chemotherapy and bone-marrow transplant in leukaemia. Statistics in
#' Medicine, 21(24):3847-64.
#'
#' Xie J, Liu C. (2005) Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment
#' weighting for survival data. Statistics in Medicine, 26(10):2276.
#'
#' @export
plotKaplanMeier <- function(population,
                            censorMarks = FALSE,
                            confidenceIntervals = TRUE,
                            includeZero = FALSE,
                            dataTable = TRUE,
                            dataCutoff = 0.90,
                            targetLabel = "Treated",
                            comparatorLabel = "Comparator",
                            title,
                            fileName = NULL) {
  population$y <- 0
  population$y[population$outcomeCount != 0] <- 1
  if (is.null(population$stratumId) || length(unique(population$stratumId)) == nrow(population)/2) {
    sv <- survival::survfit(survival::Surv(survivalTime, y) ~ treatment, population, conf.int = TRUE)
    data <- data.frame(time = sv$time,
                       n.censor = sv$n.censor,
                       s = sv$surv,
                       strata = summary(sv, censored = T)$strata,
                       upper = sv$upper,
                       lower = sv$lower)
    levels(data$strata)[levels(data$strata) == "treatment=0"] <- comparatorLabel
    levels(data$strata)[levels(data$strata) == "treatment=1"] <- targetLabel
  } else {
    ParallelLogger::logInfo("Variable size strata detected so using adjusted KM for stratified data")
    population$stratumSizeT <- 1
    strataSizesT <- aggregate(stratumSizeT ~ stratumId, population[population$treatment == 1,], sum)
    if (max(strataSizesT$stratumSizeT) == 1) {
      # variable ratio matching: use propensity score to compute IPTW
      if (is.null(population$propensityScore)) {
        stop("Variable ratio matching detected, but no propensity score found")
      }
      weights <- aggregate(propensityScore ~ stratumId, population, mean)
      weights$weight <- weights$propensityScore / (1 - weights$propensityScore)
    } else {
      # stratification: infer probability of treatment from subject counts
      strataSizesC <- aggregate(stratumSizeT ~ stratumId, population[population$treatment == 0,], sum)
      colnames(strataSizesC)[2] <- "stratumSizeC"
      weights <- merge(strataSizesT, strataSizesC)
      weights$weight <- weights$stratumSizeT / weights$stratumSizeC
    }
    strataSizesC <- aggregate(stratumSizeT ~ stratumId, population[population$treatment == 0,], sum)
    colnames(strataSizesC)[2] <- "stratumSizeC"
    weights <- merge(strataSizesT, strataSizesC)
    weights$weight <- weights$stratumSizeT / weights$stratumSizeC
    population <- merge(population, weights[, c("stratumId", "weight")])
    population$weight[population$treatment == 1] <- 1
    idx <- population$treatment == 1
    survTarget <- adjustedKm(weight = population$weight[idx],
                             time = population$survivalTime[idx],
                             y = population$y[idx])
    survTarget$strata <- targetLabel
    idx <- population$treatment == 0
    survComparator <- adjustedKm(weight = population$weight[idx],
                                 time = population$survivalTime[idx],
                                 y = population$y[idx])
    survComparator$strata <- comparatorLabel
    if (censorMarks) {
      addCensorData <- function(surv, treatment) {
        censorData <- aggregate(rowId ~ survivalTime, population[population$treatment == treatment, ], length)
        colnames(censorData) <- c("time", "censored")
        eventData <- aggregate(y ~ survivalTime, population, sum)
        colnames(eventData) <- c("time", "events")
        surv <- merge(surv, censorData)
        surv <- merge(surv, eventData, all.x = TRUE)
        surv$n.censor = surv$censored - surv$events
        return(surv)
      }
      survTarget <- addCensorData(survTarget, 1)
      survComparator <- addCensorData(survComparator, 0)
    }

    data <- rbind(survTarget, survComparator)
    data$upper <- data$s^exp(qnorm(1 - 0.025)/log(data$s)*sqrt(data$var)/data$s)
    data$lower <- data$s^exp(qnorm(0.025)/log(data$s)*sqrt(data$var)/data$s)
    data$lower[data$s > 0.9999] <- data$s[data$s > 0.9999]
  }
  data$strata <- factor(data$strata, levels = c(targetLabel, comparatorLabel))
  cutoff <- quantile(population$survivalTime, dataCutoff)
  xLabel <- "Time in days"
  yLabel <- "Survival probability"
  xlims <- c(-cutoff/40, cutoff)

  if (cutoff <= 300) {
    xBreaks <- seq(0, cutoff, by = 50)
  } else if (cutoff <= 600) {
    xBreaks <- seq(0, cutoff, by = 100)
  } else {
    xBreaks <- seq(0, cutoff, by = 250)
  }

  data <- data[data$time <= cutoff, ]
  if (includeZero) {
    ylims <- c(0, 1)
  } else if (confidenceIntervals) {
    ylims <- c(min(data$lower), 1)
  } else {
    ylims <- c(min(data$surv), 1)
  }
  plot <- ggplot2::ggplot(data, ggplot2::aes(x = .data$time,
                                             y = .data$s,
                                             color = .data$strata,
                                             fill = .data$strata,
                                             ymin = .data$lower,
                                             ymax = .data$upper))

  if (confidenceIntervals)
    plot <- plot + ggplot2::geom_ribbon(color = rgb(0, 0, 0, alpha = 0))

  plot <- plot +
    ggplot2::geom_step(size = 1) +
    ggplot2::scale_color_manual(values = c(rgb(0.8, 0, 0, alpha = 0.8),
                                           rgb(0, 0, 0.8, alpha = 0.8))) +
    ggplot2::scale_fill_manual(values = c(rgb(0.8, 0, 0, alpha = 0.3),
                                          rgb(0, 0, 0.8, alpha = 0.3))) +
    ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
    ggplot2::scale_y_continuous(yLabel, limits = ylims) +
    ggplot2::theme(legend.title = ggplot2::element_blank(),
                   legend.position = "top",
                   plot.title = ggplot2::element_text(hjust = 0.5))

  if (censorMarks == TRUE) {
    plot <- plot + ggplot2::geom_point(data = subset(data, .data$n.censor >= 1),
                                       ggplot2::aes(x = .data$time, y = .data$s),
                                       shape = "|",
                                       size = 3)
  }
  if (!missing(title) && !is.null(title)) {
    plot <- plot + ggplot2::ggtitle(title)
  }
  if (dataTable) {
    targetAtRisk <- c()
    comparatorAtRisk <- c()
    for (xBreak in xBreaks) {
      targetAtRisk <- c(targetAtRisk, sum(population$treatment == 1 & population$survivalTime >= xBreak))
      comparatorAtRisk <- c(comparatorAtRisk, sum(population$treatment == 0 & population$survivalTime >= xBreak))
    }
    labels <- data.frame(x = c(0, xBreaks, xBreaks),
                         y = as.factor(c("Number at risk", rep(targetLabel, length(xBreaks)), rep(comparatorLabel, length(xBreaks)))),
                         label = c("", formatC(targetAtRisk, big.mark = ","), formatC(comparatorAtRisk, big.mark = ",")))
    labels$y <- factor(labels$y, levels = c(comparatorLabel, targetLabel, "Number at risk"))
    dataTable <- ggplot2::ggplot(labels, ggplot2::aes(x = .data$x, y = .data$y, label = .data$label)) +
      ggplot2::geom_text(size = 3.5, vjust = 0.5) +
      ggplot2::scale_x_continuous(xLabel, limits = xlims, breaks = xBreaks) +
      ggplot2::theme(panel.grid.major = ggplot2::element_blank(),
                     panel.grid.minor = ggplot2::element_blank(),
                     legend.position = "none",
                     panel.border = ggplot2::element_blank(),
                     panel.background = ggplot2::element_blank(),
                     axis.text.x = ggplot2::element_text(color = "white"),
                     axis.title.x = ggplot2::element_text(color = "white"),
                     axis.title.y = ggplot2::element_blank(),
                     axis.ticks = ggplot2::element_line(color = "white"))
    plots <- list(plot, dataTable)
    grobs <- widths <- list()
    for (i in 1:length(plots)) {
      grobs[[i]] <- ggplot2::ggplotGrob(plots[[i]])
      widths[[i]] <- grobs[[i]]$widths[2:5]
    }
    maxwidth <- do.call(grid::unit.pmax, widths)
    for (i in 1:length(grobs)) {
      grobs[[i]]$widths[2:5] <- as.list(maxwidth)
    }
    plot <- gridExtra::grid.arrange(grobs[[1]], grobs[[2]], heights = c(400,100))
  }

  if (!is.null(fileName))
    ggplot2::ggsave(fileName, plot, width = 7, height = 5, dpi = 400)
  return(plot)
}
escott12/CohortMethod documentation built on Dec. 20, 2021, 6:37 a.m.