R/estimateSurvival.R

Defines functions .fit_parametric_for_group .build_parametric_survivalList .fit_survfit_for_group .logrank_from_df .make_age_labels .validate_age_groups singleEventSurvival

Documented in singleEventSurvival

#' Kaplan–Meier survival by simple strata (gender and/or age groups)
#'
#' @description
#' Compute Kaplan–Meier (KM) survival curves overall and, optionally, by simple
#' strata derived from an existing `gender` column and/or an `age_group` built
#' from list-of-range breaks using `age_years`. Results are returned per stratum
#' plus an `"overall"` entry. Additionally, log-rank tests (overall and
#' pairwise) are computed when strata are specified.
#'
#' @details
#' - Input follow-up `time` is supplied in **days** and internally rescaled to
#'   the requested `timeScale` for reporting (`"days"`, `"weeks"`, `"months"`,
#'   `"years"`).
#' - If `"age_group"` is included in `strata`, you must provide an `age_years`
#'   column. Age-group labels are generated from `ageBreaks` (e.g., `0-18`,
#'   `19-45`, `46-65`, `66+`), where each element is a numeric range
#'   `c(min_age, max_age)`. Use `Inf` for open-ended upper bounds.
#' - Stratification is **simple**: groups are created from observed levels of
#'   `gender` and/or derived `age_group`. If both are requested, they are
#'   handled **separately** (gender OR age_group), not jointly.
#'   A one-sample KM curve is fit for each non-empty group, plus an `"overall"`
#'   curve for the full data.
#' - Confidence intervals are controlled by `confType` and `confInt` and are
#'   passed to `survival::survfit()`.
#' - The `model` argument controls which survival estimator is fitted:
#'   - `"km"`: non-parametric Kaplan–Meier estimate via [ggsurvfit::survfit()].
#'   - `"cox"`: Cox PH model via [survival::coxph()] + [survival::survfit()].
#'     Without covariates the Breslow baseline hazard is used.  When covariates
#'     are provided, the survival curve is evaluated at the covariate means.
#'   - `"weibull"`, `"exponential"`, `"lognormal"`, `"loglogistic"`: AFT
#'     parametric models via [survival::survreg()].  S(t) is evaluated
#'     analytically at observed event times.  Pointwise CIs are not available
#'     for parametric models (`lower`/`upper` are `NA`).
#' - `covariates` is used only for Cox and parametric models.
#' - `times` and `probs` control quantile extraction; `probs` defaults to
#'   `c(0.75, 0.5, 0.25)` (q75, median, q25).
#' - When strata are specified, a **log-rank test** is performed to compare
#'   survival curves across groups within each stratifier (gender and/or
#'   age_group). The overall test and pairwise tests are included in the
#'   returned object as tibbles.
#'
#' @section Returned object:
#' A list of class `singleEventSurvival`. Elements include:
#' - Per-stratum entries named like `"gender=F"`, `"gender=M"`,
#'   `"age_group=18-44"`, etc., and an `"overall"` element.
#'
#' Each stratum element contains:
#' - `data`: a tibble with KM step data:
#'   `time`, `n_risk`, `n_event`, `n_censor`, `survival`, `std_err`,
#'   optional `lower`, `upper` (when `confInt` > 0), and derived
#'   `hazard`, `cum_hazard`, `cum_event`, `cum_censor`.
#' - `summary`: a list with `n`, `events`, `censored`, `medianSurvival`,
#'   `q25Survival`, `q75Survival`, `meanSurvival`, and `timeScale`.
#'
#' Additionally, if `gender` is in `strata`, a `logrank_test_gender` element
#' is included; if `age_group` is in `strata`, a `logrank_test_age_group`
#' element is included. Each contains:
#' - `testType`: `"overall"` or `"pairwise"`
#' - `stratum1`, `stratum2`: labels of compared strata
#' - `chisq`: chi-square test statistic
#' - `df`: degrees of freedom
#' - `pvalue`: p-value for the test
#'
#' @param survivalData A `data.frame` with required columns:
#'   - `subject_id` (unique id)
#'   - `time` (numeric follow-up in **days**; finite)
#'   - `status` (0/1; 1 = event)
#'   Optional columns for stratification/age grouping: `gender`, `age_years`.
#'   Additional columns may be present but are currently unused.
#' @param timeScale One of `"days"`, `"weeks"`, `"months"`, or `"years"`. Used
#'   only to scale the reported time axis; input `time` is assumed to be days.
#' @param model Survival estimator to fit. One of `"km"` (Kaplan–Meier),
#'   `"cox"` (Cox PH, baseline hazard), `"weibull"`, `"exponential"`,
#'   `"lognormal"`, `"loglogistic"` (AFT parametric models via
#'   [survival::survreg()]).
#' @param covariates Optional character vector of covariate column names used
#'   in Cox and parametric models. Ignored for `model = "km"`.
#' @param strata Optional character vector of stratifying variables. Allowed:
#'   `"gender"`, `"age_group"`. If both are supplied, they are applied
#'   independently (gender OR age_group).
#' @param ageBreaks A list of numeric length-2 vectors defining age ranges for
#'   auto-stratification, e.g.
#'   `list(c(0, 18), c(19, 45), c(46, 65), c(66, Inf))` -> `0-18`, `19-45`,
#'   `46-65`, `66+`. Used only if `"age_group"` is in `strata`.
#' @param times Reserved for future enhancements; currently unused.
#' @param probs Numeric vector of probabilities used to extract quantiles from
#'   KM curves. Default is `c(0.75, 0.5, 0.25)`.
#' @param confInt Numeric confidence level for KM intervals (e.g., `0.95`);
#'   passed as `conf.int` to `survival::survfit()`.
#' @param confType Character string for KM CI type, one of `"log"`, `"log-log"`,
#'   `"plain"`, `"arcsin"`, `"none"`; passed as `conf.type` to
#'   `survival::survfit()`.
#'
#' @return A list of class `singleEventSurvival`. See **Returned object**.
#'
#' @export
singleEventSurvival <- function(
    survivalData,
    timeScale = "days",
    model = "km",
    covariates = NULL,
    strata = NULL,
    ageBreaks = list(c(0, 18), c(19, 45), c(46, 65), c(66, Inf)),
    times = NULL,
    probs = c(0.75, 0.5, 0.25),
    confInt = 0.95,
    confType = "log") {

  # ---- validation ---------------------------------------------------------
  if (!is.data.frame(survivalData)) stop("`survivalData` must be a data.frame.")
  required <- c("subject_id", "time", "status")
  if (!all(required %in% names(survivalData))) {
    stop("`survivalData` must contain: subject_id, time, status.")
  }
  if (!is.numeric(survivalData$time) || any(!is.finite(survivalData$time))) {
    stop("`time` must be finite numerics (in days).")
  }
  if (!all(survivalData$status %in% c(0, 1))) {
    stop("`status` must be 0 (censored) or 1 (event).")
  }

  timeScale <- match.arg(timeScale, c("days", "weeks", "months", "years"))
  model <- match.arg(model, c("km", "cox", "weibull", "exponential", "lognormal", "loglogistic"))
  confType <- match.arg(confType, c("log", "log-log", "plain", "arcsin", "none"))

  # ---- validate ageBreaks if age_group in strata --------------------------
  if (!is.null(strata) && "age_group" %in% strata) {
    if (!is.list(ageBreaks) || length(ageBreaks) == 0) {
      stop("`ageBreaks` must be a non-empty list of numeric vectors when using age_group stratification.")
    }
    for (i in seq_along(ageBreaks)) {
      range_i <- ageBreaks[[i]]
      if (!is.numeric(range_i) || length(range_i) != 2) {
        stop(sprintf("`ageBreaks[[%d]]` must be a numeric vector of length 2 (c(min, max)).", i))
      }
      if (!is.finite(range_i[1])) {
        stop(sprintf("`ageBreaks[[%d]][1]` (min age) must be finite.", i))
      }
      if (range_i[1] > range_i[2] && is.finite(range_i[2])) {
        stop(sprintf("`ageBreaks[[%d]]`: min age (%g) must be <= max age (%g).", i, range_i[1], range_i[2]))
      }
    }
  }

  # ---- scale & clean ------------------------------------------------------
  scaleFactor <- c(days = 1, weeks = 7, months = 30.4375, years = 365.25)[[timeScale]]
  covars <- if (is.null(covariates)) character() else as.character(covariates)

  df             <- survivalData
  df$time_scaled <- df$time / scaleFactor
  df$status      <- as.integer(df$status)

  # ---- derive age_group if needed -----------------------------------------
  if (!is.null(strata) && "age_group" %in% strata && !is.null(ageBreaks)) {
    if (!"age_years" %in% names(df)) {
      stop("`age_years` column required when using age_group stratification.")
    }
    assign_age_group <- function(age, breaks_list) {
      for (range_i in breaks_list) {
        min_age  <- range_i[1]
        max_age  <- range_i[2]
        in_range <- if (is.infinite(max_age)) age >= min_age else age >= min_age & age <= max_age
        if (in_range) {
          return(if (is.infinite(max_age)) paste0(min_age, "+") else paste0(min_age, "-", max_age))
        }
      }
      NA_character_
    }
    df$age_group <- vapply(df$age_years, function(a) assign_age_group(a, ageBreaks), character(1))
  }

  # Normalize strata: we only support gender and age_group, each independently
  strata <- intersect(if (is.null(strata)) character() else strata, c("gender", "age_group"))
  if (length(strata) == 0L) strata <- NULL

  if (model %in% c("km", "cox")) {

    # ---- KM / Cox per-stratum (gender and/or age_group handled separately) -
    survFitModels <- list()

    if (!is.null(strata)) {

      # Gender strata
      if ("gender" %in% strata && "gender" %in% names(df)) {
        gender_levels <- sort(stats::na.omit(unique(df$gender)))
        for (g in gender_levels) {
          .df <- df[!is.na(df$gender) & df$gender == g, ]
          if (nrow(.df) > 0L) {
            survFitModels[[paste0("gender=", g)]] <-
              .fit_survfit_for_group(.df, model, covars, confType, confInt)
          }
        }
      }

      # Age-group strata
      if ("age_group" %in% strata && "age_group" %in% names(df)) {
        age_levels <- sort(stats::na.omit(unique(df$age_group)))
        for (a in age_levels) {
          .df <- df[!is.na(df$age_group) & df$age_group == a, ]
          if (nrow(.df) > 0L) {
            survFitModels[[paste0("age_group=", a)]] <-
              .fit_survfit_for_group(.df, model, covars, confType, confInt)
          }
        }
      }
    }

    # Overall fit (no stratification)
    survFitModels[["overall"]] <- .fit_survfit_for_group(df, model, covars, confType, confInt)

    # ---- summarize KM / Cox fits ------------------------------------------
    kmSum               <- lapply(survFitModels, summary)
    confidenceIntervals <- lapply(kmSum, function(x) x$table)
    kmSumTable          <- lapply(survFitModels, stats::quantile, probs = probs)

    survivalList <- stats::setNames(
      lapply(seq_along(survFitModels), function(i) {
        km_obj <- kmSum[[i]]
        q_obj  <- kmSumTable[[i]]
        ci_obj <- confidenceIntervals[[i]]

        total_n      <- if (length(km_obj$n) > 0L) as.integer(km_obj$n[1]) else 0L
        total_events <- sum(km_obj$n.event, na.rm = TRUE)

        # Handle all-censored or empty data (no event times)
        if (is.null(km_obj$time) || length(km_obj$time) == 0L) {
          return(list(
            data = data.frame(
              time = numeric(0), n_risk = integer(0), n_event = integer(0),
              n_censor = integer(0), lower = numeric(0), upper = numeric(0),
              survival = numeric(0), std_err = numeric(0),
              cum_event = integer(0), cum_censor = integer(0),
              hazard = numeric(0), cum_hazard = numeric(0),
              stringsAsFactors = FALSE
            ),
            summary = list(
              n = total_n, events = 0L, censored = total_n,
              medianSurvival = NA_real_, medianSurvivalLowerCI = NA_real_,
              medianSurvivalUpperCI = NA_real_, q25Survival = NA_real_,
              q75Survival = NA_real_, meanSurvival = NA_real_,
              timeScale = timeScale
            )
          ))
        }

        q_quantile <- if (is.list(q_obj) && !is.null(q_obj$quantile)) {
          q_obj$quantile
        } else {
          q_obj
        }
        median_surv <- if (length(q_quantile) >= 2L) q_quantile[[2]] else NA_real_
        q75_surv    <- if (length(q_quantile) >= 1L) q_quantile[[1]] else NA_real_
        q25_surv    <- if (length(q_quantile) >= 3L) q_quantile[[3]] else NA_real_

        lower_ci <- if (length(ci_obj) >= 8L) ci_obj[[8]] else NA_real_
        upper_ci <- if (length(ci_obj) >= 9L) ci_obj[[9]] else NA_real_

        mean_surv <- if (any(!is.na(km_obj$surv))) {
          sum(diff(c(0, km_obj$time)) * utils::head(c(1, km_obj$surv), -1), na.rm = TRUE)
        } else {
          NA_real_
        }

        n_rows <- length(km_obj$time)
        n_censor_col <- if (!is.null(km_obj$n_censor) && length(km_obj$n_censor) == n_rows) {
          km_obj$n_censor
        } else {
          rep(NA_integer_, n_rows)
        }
        lower_col <- if (!is.null(km_obj$lower) && length(km_obj$lower) == n_rows) {
          km_obj$lower
        } else {
          rep(NA_real_, n_rows)
        }
        upper_col <- if (!is.null(km_obj$upper) && length(km_obj$upper) == n_rows) {
          km_obj$upper
        } else {
          rep(NA_real_, n_rows)
        }
        std_err_col <- if (!is.null(km_obj$std.err) && length(km_obj$std.err) == n_rows) {
          km_obj$std.err
        } else {
          rep(NA_real_, n_rows)
        }

        dat <- data.frame(
          time     = km_obj$time,
          n_risk   = km_obj$n.risk,
          n_event  = km_obj$n.event,
          n_censor = n_censor_col,
          lower    = lower_col,
          upper    = upper_col,
          survival = km_obj$surv,
          std_err  = std_err_col,
          stringsAsFactors = FALSE
        )
        dat <- dat[order(dat$time), ]

        n_event0  <- ifelse(is.na(dat$n_event),  0L, dat$n_event)
        n_censor0 <- ifelse(is.na(dat$n_censor), 0L, dat$n_censor)
        dat$cum_event  <- cumsum(n_event0)
        dat$cum_censor <- cumsum(n_censor0)
        dat$hazard     <- ifelse(dat$n_risk > 0L, n_event0 / dat$n_risk, NA_real_)
        dat$cum_hazard <- -log(pmax(dat$survival, .Machine$double.eps))

        list(
          data = dat,
          summary = list(
            n                     = total_n,
            events                = total_events,
            censored              = total_n - total_events,
            medianSurvival        = if (!is.na(median_surv)) round(median_surv, 1) else NA_real_,
            medianSurvivalLowerCI = if (!is.na(lower_ci)) round(lower_ci, 1) else NA_real_,
            medianSurvivalUpperCI = if (!is.na(upper_ci)) round(upper_ci, 1) else NA_real_,
            q25Survival           = q25_surv,
            q75Survival           = q75_surv,
            meanSurvival          = mean_surv,
            timeScale             = timeScale
          )
        )
      }),
      names(kmSum)
    )

  } else {

    # ---- parametric models: weibull, exponential, lognormal, loglogistic ---
    survivalList <- .build_parametric_survivalList(
      df        = df,
      strata    = strata,
      dist      = model,
      covariates = covars,
      probs     = probs,
      timeScale = timeScale
    )

  }

  # ---- log-rank tests per stratifier (gender OR age_group) ----------------
  if (!is.null(strata)) {

    if ("gender" %in% strata && "gender" %in% names(df)) {
      lr_gender <- .logrank_from_df(df, "gender")
      if (!is.null(lr_gender)) {
        survivalList$logrank_test_gender <- lr_gender[lr_gender$testType != "overall", ]
      }
    }

    if ("age_group" %in% strata && "age_group" %in% names(df)) {
      lr_age <- .logrank_from_df(df, "age_group")
      if (!is.null(lr_age)) {
        survivalList$logrank_test_age_group <- lr_age[lr_age$testType != "overall", ]
      }
    }
  }

  class(survivalList) <- c("singleEventSurvival", "list")
  survivalList
}

