R/main.R

Defines functions WA_fit

Documented in WA_fit

# R/main.R ------------------------------------------------------------

#' While-Alive Regression (WA) for Composite Endpoints
#'
#' Fits the while-alive regression model targeting the while-alive loss rate
#' for composite endpoints with recurrent and terminal events. Time-varying
#' covariate effects are represented via user-chosen time bases (e.g., B-spline,
#' piecewise polynomial, interval-local). Robust inference supports
#' cluster-randomized trials (CRTs) via cluster-robust variance; if
#' \code{cluster = NULL}, IID (subject-as-cluster) variance is used.
#'
#' @param formula A \code{Surv(time, status) ~ RHS} formula. \code{time} and
#'   \code{status} must exist in \code{data}. The RHS contains baseline
#'   covariates (no explicit time-varying covariates here; time-variation is
#'   induced via the chosen basis).
#' @param data Long-format data frame with one row per \emph{event/checkpoint}
#'   per subject, containing \code{time}, \code{status}, \code{id}, optional
#'   \code{cluster}, and RHS covariates.
#' @param id Character scalar; subject ID column name.
#' @param cluster Optional character scalar; cluster column name for CRT-robust
#'   inference. If \code{NULL}, IID inference treats each subject as its own cluster.
#' @param knots Numeric vector (length \eqn{\ge 2}) specifying the basis
#'   boundaries and optional interior knots that define the time basis shape.
#' @param tau_grid Numeric vector of evaluation times used to stack the
#'   estimating equations. Independent of \code{knots}.
#' @param basis One of \code{"il","pl","bz","ns","ms","st","tl","tf"}:
#'   interval-local (\code{"il"}), piecewise polynomial (\code{"pl"}),
#'   B-spline (\code{"bz"}), natural spline (\code{"ns"}), M-spline
#'   (\code{"ms"}, requires \pkg{splines2}), step (\code{"st"}), truncated
#'   linear (\code{"tl"}), or time-fixed (\code{"tf"}).
#' @param degree Integer degree for bases that use it (e.g., \code{"bz"}, \code{"pl"}, \code{"ns"}, \code{"ms"}).
#' @param link Link function: \code{"log"} (default) or \code{"identity"}.
#' @param w_recur Numeric vector of weights for each recurrent event type. Its
#'   length must match the number of recurrent \code{status} codes in
#'   \code{data} (i.e., excluding \code{0} for censoring and the max code for terminal).
#' @param w_term Numeric scalar; weight for the terminal event.
#' @param ipcw IPCW method: \code{"km"} or \code{"cox"}.
#' @param ipcw_formula A one-sided formula specifying RHS covariates for the IPCW Cox model
#'   when \code{ipcw = "cox"} (e.g., \code{~ x1 + x2}). Ignored for \code{ipcw = "km"}.
#'
#' @details
#' The estimating equations solve \eqn{E[Z(t)\{L(t) - \mu_\beta(t)X_{\min}(t)\}V/G]=0}
#' over \code{tau_grid}, where \eqn{L(t)} is the weighted composite loss
#' (recurrent+terminal), \eqn{\mu_\beta(t)} the while-alive loss rate under the chosen
#' link, \eqn{X_{\min}(t) = \min(T, t)}, \eqn{V} the at-risk/terminal indicator, and
#' \eqn{G} the censoring survival modeled via \code{ipcw}.
#'
#' @return An object of class \code{"WA"} with elements:
#' \itemize{
#' \item \code{est}: named coefficient vector.
#' \item \code{vcov}: cluster-robust variance matrix.
#' \item \code{se}: standard errors.
#' \item \code{converged}: logical.
#' \item \code{basis}, \code{degree}, \code{link}, \code{Z_cols},
#'       \code{knots}, \code{tau_grid}, \code{id_var}, \code{cluster_var},
#'       \code{w_recur}, \code{w_term}, \code{status_codes}, \code{formula}.
#' }
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(
#'   survival::Surv(time, status) ~ trt + Z1 + Z2,
#'   data     = ex_dt,
#'   id       = "id",
#'   cluster  = "cluster",
#'   knots    = seq(0, max(ex_dt$time, na.rm = TRUE), length.out = 6),
#'   tau_grid = seq(0, max(ex_dt$time, na.rm = TRUE), length.out = 6),
#'   basis    = "bz", degree = 1, link = "log",
#'   w_recur  = c(1, 1), w_term = 2,
#'   ipcw     = "km"
#' )
#' s <- summary(fit)
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' plot(fit, newdata = nd,
#'      t_seq = seq(0, max(fit$tau_grid), length.out = 200),
#'      id = 1, mode = "wa", smooth = TRUE)
#' }
#'
#' @export
WA_fit <- function(formula,
                   data,
                   id,
                   cluster = NULL,
                   knots,
                   tau_grid,
                   basis   = c("il","pl","bz","ns","ms","st","tl","tf"),
                   degree  = 1,
                   link    = c("log","identity"),
                   w_recur,
                   w_term,
                   ipcw = c("km","cox"),
                   ipcw_formula = ~ 1) {

  basis <- match.arg(basis)
  link  <- match.arg(link)
  ipcw  <- match.arg(ipcw)

  par <- .WA_parse_formula(formula, data)
  time_vec   <- par$time_vec
  status_vec <- par$status_vec
  Xmm        <- par$Xmm
  Z_cols     <- par$Z_cols

  if (length(Z_cols) == 0L) stop("No covariates specified on RHS.")
  if (!id %in% names(data)) stop("id column not found in data.")
  if (!is.null(cluster) && !cluster %in% names(data)) stop("cluster column not found in data.")

  Z_df <- as.data.frame(Xmm); names(Z_df) <- Z_cols
  dat <- data.frame(hold = seq_along(time_vec))
  dat$.time   <- time_vec
  dat$.status <- status_vec
  dat$.id     <- data[[id]]
  dat$.cluster <- if (!is.null(cluster)) data[[cluster]] else data[[id]]
  dat <- cbind(dat, Z_df)[, -1]

  # drop any columns with NAs
  dat <- dat %>% dplyr::select(dplyr::where(~ !any(is.na(.))))

  ustat <- sort(unique(dat$.status))
  if (!any(ustat == 0L)) stop("Status must include 0 for censoring.")
  s_max <- max(ustat, na.rm = TRUE)
  rec_types <- setdiff(ustat, c(0L, s_max))
  if (length(w_recur) != length(rec_types)) {
    stop("Length(w_recur) must equal # recurrent event types (status in {",
         paste(rec_types, collapse=","), "}).")
  }

  subj <- dat %>%
    dplyr::group_by(.data$.id, .data$.cluster) %>%
    dplyr::summarise(
      dplyr::across(dplyr::all_of(Z_cols), ~ dplyr::first(.x)),
      obs_T  = max(.data$.time, na.rm = TRUE),
      Delta  = { idx <- which.max(.data$.time); as.integer(.data$.status[idx] == s_max) },
      .groups = "drop"
    )

  expanded <- subj %>%
    tidyr::crossing(tau = tau_grid) %>%
    dplyr::mutate(
      X_min_tau = pmin(.data$obs_T, .data$tau),
      V_i_tau   = as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L)) + as.integer(.data$obs_T > .data$tau),
      N_term_tau= as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L))
    )

  if (ipcw == "km") {
    km_c <- survival::survfit(survival::Surv(time = subj$obs_T, event = 1 - subj$Delta) ~ 1)
    expanded$G_X_min_tau <- stats::approx(
      x = km_c$time, y = km_c$surv,
      xout = expanded$X_min_tau,
      method = "constant", f = 0,
      yleft = 1, yright = utils::tail(km_c$surv, 1)
    )$y
  } else {
    rhs <- paste(deparse(ipcw_formula[[2]]), collapse = "")
    cfit <- survival::coxph(stats::as.formula(paste0("Surv(obs_T, I(1-Delta)) ~ ", rhs)),
                            data = subj, ties = "breslow", x = TRUE, y = TRUE, model = FALSE)
    bh <- survival::basehaz(cfit, centered = FALSE)
    Lambda0 <- stats::approxfun(bh$time, bh$hazard, method = "linear", rule = 2)

    mm <- stats::model.matrix(stats::as.formula(paste0("~ -1 + ", rhs)), data = subj)
    beta_c <- stats::coef(cfit)
    if (length(beta_c)) mm <- mm[, names(beta_c), drop = FALSE]
    lp <- if (length(beta_c)) drop(mm %*% beta_c) else 0

    subj$lp_cens <- lp
    expanded <- dplyr::left_join(expanded, subj[, c(".id", "lp_cens")], by = c(".id" = ".id"))
    expanded$G_X_min_tau <- exp(-Lambda0(expanded$X_min_tau) * exp(expanded$lp_cens))
    expanded$G_X_min_tau[!is.finite(expanded$G_X_min_tau)] <- 0
  }

  recs <- dat %>%
    dplyr::filter(.data$.status %in% c(rec_types, s_max)) %>%
    dplyr::select(dplyr::all_of(c(".id", ".status", ".time")))

  grid <- expanded %>% dplyr::select(dplyr::all_of(c(".id", "tau", "X_min_tau")))

  counts <- recs %>%
    dplyr::right_join(grid, by = ".id", relationship = "many-to-many") %>%
    dplyr::filter(.data$.time <= .data$X_min_tau) %>%
    dplyr::group_by(.data$.id, .data$tau, .data$X_min_tau, .data$.status) %>%
    dplyr::summarise(n = dplyr::n(), .groups = "drop")

  if (nrow(counts)) {
    counts_wide <- tidyr::pivot_wider(
      counts,
      id_cols = c(".id", "tau", "X_min_tau"),
      names_from = ".status",
      values_from = "n",
      values_fill = 0L,
      names_prefix = "N_"
    )
    expanded <- dplyr::left_join(expanded, counts_wide,
                                 by = c(".id","tau","X_min_tau"))
  }

  for (s in rec_types) {
    nm <- paste0("N_", s)
    if (!nm %in% names(expanded)) expanded[[nm]] <- 0L
  }

  L_recur <- 0
  for (j in seq_along(rec_types)) {
    L_recur <- L_recur + w_recur[j] * expanded[[paste0("N_", rec_types[j])]]
  }
  expanded$L <- L_recur + w_term * expanded$N_term_tau
  expanded$L <- ifelse(is.na(expanded$L), 0, expanded$L)

  tb <- .WA_design_Z(expanded, Z_cols = Z_cols, knots = knots,
                     basis = basis, degree = degree, include_intercept = FALSE)
  expanded <- tb$data
  cov_pat  <- tb$cov_pattern

  p <- length(grep(cov_pat, names(expanded)))
  beta_init <- rep(0, p)

  sol <- nleqslv::nleqslv(
    x  = beta_init,
    fn = .WA_ee,
    data = expanded,
    cov_pattern = cov_pat,
    L_col="L", Dmin_col="X_min_tau", V_col="V_i_tau", G_col="G_X_min_tau",
    link = link,
    method = "Broyden"
  )
  if (sol$termcd > 3) warning("WA_fit: nleqslv may not have converged (termcd=", sol$termcd, ").")
  beta_hat <- sol$x

  V <- .WA_var(
    data = expanded, beta = beta_hat, cov_pattern = cov_pat,
    L_col="L", Dmin_col="X_min_tau", V_col="V_i_tau", G_col="G_X_min_tau",
    id_col = ".id", cluster_col = if (is.null(cluster)) NULL else ".cluster",
    link = link
  )

  beta_names <- grep(cov_pat, names(expanded), value = TRUE)
  names(beta_hat) <- beta_names
  dimnames(V)     <- list(beta_names, beta_names)

  structure(list(
    call        = match.call(),
    est         = beta_hat,
    vcov        = V,
    se          = sqrt(pmax(diag(V), 0)),
    converged   = (sol$termcd <= 3),
    basis       = basis,
    degree      = degree,
    link        = link,
    Z_cols      = Z_cols,
    knots       = knots,
    tau_grid    = tau_grid,
    id_var      = id,
    cluster_var = cluster,
    w_recur     = w_recur,
    w_term      = w_term,
    status_codes= list(recurrent = rec_types, terminal = s_max, censor = 0),
    formula     = formula
  ), class = "WA")
}


