R/calc_km_pi.R

Defines functions summary.survparamsim.kmpi print.survparamsim.kmpi plot_km_pi calc_km_pi

Documented in calc_km_pi plot_km_pi print.survparamsim.kmpi summary.survparamsim.kmpi

#' Generate Kaplan-Meier curves with prediction intervals from parametric bootstrap simulation
#'
#' @export
#' @param sim A `survparamsim` class object generated by [surv_param_sim()] function.
#' @param trt An optional string to specify which column define treatment status.
#' You will have survival curves with different colors in [plot_km_pi()] function.
#' @param group Optional string(s) to specify grouping variable(s).
#' You will have faceted survival curves for these variables in [plot_km_pi()] function.
#' @param pi.range Prediction interval for simulated survival curves.
#' @param calc.obs A logical to specify whether KM estimates will be performed
#' for the observed data. Need be set as FALSE if survival information in the `newdata` is dummy.
#' @param simtimelast An optional numeric to specify last simulation time for survival curve.
#' If NULL (default), the last observation time in the `newdata` will be used.
#' @param trt.assign Specify which of the categories of `trt` need to be considered as control group.
#' See details below if you have more than two categories. Only applicable if you will use
#' [extract_medsurv_delta_pi()] to extract delta of median survival times.
#'
#' @details
#' If your `trt` has more than two categories/levels and want to specify which one to use as a
#' reference group, you can convert the column into a factor in the `newdata` input for
#' [surv_param_sim()]. The first level will be used as a reference group.
#'
#'
#'
calc_km_pi <- function(sim, trt=NULL, group=NULL, pi.range = 0.95,
                       calc.obs = TRUE, simtimelast = NULL,
                       trt.assign = c("default", "reverse")){

  # Replace nest with packageVersion("tidyr") == '1.0.0' for a speed issue
  # See https://github.com/tidyverse/tidyr/issues/751
  nest2 <- ifelse(utils::packageVersion("tidyr") == '1.0.0', tidyr::nest_legacy, tidyr::nest)
  unnest2 <- ifelse(utils::packageVersion("tidyr") == '1.0.0', tidyr::unnest_legacy, tidyr::unnest)

  trt.assign <- match.arg(trt.assign)

  if(methods::is(sim, "survparamsim_pre_resampled")){
    if(sim$newdata.orig.missing & calc.obs) {
      warning("Original observed data not provided in `surv_param_sim_pre_resampled()` and KM will not be estimated for the observed data. Speficy `calc.obs = FALSE` to avoid this warning.")
      calc.obs = FALSE
    }
  }

  if(length(trt) > 1) stop("`trt` can only take one string")

  # This needs to be kept as syms - rlang::sym() fails with trt=NULL
  trt.syms   <- rlang::syms(trt)
  group.syms <- rlang::syms(group)
  # This is needed to handle when the same variable is used for both `group` and `trt`
  group.trt.syms <- rlang::syms(unique(c(group, trt)))

  if(length(trt.syms) + length(group.syms) > length(group.trt.syms)){
    warning(paste("Use of the same variable for `group` and `trt` is discouraged.",
                  "If you need a colored & faceted plot, please consider assigning",
                  "your variable to `trt`, and on the plot generated from `plot_km_pi()`,",
                  "apply `facet_wrap()` or `facet_grid()`"))
  }

  ## time for output
  if(is.null(simtimelast)){
    t.out <- seq(0, sim$t.last.orig.new, length.out = 100)
  } else {
    t.out <- seq(0, simtimelast, length.out = round(100 * max(simtimelast/sim$t.last.orig.new, 1)))
  }


  # Define function to approximate or extract KM curves from KM fit object
  approx_km <- function(x){
    timetmp = x$time
    survtmp = x$surv
    if(min(timetmp) != 0) {
      timetmp = c(0, timetmp)
      survtmp = c(1, survtmp)
    }
    surv <- stats::approx(timetmp, survtmp, xout=t.out, method="constant", rule=2)$y
    data.frame(time = t.out,
               surv = surv)
  }
  extract_km_obs <- function(x){
    data.frame(time = c(0, x$time),
               surv = c(1, x$surv),
               cnsr = c(0, x$n.censor))
  }


  if(calc.obs){
    # Fit K-M curve to observed data
    obs.grouped <-
      sim$newdata.nona.obs %>%
      dplyr::group_by(!!!group.trt.syms)

    if(length(dplyr::group_vars(obs.grouped)) == 0 &
       utils::packageVersion("tidyr") >= '1.0.0') {
      obs.nested <-
        obs.grouped %>%
        nest2(data = dplyr::everything())
    } else {
      obs.nested <- nest2(obs.grouped)
    }


    ## Define formula
    formula <-
      paste(attributes(formula(sim$survreg))$variables,"~1")[2] %>%
      stats::as.formula()


    ## Calc median and KM curve
    obs.km.nested <-
      obs.nested %>%
      dplyr::mutate(kmfit = purrr::map(data, function(x) survival::survfit(formula, data=x))) %>%
      dplyr::mutate(median = purrr::map_dbl(kmfit, function(x) summary(x)$table["median"]),
                    n      = purrr::map_dbl(kmfit, function(x) summary(x)$table["records"]),
                    km = purrr::map(kmfit, extract_km_obs))

    obs.km <-
      obs.km.nested %>%
      dplyr::select(-data, -kmfit) %>%
      unnest2(km) %>%
      dplyr::ungroup() %>%
      dplyr::filter(!is.na(surv)) %>%
      dplyr::select(-median)

    obs.median.time <-
      obs.km.nested %>%
      dplyr::select(!!!group.trt.syms, median, n)

  } else {
    obs.km <- NULL
    obs.median.time <- NULL
  }


  # Calculate percentiles for simulated data

  ## First nest data - KM fit will done for each nested data

  newdata.group <-
    sim$newdata.nona.sim %>%
    dplyr::select(subj.sim, !!!group.trt.syms)

  sim.grouped <-
    sim$sim %>%
    dplyr::left_join(newdata.group, by = "subj.sim") %>%
    dplyr::group_by(rep, !!!group.trt.syms)

  sim.nested <- nest2(sim.grouped)

  sim.km <-
    sim.nested %>%
    # Fit each nested data to KM
    dplyr::mutate(kmfit =
                    purrr::map(data, function(x) survival::survfit(Surv(time, event)~1, data=x))) %>%
    # Calc median and KM curve
    dplyr::mutate(median = purrr::map_dbl(kmfit, function(x) summary(x)$table["median"]),
                  n      = purrr::map_dbl(kmfit, function(x) summary(x)$table["records"]),
                  km = purrr::map(kmfit, approx_km)) %>%
    dplyr::arrange(rep, !!!group.trt.syms)


  ## Calc quantile for survival curves
  sim.km.quantile <-
    sim.km %>%
    dplyr::select(-data, -kmfit) %>%
    unnest2(km) %>%
    dplyr::group_by(!!!group.trt.syms, time) %>%
    nest2() %>%
    dplyr::mutate(quantiles = purrr::map(data, function(x)
      dplyr::summarize(x,
                       pi_low = stats::quantile(surv, probs = 0.5 - pi.range/2),
                       pi_med = stats::quantile(surv, probs = 0.5),
                       pi_high= stats::quantile(surv, probs = 0.5 + pi.range/2)))) %>%
    unnest2(quantiles) %>%
    dplyr::ungroup() %>%
    dplyr::select(-data)


  ## Calc quantiles for median survival time
  sim.median.time <-
    sim.km %>%
    dplyr::select(rep, !!!group.trt.syms, median, n) %>%
    dplyr::ungroup()

  quantiles <-
    tibble::tibble(description = c("pi_low", "pi_med", "pi_high"),
                   quantile = c(0.5 - pi.range/2, 0.5, 0.5 + pi.range/2))

  sim.median.pi <-
    sim.median.time %>%
    dplyr::group_by(!!!group.trt.syms) %>%
    dplyr::summarize(pi_low = as.numeric(stats::quantile(median, probs = 0.5 - pi.range/2, na.rm = TRUE)),
                     pi_med = as.numeric(stats::quantile(median, probs = 0.5, na.rm = TRUE)),
                     pi_high= as.numeric(stats::quantile(median, probs = 0.5 + pi.range/2, na.rm = TRUE)),
                     n_min = min(n),
                     n_max = max(n),
                     n     = min(n)) %>%
    dplyr::ungroup() %>%
    tidyr::gather(description, median, pi_low:pi_high) %>%
    dplyr::left_join(quantiles, by = "description")


  # Check NA in median time and give warning
  median.time.na.detail <-
    sim.median.time %>%
    dplyr::mutate(is.median.na = is.na(median)) %>%
    dplyr::group_by(!!!group.trt.syms) %>%
    dplyr::summarize(N.median.NA = sum(is.median.na),
                    N.all = dplyr::n()) %>%
    dplyr::ungroup()

  median.time.na.overall <-
    median.time.na.detail %>%
    dplyr:: summarize(N.median.NA = sum(N.median.NA),
                      N.all = sum(N.all))

  if (median.time.na.overall$N.median.NA > 0) {
    warning(paste0(median.time.na.overall$N.median.NA, " of ", median.time.na.overall$N.all,
                   " simulations (#rep * #trt * #group) did not reach median survival time and",
                   " these are not included for prediction interval calculation. You may",
                   " want to delay the `censor.dur` in simulation."))
  }



  if(identical(sim.median.pi$n_min, sim.median.pi$n_max)) {
    sim.median.pi <-
      sim.median.pi %>%
      dplyr::select(-n_min, -n_max)
  } else {
    warning("N of subjects are not consistent across simulation replications, either from unstratified resampling or presence of NA in covariates.",
            " In case of former, consider stratified resampling e.g. by `strat.resample` in `surv_param_sim_resample()`")
    sim.median.pi <-
      sim.median.pi %>%
      dplyr::mutate(n = NA)
  }

  if(calc.obs){
    obs.median <-
      obs.median.time %>%
      dplyr::mutate(description = "obs")

    median.pi <-
      dplyr::bind_rows(sim.median.pi, obs.median) %>%
      dplyr::arrange(!!!group.trt.syms)

  } else {
    median.pi <-
      sim.median.pi %>%
      dplyr::arrange(!!!group.trt.syms)
  }


  # Output
  out <- list()

  out$calc.obs <- calc.obs
  out$pi.range   <- pi.range

  out$group.syms <- group.syms
  out$trt.syms    <- trt.syms
  out$group.trt.syms <- group.trt.syms
  out$trt.assign <- trt.assign

  out$simtimelast <- simtimelast
  out$t.last <- sim$t.last.orig.new
  out$censor.dur <- sim$censor.dur

  out$obs.km <- obs.km
  out$obs.median.time <- obs.median.time

  out$sim.km <- sim.km
  out$sim.km.quantile <- sim.km.quantile
  out$sim.median.time <- sim.median.time

  out$median.pi <- median.pi

  structure(out, class = c("survparamsim.kmpi"))
}