# -------------------------------------------------------------------------
# helpers -----------------------------------------------------------------
# -------------------------------------------------------------------------

#' @keywords internal
#' @noRd
.validate_age_groups <- function(age_groups) {
  if (!is.list(age_groups) || length(age_groups) == 0L) {
    stop("`age_groups` must be a non-empty list, e.g. list(c(0,17), c(18,44), c(45,Inf)).")
  }
  ag <- lapply(seq_along(age_groups), function(i) {
    v <- age_groups[[i]]
    if (!is.numeric(v) || length(v) != 2L || any(is.na(v))) {
      stop(sprintf("age_groups[[%d]] must be numeric length-2.", i))
    }
    if (v[1] > v[2]) stop(sprintf("age_groups[[%d]] has lower > upper.", i))
    v
  })
  lowers <- vapply(ag, `[`, numeric(1), 1L)
  uppers <- vapply(ag, `[`, numeric(1), 2L)
  ord    <- order(lowers, uppers)
  lowers <- lowers[ord]
  uppers <- uppers[ord]
  if (length(lowers) > 1L && any(lowers[-1] <= uppers[-length(uppers)])) {
    stop("`age_groups` must be non-overlapping and in increasing order.")
  }
  list(lowers = lowers, uppers = uppers)
}

#' @keywords internal
#' @noRd
.make_age_labels <- function(lowers, uppers) {
  fmt_num <- function(x) prettyNum(x, preserve.width = "none", drop0trailing = TRUE, scientific = FALSE)
  vapply(seq_along(lowers), function(i) {
    lo <- lowers[i]
    up <- uppers[i]
    if (is.infinite(up)) paste0(fmt_num(lo), "+") else paste0(fmt_num(lo), "-", fmt_num(up))
  }, character(1))
}