# R/summary-print.R ----------------------------------------------------

#' @export
print.WA <- function(x, ...) {
  cat("While-Alive Regression (WA)\n")
  cat("  basis:", x$basis, " degree:", x$degree, " link:", x$link, "\n")
  cat("  #coef:", length(x$est), " converged:", isTRUE(x$converged), "\n")
  invisible(x)
}

#' Summarize a WA object
#'
#' @param object A \code{"WA"} object from \code{\link{WA_fit}}.
#' @param ... Unused.
#'
#' @return An object of class \code{"summary.WA"} containing configuration and a
#'   coefficient table with estimates, standard errors, and z-scores.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#'               data = ex_dt, id="id", cluster="cluster",
#'               knots=seq(0, max(ex_dt$time), length.out=6),
#'               tau_grid=seq(0, max(ex_dt$time), length.out=6),
#'               basis="bz", degree=1, link="log",
#'               w_recur=c(1,1), w_term=2, ipcw="km")
#' summary(fit)
#' }
#' @export
summary.WA <- function(object, ...) {
  coeftab <- data.frame(
    coef = object$est,
    se   = object$se,
    z    = ifelse(object$se > 0, object$est/object$se, NA_real_),
    row.names = names(object$est)
  )
  res <- list(
    basis = object$basis,
    degree = object$degree,
    link = object$link,
    knots = object$knots,
    tau_grid = object$tau_grid,
    status_codes = object$status_codes,
    w_recur = object$w_recur,
    w_term  = object$w_term,
    coef  = coeftab
  )
  class(res) <- "summary.WA"
  res
}