#' Plot Kaplan-Meier curves with prediction intervals from parametric bootstrap simulation
#'
#' Need to think about how to apply this for subgroups
#'
#' @export
#' @param km.pi an output from \code{\link{calc_km_pi}} function.
#' @param show.obs A logical specifying whether to show observed K-M curve on the plot.
#'   This will have no effect if `calc.obs` was set to `FALSE` in \code{\link{calc_km_pi}}.
#' @param trunc.sim.censor A logical specifying whether to truncate the simulated
#' curve at the last time of `censor.dur` specified in \code{\link{surv_param_sim}}.
#'
plot_km_pi <- function(km.pi, show.obs = TRUE, trunc.sim.censor = TRUE){

  obs.km <- km.pi$obs.km
  sim.km.quantile.plot <- extract_km_pi(km.pi, trunc.sim.censor = trunc.sim.censor)

  group.syms <- km.pi$group.syms
  trt.syms    <- km.pi$trt.syms


  # Plot
  ## Generate ggplot object with aes specified using simulated data
  if(length(trt.syms) == 0) {
    g <-
      ggplot2::ggplot(sim.km.quantile.plot,
                      ggplot2::aes(time))
  } else {
    g <-
      ggplot2::ggplot(sim.km.quantile.plot,
                      ggplot2::aes(time, color = factor(!!!trt.syms),
                                   fill = factor(!!!trt.syms)))
    color.lab <- as.character(trt.syms[[1]])
  }

  ## Observed
  if(km.pi$calc.obs & show.obs) {
    g <-
      g +
      ggplot2::geom_step(data = obs.km,
                         ggplot2::aes(y = surv), size = 1) +
      ggplot2::geom_point(data = dplyr::filter(obs.km, cnsr > 0),
                          ggplot2::aes(y = surv), shape = "|", size = 3)
  }

  ## Simulated
  if(length(trt.syms) == 0) {
    g <-
      g + ggplot2::geom_ribbon(ggplot2::aes(ymin = pi_low, ymax = pi_high),
                               alpha = 0.4)
  } else {
    g <-
      g + ggplot2::geom_ribbon(ggplot2::aes(ymin = pi_low, ymax = pi_high),
                               alpha = 0.4) +
      ggplot2::labs(color = color.lab, fill = color.lab)
  }


  # Facet fig based on group
  if(length(group.syms) == 1 || length(group.syms) >= 3 ) {
    g <- g + ggplot2::facet_wrap(ggplot2::vars(!!!group.syms),
                                 labeller = ggplot2::label_both)
  } else if (length(group.syms) == 2) {
    g <- g + ggplot2::facet_grid(ggplot2::vars(!!group.syms[[1]]),
                                 ggplot2::vars(!!group.syms[[2]]),
                                 labeller = ggplot2::label_both)
  }

  return(g)

}