#' Internal log-rank tests helper used by singleEventSurvival
#'
#' @param df Tibble with columns: time_scaled, status, and a single strata var
#' @param strataVar Character scalar: name of strata variable present in `df`
#'
#' @return A tibble with overall and pairwise log-rank test results, or NULL
#'   if insufficient data.
#' @keywords internal
#' @noRd
.logrank_from_df <- function(df, strataVar) {

  if (!strataVar %in% names(df)) {
    return(NULL)
  }

  # Need at least one event
  if (nrow(df) == 0L || sum(df$status) == 0L) {
    return(NULL)
  }

  strata_col <- strataVar

  # Unique strata levels
  unique_strata <- unique(df[[strata_col]])
  unique_strata <- unique_strata[!is.na(unique_strata)]

  if (length(unique_strata) < 2L) {
    return(NULL)
  }

  # Remove strata with no events
  grp_events         <- tapply(df$status, df[[strata_col]], sum, na.rm = TRUE)
  strata_with_events <- names(grp_events[grp_events > 0L])

  if (length(strata_with_events) < 2L) {
    return(NULL)
  }

  df <- df[df[[strata_col]] %in% strata_with_events, ]

  results_list <- list()

  # Overall log-rank test
  tryCatch({
    surv_obj <- survival::Surv(
      time = df$time_scaled,
      event = df$status
    )

    formula_str <- paste("surv_obj ~", strata_col)
    survdiff_result <- survival::survdiff(
      stats::as.formula(formula_str),
      data = df
    )

    chi_sq <- survdiff_result$chisq
    df_chisq <- length(survdiff_result$n) - 1L
    p_value <- stats::pchisq(chi_sq, df = df_chisq, lower.tail = FALSE)

    overall_result <- data.frame(
      testType = "overall",
      stratum1 = "all",
      stratum2 = "all",
      chisq    = chi_sq,
      df       = df_chisq,
      pvalue   = p_value,
      stringsAsFactors = FALSE
    )

    results_list[[length(results_list) + 1L]] <- overall_result
  }, error = function(e) {
    if (interactive()) {
      message("Could not calculate overall log-rank test: ", e$message)
    }
  })

  # Pairwise comparisons
  strata_combinations <- utils::combn(strata_with_events, 2L, simplify = FALSE)

  for (combo in strata_combinations) {
    stratum1 <- combo[1]
    stratum2 <- combo[2]

    pair_data <- df[df[[strata_col]] %in% c(stratum1, stratum2), ]

    if (sum(pair_data$status) == 0L) next

    tryCatch({
      surv_obj <- survival::Surv(
        time = pair_data$time_scaled,
        event = pair_data$status
      )

      formula_str <- paste("surv_obj ~", strata_col)
      survdiff_result <- survival::survdiff(
        stats::as.formula(formula_str),
        data = pair_data
      )

      chi_sq <- survdiff_result$chisq
      df_chisq <- 1L
      p_value <- stats::pchisq(chi_sq, df = df_chisq, lower.tail = FALSE)

      pairwise_result <- data.frame(
        testType = "pairwise",
        stratum1 = stratum1,
        stratum2 = stratum2,
        chisq    = chi_sq,
        df       = df_chisq,
        pvalue   = p_value,
        stringsAsFactors = FALSE
      )

      results_list[[length(results_list) + 1L]] <- pairwise_result
    }, error = function(e) {
      if (interactive()) {
        message(sprintf("Could not calculate pairwise test for %s vs %s: %s",
                        stratum1, stratum2, e$message))
      }
    })
  }

  if (length(results_list) > 0L) do.call(rbind, results_list) else NULL
}