#' @export
print.summary.WA <- function(x, ...) {
  cat("While-Alive Regression summary\n")
  cat("  basis:", x$basis, " degree:", x$degree, " link:", x$link, "\n")
  cat("  knots:", paste(format(x$knots), collapse = ", "), "\n")
  cat("  \u03C4-grid length:", length(x$tau_grid), "\n")
  cat("  status codes: recurrent={", paste(x$status_codes$recurrent, collapse=","), "}, terminal=",
      x$status_codes$terminal, ", censor=0\n", sep = "")
  cat("  weights: recur=", paste(x$w_recur, collapse=","), " terminal=", x$w_term, "\n\n", sep="")
  print(x$coef)
  invisible(x)
}


# R/predict.R ----------------------------------------------------------

#' Predict while-alive loss rates
#'
#' @param object A \code{"WA"} object.
#' @param newdata Data frame with columns matching the RHS of the fitted model.
#'   Predictions are computed for the rows of \code{newdata}.
#' @param t_seq Numeric vector of times at which to evaluate predictions.
#' @param level Confidence level for pointwise intervals (default 0.95).
#' @param ... Unused.
#'
#' @return A data frame with columns \code{id} (row index in \code{newdata}),
#'   \code{t}, \code{mu} (predicted while-alive rate), and CI columns \code{lb}, \code{ub}.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#'               data = ex_dt, id="id", cluster="cluster",
#'               knots=seq(0, max(ex_dt$time), length.out=6),
#'               tau_grid=seq(0, max(ex_dt$time), length.out=6),
#'               basis="bz", degree=1, link="log",
#'               w_recur=c(1,1), w_term=2, ipcw="km")
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' pred <- predict(fit, newdata = nd, t_seq = seq(0, max(fit$tau_grid), by = 0.2))
#' head(pred)
#' }
#' @export
predict.WA <- function(object, newdata, t_seq, level = 0.95, ...) {
  stopifnot(!is.null(object$est), !is.null(object$vcov))
  V  <- object$vcov
  lf <- .WA_link(object$link)

  tf <- stats::terms(object$formula, data = newdata)
  Xmm <- stats::model.matrix(stats::delete.response(tf), data = newdata)
  if (ncol(Xmm) > 0 && colnames(Xmm)[1] == "(Intercept)")
    Xmm <- Xmm[, -1, drop = FALSE]
  Z_cols <- colnames(Xmm)
  if (length(Z_cols) == 0L) stop("newdata lacks covariates specified in the model formula.")
  Z_df <- as.data.frame(Xmm)

  zcrit <- stats::qnorm(0.5 + level/2)
  rows <- vector("list", length(t_seq))

  for (i in seq_along(t_seq)) {
    X <- .WA_design_at_t(Z_df, t_seq[i], Z_cols, object$knots, object$basis, object$degree)
    beta_ord <- object$est[colnames(X)]
    eta <- as.vector(X %*% beta_ord)
    mu  <- lf$linkinv(eta)

    VX <- X %*% V[colnames(X), colnames(X), drop = FALSE]
    var_eta <- rowSums(VX * X)
    se_eta  <- sqrt(pmax(var_eta, 0))

    if (object$link == "log") {
      lb <- mu * exp(-zcrit * se_eta)
      ub <- mu * exp( zcrit * se_eta)
    } else {
      lb <- pmax(mu - zcrit*se_eta, 0)
      ub <- pmax(mu + zcrit*se_eta, 0)
    }
    rows[[i]] <- data.frame(id = seq_len(nrow(newdata)), t = t_seq[i],
                            mu = mu, lb = lb, ub = ub)
  }
  do.call(rbind, rows)
}