#' @rdname survparamsim-methods
#' @export
print.survparamsim.kmpi <- function(x, ...){
  trt <- as.character(x$trt.syms)
  group <- as.character(x$group.syms)

  cat("---- Simulated and observed (if calculated) survival curves ----\n")
  cat("* Use `extract_medsurv_pi()` to extract prediction intervals of median survival times\n")
  cat("* Use `extract_km_pi()` to extract prediction intervals of K-M curves\n")
  cat("* Use `plot_km_pi()` to draw survival curves\n\n")
  cat("* Settings:\n")
  cat("    trt:", ifelse(is.null(trt), "(NULL)", trt), "\n", sep=" ")
  cat("    group:", ifelse(is.null(group), "(NULL)", group), "\n", sep=" ")
  cat("    pi.range:", x$pi.range, "\n", sep=" ")
  cat("    calc.obs:", x$calc.obs, "\n", sep=" ")


}


#' @rdname survparamsim-methods
#' @export
summary.survparamsim.kmpi <- function(object, ...) {

  return(extract_medsurv_pi(object))
}


# print.summary.survparamsim.kmpi <- function(x, ...){
#   cat("Predicted and observed (if calculated) median event time: \n")
#   print(tibble::as_tibble(x))
# }

Try the survParamSim package in your browser

Any scripts or data that you put into this service are public.

survParamSim documentation built on June 3, 2022, 9:06 a.m.