# ---- new model helpers --------------------------------------------------

#' Fit a KM or Cox survfit object for a single group data frame.
#'
#' For `model = "km"` uses [ggsurvfit::survfit()]; for `model = "cox"` uses
#' [survival::coxph()] + [survival::survfit()]. The returned object is a
#' `survfit` whose `summary()` and `quantile()` are compatible with the
#' existing KM summarisation pipeline.
#'
#' @param df Data frame with `time_scaled` and `status`. For Cox with
#'   covariates, all covariate columns must be present.
#' @param model `"km"` or `"cox"`.
#' @param covariates Character vector of covariate column names (used for Cox).
#' @param confType Confidence interval type passed to `conf.type`.
#' @param confInt Confidence level passed to `conf.int`.
#' @return A `survfit` object.
#' @keywords internal
#' @noRd
.fit_survfit_for_group <- function(df, model, covariates, confType, confInt) {
  if (model == "km") {
    return(ggsurvfit::survfit(
      survival::Surv(time_scaled, status) ~ 1,
      data = df,
      conf.type = confType,
      conf.int = confInt
    ))
  }

  # Cox model
  if (length(covariates) == 0L) {
    fml <- stats::as.formula("survival::Surv(time_scaled, status) ~ 1")
  } else {
    fml <- stats::as.formula(paste(
      "survival::Surv(time_scaled, status) ~",
      paste(covariates, collapse = " + ")
    ))
  }

  cox_fit <- tryCatch(
    survival::coxph(fml, data = df),
    error = function(e) stop(paste0("Cox model failed: ", e$message))
  )

  if (length(covariates) == 0L) {
    survival::survfit(cox_fit, conf.type = confType, conf.int = confInt)
  } else {
    mean_cov <- as.data.frame(lapply(df[covariates], mean, na.rm = TRUE))
    survival::survfit(cox_fit, newdata = mean_cov, conf.type = confType, conf.int = confInt)
  }
}