# R/plot.R -------------------------------------------------------------

#' Plot while-alive trajectory or a covariate's time-varying effect
#'
#' @param x A \code{"WA"} object.
#' @param newdata Data used to rebuild the RHS design (same columns as in the model).
#' @param t_seq Times to plot over (numeric vector).
#' @param id Row index of \code{newdata} to use for the while-alive trajectory (mode = "wa").
#' @param mode \code{"wa"} to plot the while-alive loss rate, or \code{"cov"} to
#'   plot a specific covariate's time-varying effect.
#' @param covariate Character; covariate name (must appear on RHS) when \code{mode="cov"}.
#' @param ylab_wa Y-axis label for while-alive plot.
#' @param ylab_cov Y-axis label for covariate-effect plot; default
#'   \code{"Effect of <covariate> on \u03B7(t)"}.
#' @param xlab X-axis label.
#' @param level Confidence level for ribbons (default 0.95).
#' @param smooth Logical; if \code{TRUE}, apply LOESS smoothing to the displayed curve/CI.
#' @param span LOESS span used when \code{smooth=TRUE}.
#' @param ... Unused.
#'
#' @return A \pkg{ggplot2} object.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#'               data = ex_dt, id="id", cluster="cluster",
#'               knots=seq(0, max(ex_dt$time), length.out=6),
#'               tau_grid=seq(0, max(ex_dt$time), length.out=6),
#'               basis="bz", degree=1, link="log",
#'               w_recur=c(1,1), w_term=2, ipcw="km")
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' plot(fit, newdata = nd,
#'      t_seq = seq(0, max(fit$tau_grid), length.out = 200),
#'      id = 1, mode = "wa", smooth = TRUE)
#' }
#' @export
plot.WA <- function(x, newdata, t_seq, id = 1,
                    mode = c("wa","cov"),
                    covariate = NULL,
                    ylab_wa = "While-alive loss rate",
                    ylab_cov = NULL,
                    xlab = "Time",
                    level = 0.95,
                    smooth = FALSE,
                    span = 0.30,
                    ...) {
  mode <- match.arg(mode)

  tf  <- stats::terms(x$formula, data = newdata)
  Xmm <- stats::model.matrix(stats::delete.response(tf), data = newdata)
  if (ncol(Xmm) > 0 && colnames(Xmm)[1] == "(Intercept)")
    Xmm <- Xmm[, -1, drop = FALSE]
  Z_cols <- colnames(Xmm)
  if (length(Z_cols) == 0L) stop("newdata lacks covariates specified in the model formula.")
  if (id < 1 || id > nrow(newdata)) stop("`id` out of range for `newdata`.")

  maybe_smooth <- function(df, y, ymin, ymax) {
    if (!smooth) return(df)
    lo  <- stats::loess(stats::as.formula(paste0(y,    "~ t")), data = df, span = span)
    lol <- stats::loess(stats::as.formula(paste0(ymin, "~ t")), data = df, span = span)
    lou <- stats::loess(stats::as.formula(paste0(ymax, "~ t")), data = df, span = span)
    df[[y]]    <- stats::predict(lo)
    df[[ymin]] <- stats::predict(lol)
    df[[ymax]] <- stats::predict(lou)
    df
  }

  if (mode == "wa") {
    d <- predict.WA(x, newdata = newdata, t_seq = t_seq, level = level)
    d <- d[d$id == id, , drop = FALSE]
    d <- maybe_smooth(d, y = "mu", ymin = "lb", ymax = "ub")
    return(
      ggplot2::ggplot(d, ggplot2::aes(x = .data$t, y = .data$mu)) +
        ggplot2::geom_line(linewidth = 0.6) +
        ggplot2::geom_ribbon(ggplot2::aes(ymin = .data$lb, ymax = .data$ub), alpha = 0.2) +
        ggplot2::labs(x = xlab, y = ylab_wa) +
        ggplot2::theme_minimal(base_size = 12) +
        ggplot2::theme(panel.grid.minor = ggplot2::element_blank())
    )
  }

  if (is.null(covariate) || !(covariate %in% Z_cols))
    stop("Provide a valid `covariate` name present in the model RHS.")
  if (is.null(ylab_cov)) ylab_cov <- paste0("Effect of ", covariate, " on \u03B7(t)")

  want <- grep(paste0("^", covariate, "_t0_seg[0-9]+$"), names(x$est), value = TRUE)
  if (!length(want)) stop("No time-basis coefficients found for ", covariate)
  seg <- as.integer(sub(".*_seg([0-9]+)$", "\\1", want))
  ord <- order(seg); want <- want[ord]; seg <- seg[ord]

  build_basis_match <- function(inc_int) {
    B <- .WA_time_basis(
      t      = t_seq,
      knots  = x$knots,
      basis  = x$basis,
      degree = x$degree,
      include_intercept = inc_int
    )
    colnames(B) <- sprintf("t0_seg%d", seq_len(ncol(B)))
    need_names <- sprintf("t0_seg%d", seg)
    idx <- match(need_names, colnames(B))
    list(B = B, idx = idx, ok = !any(is.na(idx)))
  }

  try1 <- build_basis_match(FALSE)
  if (try1$ok) { Bfull <- try1$B; idx <- try1$idx } else {
    try2 <- build_basis_match(TRUE)
    if (!try2$ok) stop("Basis/coef mismatch. Check basis/degree/knots vs fit.")
    Bfull <- try2$B; idx <- try2$idx
  }

  Bz <- Bfull[, idx, drop = FALSE]
  b  <- as.numeric(x$est[want])
  V  <- as.matrix(x$vcov[want, want, drop = FALSE])

  beta_hat <- as.vector(Bz %*% b)
  var_hat  <- rowSums((Bz %*% V) * Bz)
  var_hat  <- pmax(var_hat, 0)
  zcrit    <- stats::qnorm(0.5 + level/2)
  lb <- beta_hat - zcrit * sqrt(var_hat)
  ub <- beta_hat + zcrit * sqrt(var_hat)

  df <- data.frame(t = t_seq, eff = beta_hat, lb = lb, ub = ub)
  df <- maybe_smooth(df, y = "eff", ymin = "lb", ymax = "ub")

  ggplot2::ggplot(df, ggplot2::aes(x = .data$t, y = .data$eff)) +
    ggplot2::geom_line(linewidth = 0.6) +
    ggplot2::geom_ribbon(ggplot2::aes(ymin = .data$lb, ymax = .data$ub), alpha = 0.2) +
    ggplot2::labs(x = xlab, y = ylab_cov) +
    ggplot2::theme_minimal(base_size = 12) +
    ggplot2::theme(panel.grid.minor = ggplot2::element_blank(),
                   legend.position = "none")
}


# R/WA_cv.R ------------------------------------------------------------

#' K-fold cross-validation for WA configuration selection
#'
#' Runs K-fold CV over a grid of basis types, degrees, interior-knot counts,
#' and link functions. For each configuration, fits the model on K-1 folds and
#' accumulates the prediction error (PE) on the held-out fold using
#' \code{WA_PE()} (IPCW computed on the training subjects).
#'
#' @param formula A \code{Surv(time, status) ~ RHS} formula; see \code{\link{WA_fit}}.
#' @param data Long-format data frame; see \code{\link{WA_fit}}.
#' @param id Character scalar; subject ID column name; see \code{\link{WA_fit}}.
#' @param cluster Optional character scalar; cluster column name; see \code{\link{WA_fit}}.
#' @param basis_set Character vector of candidate bases.
#' @param degree_vec Integer vector of candidate degrees.
#' @param n_int_vec Integer vector of interior-knot counts; 0 means boundaries only.
#' @param knot_scheme \code{"equidist"} or \code{"quantile"} to construct interior knots.
#' @param link_set Character vector of candidate links (subset of \code{c("log","identity")}).
#' @param time_range Optional numeric length-2 vector \code{c(tmin, tmax)}. If \code{NULL},
#'   inferred from \code{data}.
#' @param tau_grid Optional numeric vector; if \code{NULL}, a default dense grid over
#'   \code{time_range} is created.
#' @param w_recur recurrent-event weights
#' @param w_term Numeric scalar; terminal-event weight; see \code{\link{WA_fit}}.
#' @param ipcw IPCW method (\code{"cox"} or \code{"km"}) for PE computation.
#' @param ipcw_formula One-sided RHS formula for IPCW Cox model (if \code{ipcw="cox"}).
#' @param K Number of folds.
#' @param seed RNG seed for fold assignment.
#' @param verbose Logical; show a text progress bar and per-fold messages.
#'
#' @return A data frame with columns: \code{basis}, \code{degree}, \code{n_int},
#'   \code{link}, and aggregated \code{PE}. Lower \code{PE} is better.
#'
#' @export
WA_cv <- function(formula,
                  data,
                  id,
                  cluster = NULL,
                  basis_set   = c("il","pl","bz"),
                  degree_vec  = 1:2,
                  n_int_vec   = c(0, 2, 4),
                  knot_scheme = c("equidist","quantile"),
                  link_set    = c("log"),
                  time_range  = NULL,
                  tau_grid    = NULL,
                  w_recur,
                  w_term,
                  ipcw        = c("cox","km"),
                  ipcw_formula = ~ 1,
                  K = 5, seed = 1L,
                  verbose = TRUE) {

  knot_scheme <- match.arg(knot_scheme)
  ipcw <- match.arg(ipcw)

  par_cv <- .WA_parse_formula(formula, data)
  t_obs  <- par_cv$time_vec
  if (is.null(time_range)) {
    tmin <- 0
    tmax <- max(t_obs, na.rm = TRUE)
  } else {
    tmin <- time_range[1]; tmax <- time_range[2]
  }
  if (is.null(tau_grid)) {
    tau_grid <- seq(tmin, tmax - 1e-6, length.out = 150)
  }

  set.seed(seed)
  if (!is.null(cluster) && cluster %in% names(data)) {
    keys <- unique(data[[cluster]])
    fold_ids <- sample(rep_len(seq_len(K), length(keys)))
    names(fold_ids) <- as.character(keys)
    data$.fold <- fold_ids[as.character(data[[cluster]])]
  } else {
    keys <- unique(data[[id]])
    fold_ids <- sample(rep_len(seq_len(K), length(keys)))
    names(fold_ids) <- as.character(keys)
    data$.fold <- fold_ids[as.character(data[[id]])]
  }

  make_knots <- function(n_int) {
    if (n_int <= 0) return(c(tmin, tmax))
    if (knot_scheme == "equidist") {
      inter <- seq(tmin, tmax, length.out = n_int + 2L)[-c(1, n_int + 2L)]
    } else {
      qs <- seq(0, 1, length.out = n_int + 2L)[-c(1, n_int + 2L)]
      inter <- as.numeric(stats::quantile(t_obs, probs = qs, na.rm = TRUE))
    }
    sort(unique(c(tmin, inter, tmax)))
  }

  n_cfg  <- length(basis_set) * length(degree_vec) * length(n_int_vec) * length(link_set)
  total_steps <- n_cfg * K
  step <- 0L
  if (verbose) {
    pb <- utils::txtProgressBar(min = 0, max = total_steps, style = 3)
    on.exit(try(close(pb), silent = TRUE), add = TRUE)
  }
  bump <- function(extra_msg = NULL) {
    if (verbose) {
      step <<- step + 1L
      utils::setTxtProgressBar(pb, step)
      if (!is.null(extra_msg)) {
        cat(sprintf("\n%s\n", extra_msg))
        utils::flush.console()
      }
    }
  }

  results <- list(); idx <- 1L

  for (bs in basis_set) {
    for (deg in degree_vec) {
      for (nk in n_int_vec) {
        knots <- make_knots(nk)
        for (lnk in link_set) {

          if (verbose) {
            cat(sprintf(
              "\n[Config] basis=%s, degree=%s, n_int=%s, link=%s\n",
              bs, deg, nk, lnk
            ))
            utils::flush.console()
          }

          PE_sum <- 0

          for (k in seq_len(K)) {
            if (verbose) {
              cat(sprintf("  - Fold %d/%d ... ", k, K))
              utils::flush.console()
            }

            train <- data[data$.fold != k, , drop = FALSE]
            test  <- data[data$.fold == k, , drop = FALSE]

            fit_k <- WA_fit(
              formula    = formula,
              data       = train,
              id         = id,
              cluster    = cluster,
              knots      = knots,
              tau_grid   = tau_grid,
              basis      = bs,
              degree     = deg,
              link       = lnk,
              w_recur    = w_recur,
              w_term     = w_term,
              ipcw       = ipcw,
              ipcw_formula = ipcw_formula
            )

            par_tr <- .WA_parse_formula(formula, train)
            tr_df <- train %>%
              dplyr::mutate(
                .id     = .data[[id]],
                .time   = par_tr$time_vec,
                .status = as.integer(par_tr$status_vec)
              ) %>%
              dplyr::group_by(.data$.id) %>%
              dplyr::summarise(
                dplyr::across(dplyr::all_of(par_tr$Z_cols), ~ dplyr::first(.x)),
                obs_T = max(.data$.time, na.rm = TRUE),
                Delta = {
                  i_last <- which.max(.data$.time)
                  as.integer(.data$.status[i_last] == max(.data$.status, na.rm = TRUE))
                },
                .groups = "drop"
              )

            ipcw_fit_k <- .WA_ipcw_fit(tr_df, method = ipcw, ipcw_formula = ipcw_formula)

            PE_k <- WA_PE(
              fit = fit_k, formula = formula, data_test = test, id = id,
              w_recur = w_recur, w_term = w_term,
              ipcw_fit = ipcw_fit_k, tau_grid = tau_grid
            )
            PE_sum <- PE_sum + PE_k

            if (verbose) {
              cat("done.\n")
              utils::flush.console()
            }
            bump()
          }

          results[[idx]] <- data.frame(
            basis = bs, degree = deg, n_int = nk, link = lnk,
            PE = PE_sum, stringsAsFactors = FALSE
          )
          idx <- idx + 1L
        }
      }
    }
  }

  if (verbose) cat("\nAll folds complete.\n")

  out <- do.call(rbind, results)
  rownames(out) <- NULL
  out[order(out$PE), ]
}