#' Build a named list of parametric model results for all strata + overall.
#'
#' Loops over gender / age_group strata (same logic as the KM path) and calls
#' [.fit_parametric_for_group()] for each subset, then for the full data.
#'
#' @param df Scaled data frame with `time_scaled`, `status`, and optional
#'   `gender` / `age_group` columns.
#' @param strata Character vector of active strata (`"gender"`, `"age_group"`),
#'   or `NULL`.
#' @param dist Distribution name: one of `"weibull"`, `"exponential"`,
#'   `"lognormal"`, `"loglogistic"`.
#' @param covariates Character vector of covariate column names (may be empty).
#' @param probs Numeric vector of probability quantiles to extract.
#' @param timeScale Reported time scale label stored in summary.
#' @return Named list of `list(data, summary)` entries, same structure as the
#'   KM summarisation pipeline produces.
#' @keywords internal
#' @noRd
.build_parametric_survivalList <- function(df, strata, dist, covariates, probs, timeScale) {
  result <- list()

  if (!is.null(strata)) {
    if ("gender" %in% strata && "gender" %in% names(df)) {
      gender_levels <- sort(stats::na.omit(unique(df$gender)))
      for (g in gender_levels) {
        .df <- df[!is.na(df$gender) & df$gender == g, ]
        if (nrow(.df) > 0L) {
          result[[paste0("gender=", g)]] <- .fit_parametric_for_group(
            .df, dist, covariates, probs, timeScale
          )
        }
      }
    }

    if ("age_group" %in% strata && "age_group" %in% names(df)) {
      age_levels <- sort(stats::na.omit(unique(df$age_group)))
      for (a in age_levels) {
        .df <- df[!is.na(df$age_group) & df$age_group == a, ]
        if (nrow(.df) > 0L) {
          result[[paste0("age_group=", a)]] <- .fit_parametric_for_group(
            .df, dist, covariates, probs, timeScale
          )
        }
      }
    }
  }

  result[["overall"]] <- .fit_parametric_for_group(df, dist, covariates, probs, timeScale)
  result
}