#' @keywords internal
#' @noRd
WA_PE <- function(fit,
                  formula,
                  data_test,
                  id,
                  w_recur, w_term,
                  ipcw_fit,
                  tau_grid = fit$tau_grid) {

  par_pe      <- .WA_parse_formula(formula, data_test)
  Xmm         <- par_pe$Xmm
  Z_cols      <- par_pe$Z_cols
  time_name   <- par_pe$time_name
  status_name <- par_pe$status_name
  if (length(Z_cols) == 0L) stop("No covariates on RHS for TEST data.")

  test_df <- dplyr::bind_cols(
    tibble::tibble(
      .id     = data_test[[id]],
      .time   = data_test[[time_name]],
      .status = as.integer(data_test[[status_name]])
    ),
    as.data.frame(Xmm)
  )

  s_max     <- max(test_df$.status, na.rm = TRUE)
  rec_types <- setdiff(sort(unique(test_df$.status)), c(0L, s_max))

  subj <- test_df %>%
    dplyr::group_by(.data$.id) %>%
    dplyr::summarise(
      dplyr::across(dplyr::all_of(Z_cols), ~ dplyr::first(.x)),
      obs_T = max(.data$.time, na.rm = TRUE),
      Delta = { i_last <- which.max(.data$.time); as.integer(.data$.status[i_last] == s_max) },
      .groups = "drop"
    )

  uid <- subj$.id
  int_each <- numeric(length(uid)); names(int_each) <- as.character(uid)
  eps <- .Machine$double.eps
  scalar_prev <- NULL; tau_prev <- NULL

  recs <- test_df %>%
    dplyr::filter(.data$.status %in% c(rec_types, s_max)) %>%
    dplyr::select(dplyr::all_of(c(".id", ".status", ".time")))

  for (tau in tau_grid) {
    X_min_tau <- pmin(subj$obs_T, tau)
    V_i_tau   <- as.integer((subj$obs_T <= tau) & (subj$Delta == 1L)) + as.integer(subj$obs_T > tau)

    grid <- tibble::tibble(.id = uid, tau = tau, X_min_tau = X_min_tau)

    counts <- recs %>%
      dplyr::right_join(grid, by = ".id", relationship = "many-to-many") %>%
      dplyr::filter(.data$.time <= .data$X_min_tau) %>%
      dplyr::group_by(.data$.id, .data$tau, .data$X_min_tau, .data$.status) %>%
      dplyr::summarise(n = dplyr::n(), .groups = "drop")

    if (nrow(counts)) {
      counts_wide <- tidyr::pivot_wider(
        counts,
        id_cols = c(".id", "tau", "X_min_tau"),
        names_from = ".status",
        values_from = "n",
        values_fill = 0L,
        names_prefix = "N_"
      )
    } else {
      counts_wide <- grid
    }

    tmp <- subj %>%
      dplyr::mutate(tau = tau, X_min_tau = X_min_tau, V_i_tau = V_i_tau) %>%
      dplyr::left_join(counts_wide, by = c(".id","tau","X_min_tau")) %>%
      { . <- .; for (s in rec_types) if (!paste0("N_",s) %in% names(.)) .[[paste0("N_",s)]] <- 0L; . } %>%
      dplyr::mutate(N_term_tau = as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L)))

    L_recur <- 0
    for (j in seq_along(rec_types)) L_recur <- L_recur + w_recur[j] * tmp[[paste0("N_", rec_types[j])]]
    tmp$L <- L_recur + w_term * tmp$N_term_tau

    tmp$G_X_min_tau <- .WA_ipcw_predict_G(ipcw_fit, tmp$X_min_tau, newdata = tmp)

    Xnew <- as.data.frame(subj[, Z_cols, drop = FALSE])
    X <- .WA_design_at_t(Xnew, tau, Z_cols, fit$knots, fit$basis, fit$degree)
    beta_ord <- fit$est[colnames(X)]
    eta <- as.vector(X %*% beta_ord)
    lf  <- .WA_link(fit$link)
    mu  <- lf$linkinv(eta)

    resid <- tmp$L - mu * tmp$X_min_tau
    fac   <- tmp$V_i_tau / pmax(tmp$G_X_min_tau, eps)
    scalar_now <- fac * (resid^2)

    if (is.null(scalar_prev)) {
      dt <- tau
      int_each <- int_each + (scalar_now) * dt / 2
    } else {
      dt <- tau - tau_prev
      int_each <- int_each + (scalar_now + scalar_prev) * dt / 2
    }
    scalar_prev <- scalar_now; tau_prev <- tau
  }

  sum(int_each, na.rm = TRUE)
}

Try the WAreg package in your browser

Any scripts or data that you put into this service are public.

WAreg documentation built on March 6, 2026, 5:07 p.m.