#' Fit a parametric survival model for a single group and return standardised results.
#'
#' Uses [survival::survreg()] to fit a Weibull, exponential, log-normal, or
#' log-logistic AFT model and evaluates the survival function at every unique
#' observed event time.  Confidence intervals are not available for the
#' pointwise survival estimates (set to `NA`); use the `summary` element for
#' the median and quartile survival times derived from the fitted distribution.
#'
#' @param df Data frame with `time_scaled` (numeric, scaled follow-up) and
#'   `status` (0/1). Covariate columns must be present if `covariates` is
#'   non-empty.
#' @param dist One of `"weibull"`, `"exponential"`, `"lognormal"`,
#'   `"loglogistic"`. Passed as `dist` to [survival::survreg()].
#' @param covariates Character vector of covariate names (intercept-only model
#'   when empty).
#' @param probs Numeric vector of survival probabilities used to extract
#'   quantile times, e.g. `c(0.75, 0.5, 0.25)`.
#' @param timeScale Character label stored in the `summary$timeScale` field.
#' @return A `list(data = <tibble>, summary = <list>)` with the same structure
#'   as the KM summarisation pipeline.
#' @keywords internal
#' @noRd
.fit_parametric_for_group <- function(df, dist, covariates, probs, timeScale) {

  if (length(covariates) == 0L) {
    fml <- stats::as.formula("survival::Surv(time_scaled, status) ~ 1")
  } else {
    fml <- stats::as.formula(paste(
      "survival::Surv(time_scaled, status) ~",
      paste(covariates, collapse = " + ")
    ))
  }

  model_fit <- tryCatch(
    survival::survreg(fml, data = df, dist = dist),
    error = function(e) stop(paste0(dist, " model failed: ", e$message))
  )

  mu    <- as.numeric(stats::coef(model_fit)[1])
  sigma <- model_fit$scale

  # Survival function S(t) by distribution
  surv_fn <- switch(dist,
    weibull     = function(t) stats::pweibull(t, shape = 1 / sigma, scale = exp(mu), lower.tail = FALSE),
    exponential = function(t) stats::pexp(t, rate = exp(-mu), lower.tail = FALSE),
    lognormal   = function(t) stats::plnorm(t, meanlog = mu, sdlog = sigma, lower.tail = FALSE),
    loglogistic = function(t) 1 / (1 + (t / exp(mu))^(1 / sigma))
  )

  event_times <- sort(unique(df$time_scaled[df$status == 1L]))

  # Edge case: no events
  if (length(event_times) == 0L) {
    return(list(
      data = data.frame(
        time = numeric(0), n_risk = integer(0), n_event = integer(0),
        n_censor = integer(0), lower = numeric(0), upper = numeric(0),
        survival = numeric(0), std_err = numeric(0), hazard = numeric(0),
        cum_hazard = numeric(0), cum_event = integer(0), cum_censor = integer(0),
        stringsAsFactors = FALSE
      ),
      summary = list(
        n = nrow(df), events = 0L, censored = nrow(df),
        medianSurvival = NA_real_, medianSurvivalLowerCI = NA_real_,
        medianSurvivalUpperCI = NA_real_, q25Survival = NA_real_,
        q75Survival = NA_real_, meanSurvival = NA_real_, timeScale = timeScale
      )
    ))
  }

  # Risk table at each unique event time
  n_risk_vec   <- vapply(event_times, function(t) sum(df$time_scaled >= t), integer(1))
  n_event_vec  <- vapply(event_times, function(t) sum(df$time_scaled == t & df$status == 1L), integer(1))
  n_censor_vec <- vapply(event_times, function(t) sum(df$time_scaled == t & df$status == 0L), integer(1))

  surv_vec    <- pmax(pmin(surv_fn(event_times), 1), 0)
  cum_haz_vec <- -log(pmax(surv_vec, .Machine$double.eps))

  # Hazard approximated as -dS / (S * dt) (discrete)
  dt      <- c(event_times[1], diff(event_times))
  ds      <- -diff(c(1, surv_vec))
  haz_vec <- pmax(ds / pmax(utils::head(c(1, surv_vec), -1) * dt, .Machine$double.eps), 0)

  # Quantile survival times from parametric S(t) evaluated at all observed times
  all_times <- sort(unique(df$time_scaled))
  q_times <- vapply(probs, function(q) {
    idx <- which(surv_fn(all_times) <= q)
    if (length(idx) > 0L) all_times[idx[1L]] else NA_real_
  }, numeric(1))

  # probs defaults to c(0.75, 0.5, 0.25) -> q75, median, q25
  q75_surv    <- if (length(q_times) >= 1L) q_times[1L] else NA_real_
  median_surv <- if (length(q_times) >= 2L) q_times[2L] else NA_real_
  q25_surv    <- if (length(q_times) >= 3L) q_times[3L] else NA_real_

  # Restricted mean survival (area under S(t) over the observed time range)
  t_grid    <- c(0, all_times)
  s_grid    <- c(1, pmax(pmin(surv_fn(all_times), 1), 0))
  mean_surv <- sum(diff(t_grid) * utils::head(s_grid, -1))

  list(
    data = data.frame(
      time       = event_times,
      n_risk     = n_risk_vec,
      n_event    = n_event_vec,
      n_censor   = n_censor_vec,
      lower      = NA_real_,
      upper      = NA_real_,
      survival   = surv_vec,
      std_err    = NA_real_,
      hazard     = haz_vec,
      cum_hazard = cum_haz_vec,
      cum_event  = as.integer(cumsum(n_event_vec)),
      cum_censor = as.integer(cumsum(n_censor_vec)),
      stringsAsFactors = FALSE
    ),
    summary = list(
      n                     = nrow(df),
      events                = sum(df$status),
      censored              = sum(df$status == 0L),
      medianSurvival        = if (!is.na(median_surv)) round(median_surv, 1) else NA_real_,
      medianSurvivalLowerCI = NA_real_,
      medianSurvivalUpperCI = NA_real_,
      q25Survival           = q25_surv,
      q75Survival           = q75_surv,
      meanSurvival          = mean_surv,
      timeScale             = timeScale
    )
  )
}

Try the OdysseusSurvivalModule package in your browser

Any scripts or data that you put into this service are public.

OdysseusSurvivalModule documentation built on April 3, 2026, 5:06 p.m.