Nothing
# R/main.R ------------------------------------------------------------
#' While-Alive Regression (WA) for Composite Endpoints
#'
#' Fits the while-alive regression model targeting the while-alive loss rate
#' for composite endpoints with recurrent and terminal events. Time-varying
#' covariate effects are represented via user-chosen time bases (e.g., B-spline,
#' piecewise polynomial, interval-local). Robust inference supports
#' cluster-randomized trials (CRTs) via cluster-robust variance; if
#' \code{cluster = NULL}, IID (subject-as-cluster) variance is used.
#'
#' @param formula A \code{Surv(time, status) ~ RHS} formula. \code{time} and
#' \code{status} must exist in \code{data}. The RHS contains baseline
#' covariates (no explicit time-varying covariates here; time-variation is
#' induced via the chosen basis).
#' @param data Long-format data frame with one row per \emph{event/checkpoint}
#' per subject, containing \code{time}, \code{status}, \code{id}, optional
#' \code{cluster}, and RHS covariates.
#' @param id Character scalar; subject ID column name.
#' @param cluster Optional character scalar; cluster column name for CRT-robust
#' inference. If \code{NULL}, IID inference treats each subject as its own cluster.
#' @param knots Numeric vector (length \eqn{\ge 2}) specifying the basis
#' boundaries and optional interior knots that define the time basis shape.
#' @param tau_grid Numeric vector of evaluation times used to stack the
#' estimating equations. Independent of \code{knots}.
#' @param basis One of \code{"il","pl","bz","ns","ms","st","tl","tf"}:
#' interval-local (\code{"il"}), piecewise polynomial (\code{"pl"}),
#' B-spline (\code{"bz"}), natural spline (\code{"ns"}), M-spline
#' (\code{"ms"}, requires \pkg{splines2}), step (\code{"st"}), truncated
#' linear (\code{"tl"}), or time-fixed (\code{"tf"}).
#' @param degree Integer degree for bases that use it (e.g., \code{"bz"}, \code{"pl"}, \code{"ns"}, \code{"ms"}).
#' @param link Link function: \code{"log"} (default) or \code{"identity"}.
#' @param w_recur Numeric vector of weights for each recurrent event type. Its
#' length must match the number of recurrent \code{status} codes in
#' \code{data} (i.e., excluding \code{0} for censoring and the max code for terminal).
#' @param w_term Numeric scalar; weight for the terminal event.
#' @param ipcw IPCW method: \code{"km"} or \code{"cox"}.
#' @param ipcw_formula A one-sided formula specifying RHS covariates for the IPCW Cox model
#' when \code{ipcw = "cox"} (e.g., \code{~ x1 + x2}). Ignored for \code{ipcw = "km"}.
#'
#' @details
#' The estimating equations solve \eqn{E[Z(t)\{L(t) - \mu_\beta(t)X_{\min}(t)\}V/G]=0}
#' over \code{tau_grid}, where \eqn{L(t)} is the weighted composite loss
#' (recurrent+terminal), \eqn{\mu_\beta(t)} the while-alive loss rate under the chosen
#' link, \eqn{X_{\min}(t) = \min(T, t)}, \eqn{V} the at-risk/terminal indicator, and
#' \eqn{G} the censoring survival modeled via \code{ipcw}.
#'
#' @return An object of class \code{"WA"} with elements:
#' \itemize{
#' \item \code{est}: named coefficient vector.
#' \item \code{vcov}: cluster-robust variance matrix.
#' \item \code{se}: standard errors.
#' \item \code{converged}: logical.
#' \item \code{basis}, \code{degree}, \code{link}, \code{Z_cols},
#' \code{knots}, \code{tau_grid}, \code{id_var}, \code{cluster_var},
#' \code{w_recur}, \code{w_term}, \code{status_codes}, \code{formula}.
#' }
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(
#' survival::Surv(time, status) ~ trt + Z1 + Z2,
#' data = ex_dt,
#' id = "id",
#' cluster = "cluster",
#' knots = seq(0, max(ex_dt$time, na.rm = TRUE), length.out = 6),
#' tau_grid = seq(0, max(ex_dt$time, na.rm = TRUE), length.out = 6),
#' basis = "bz", degree = 1, link = "log",
#' w_recur = c(1, 1), w_term = 2,
#' ipcw = "km"
#' )
#' s <- summary(fit)
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' plot(fit, newdata = nd,
#' t_seq = seq(0, max(fit$tau_grid), length.out = 200),
#' id = 1, mode = "wa", smooth = TRUE)
#' }
#'
#' @export
WA_fit <- function(formula,
data,
id,
cluster = NULL,
knots,
tau_grid,
basis = c("il","pl","bz","ns","ms","st","tl","tf"),
degree = 1,
link = c("log","identity"),
w_recur,
w_term,
ipcw = c("km","cox"),
ipcw_formula = ~ 1) {
basis <- match.arg(basis)
link <- match.arg(link)
ipcw <- match.arg(ipcw)
par <- .WA_parse_formula(formula, data)
time_vec <- par$time_vec
status_vec <- par$status_vec
Xmm <- par$Xmm
Z_cols <- par$Z_cols
if (length(Z_cols) == 0L) stop("No covariates specified on RHS.")
if (!id %in% names(data)) stop("id column not found in data.")
if (!is.null(cluster) && !cluster %in% names(data)) stop("cluster column not found in data.")
Z_df <- as.data.frame(Xmm); names(Z_df) <- Z_cols
dat <- data.frame(hold = seq_along(time_vec))
dat$.time <- time_vec
dat$.status <- status_vec
dat$.id <- data[[id]]
dat$.cluster <- if (!is.null(cluster)) data[[cluster]] else data[[id]]
dat <- cbind(dat, Z_df)[, -1]
# drop any columns with NAs
dat <- dat %>% dplyr::select(dplyr::where(~ !any(is.na(.))))
ustat <- sort(unique(dat$.status))
if (!any(ustat == 0L)) stop("Status must include 0 for censoring.")
s_max <- max(ustat, na.rm = TRUE)
rec_types <- setdiff(ustat, c(0L, s_max))
if (length(w_recur) != length(rec_types)) {
stop("Length(w_recur) must equal # recurrent event types (status in {",
paste(rec_types, collapse=","), "}).")
}
subj <- dat %>%
dplyr::group_by(.data$.id, .data$.cluster) %>%
dplyr::summarise(
dplyr::across(dplyr::all_of(Z_cols), ~ dplyr::first(.x)),
obs_T = max(.data$.time, na.rm = TRUE),
Delta = { idx <- which.max(.data$.time); as.integer(.data$.status[idx] == s_max) },
.groups = "drop"
)
expanded <- subj %>%
tidyr::crossing(tau = tau_grid) %>%
dplyr::mutate(
X_min_tau = pmin(.data$obs_T, .data$tau),
V_i_tau = as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L)) + as.integer(.data$obs_T > .data$tau),
N_term_tau= as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L))
)
if (ipcw == "km") {
km_c <- survival::survfit(survival::Surv(time = subj$obs_T, event = 1 - subj$Delta) ~ 1)
expanded$G_X_min_tau <- stats::approx(
x = km_c$time, y = km_c$surv,
xout = expanded$X_min_tau,
method = "constant", f = 0,
yleft = 1, yright = utils::tail(km_c$surv, 1)
)$y
} else {
rhs <- paste(deparse(ipcw_formula[[2]]), collapse = "")
cfit <- survival::coxph(stats::as.formula(paste0("Surv(obs_T, I(1-Delta)) ~ ", rhs)),
data = subj, ties = "breslow", x = TRUE, y = TRUE, model = FALSE)
bh <- survival::basehaz(cfit, centered = FALSE)
Lambda0 <- stats::approxfun(bh$time, bh$hazard, method = "linear", rule = 2)
mm <- stats::model.matrix(stats::as.formula(paste0("~ -1 + ", rhs)), data = subj)
beta_c <- stats::coef(cfit)
if (length(beta_c)) mm <- mm[, names(beta_c), drop = FALSE]
lp <- if (length(beta_c)) drop(mm %*% beta_c) else 0
subj$lp_cens <- lp
expanded <- dplyr::left_join(expanded, subj[, c(".id", "lp_cens")], by = c(".id" = ".id"))
expanded$G_X_min_tau <- exp(-Lambda0(expanded$X_min_tau) * exp(expanded$lp_cens))
expanded$G_X_min_tau[!is.finite(expanded$G_X_min_tau)] <- 0
}
recs <- dat %>%
dplyr::filter(.data$.status %in% c(rec_types, s_max)) %>%
dplyr::select(dplyr::all_of(c(".id", ".status", ".time")))
grid <- expanded %>% dplyr::select(dplyr::all_of(c(".id", "tau", "X_min_tau")))
counts <- recs %>%
dplyr::right_join(grid, by = ".id", relationship = "many-to-many") %>%
dplyr::filter(.data$.time <= .data$X_min_tau) %>%
dplyr::group_by(.data$.id, .data$tau, .data$X_min_tau, .data$.status) %>%
dplyr::summarise(n = dplyr::n(), .groups = "drop")
if (nrow(counts)) {
counts_wide <- tidyr::pivot_wider(
counts,
id_cols = c(".id", "tau", "X_min_tau"),
names_from = ".status",
values_from = "n",
values_fill = 0L,
names_prefix = "N_"
)
expanded <- dplyr::left_join(expanded, counts_wide,
by = c(".id","tau","X_min_tau"))
}
for (s in rec_types) {
nm <- paste0("N_", s)
if (!nm %in% names(expanded)) expanded[[nm]] <- 0L
}
L_recur <- 0
for (j in seq_along(rec_types)) {
L_recur <- L_recur + w_recur[j] * expanded[[paste0("N_", rec_types[j])]]
}
expanded$L <- L_recur + w_term * expanded$N_term_tau
expanded$L <- ifelse(is.na(expanded$L), 0, expanded$L)
tb <- .WA_design_Z(expanded, Z_cols = Z_cols, knots = knots,
basis = basis, degree = degree, include_intercept = FALSE)
expanded <- tb$data
cov_pat <- tb$cov_pattern
p <- length(grep(cov_pat, names(expanded)))
beta_init <- rep(0, p)
sol <- nleqslv::nleqslv(
x = beta_init,
fn = .WA_ee,
data = expanded,
cov_pattern = cov_pat,
L_col="L", Dmin_col="X_min_tau", V_col="V_i_tau", G_col="G_X_min_tau",
link = link,
method = "Broyden"
)
if (sol$termcd > 3) warning("WA_fit: nleqslv may not have converged (termcd=", sol$termcd, ").")
beta_hat <- sol$x
V <- .WA_var(
data = expanded, beta = beta_hat, cov_pattern = cov_pat,
L_col="L", Dmin_col="X_min_tau", V_col="V_i_tau", G_col="G_X_min_tau",
id_col = ".id", cluster_col = if (is.null(cluster)) NULL else ".cluster",
link = link
)
beta_names <- grep(cov_pat, names(expanded), value = TRUE)
names(beta_hat) <- beta_names
dimnames(V) <- list(beta_names, beta_names)
structure(list(
call = match.call(),
est = beta_hat,
vcov = V,
se = sqrt(pmax(diag(V), 0)),
converged = (sol$termcd <= 3),
basis = basis,
degree = degree,
link = link,
Z_cols = Z_cols,
knots = knots,
tau_grid = tau_grid,
id_var = id,
cluster_var = cluster,
w_recur = w_recur,
w_term = w_term,
status_codes= list(recurrent = rec_types, terminal = s_max, censor = 0),
formula = formula
), class = "WA")
}
# R/summary-print.R ----------------------------------------------------
#' @export
print.WA <- function(x, ...) {
cat("While-Alive Regression (WA)\n")
cat(" basis:", x$basis, " degree:", x$degree, " link:", x$link, "\n")
cat(" #coef:", length(x$est), " converged:", isTRUE(x$converged), "\n")
invisible(x)
}
#' Summarize a WA object
#'
#' @param object A \code{"WA"} object from \code{\link{WA_fit}}.
#' @param ... Unused.
#'
#' @return An object of class \code{"summary.WA"} containing configuration and a
#' coefficient table with estimates, standard errors, and z-scores.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#' data = ex_dt, id="id", cluster="cluster",
#' knots=seq(0, max(ex_dt$time), length.out=6),
#' tau_grid=seq(0, max(ex_dt$time), length.out=6),
#' basis="bz", degree=1, link="log",
#' w_recur=c(1,1), w_term=2, ipcw="km")
#' summary(fit)
#' }
#' @export
summary.WA <- function(object, ...) {
coeftab <- data.frame(
coef = object$est,
se = object$se,
z = ifelse(object$se > 0, object$est/object$se, NA_real_),
row.names = names(object$est)
)
res <- list(
basis = object$basis,
degree = object$degree,
link = object$link,
knots = object$knots,
tau_grid = object$tau_grid,
status_codes = object$status_codes,
w_recur = object$w_recur,
w_term = object$w_term,
coef = coeftab
)
class(res) <- "summary.WA"
res
}
#' @export
print.summary.WA <- function(x, ...) {
cat("While-Alive Regression summary\n")
cat(" basis:", x$basis, " degree:", x$degree, " link:", x$link, "\n")
cat(" knots:", paste(format(x$knots), collapse = ", "), "\n")
cat(" \u03C4-grid length:", length(x$tau_grid), "\n")
cat(" status codes: recurrent={", paste(x$status_codes$recurrent, collapse=","), "}, terminal=",
x$status_codes$terminal, ", censor=0\n", sep = "")
cat(" weights: recur=", paste(x$w_recur, collapse=","), " terminal=", x$w_term, "\n\n", sep="")
print(x$coef)
invisible(x)
}
# R/predict.R ----------------------------------------------------------
#' Predict while-alive loss rates
#'
#' @param object A \code{"WA"} object.
#' @param newdata Data frame with columns matching the RHS of the fitted model.
#' Predictions are computed for the rows of \code{newdata}.
#' @param t_seq Numeric vector of times at which to evaluate predictions.
#' @param level Confidence level for pointwise intervals (default 0.95).
#' @param ... Unused.
#'
#' @return A data frame with columns \code{id} (row index in \code{newdata}),
#' \code{t}, \code{mu} (predicted while-alive rate), and CI columns \code{lb}, \code{ub}.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#' data = ex_dt, id="id", cluster="cluster",
#' knots=seq(0, max(ex_dt$time), length.out=6),
#' tau_grid=seq(0, max(ex_dt$time), length.out=6),
#' basis="bz", degree=1, link="log",
#' w_recur=c(1,1), w_term=2, ipcw="km")
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' pred <- predict(fit, newdata = nd, t_seq = seq(0, max(fit$tau_grid), by = 0.2))
#' head(pred)
#' }
#' @export
predict.WA <- function(object, newdata, t_seq, level = 0.95, ...) {
stopifnot(!is.null(object$est), !is.null(object$vcov))
V <- object$vcov
lf <- .WA_link(object$link)
tf <- stats::terms(object$formula, data = newdata)
Xmm <- stats::model.matrix(stats::delete.response(tf), data = newdata)
if (ncol(Xmm) > 0 && colnames(Xmm)[1] == "(Intercept)")
Xmm <- Xmm[, -1, drop = FALSE]
Z_cols <- colnames(Xmm)
if (length(Z_cols) == 0L) stop("newdata lacks covariates specified in the model formula.")
Z_df <- as.data.frame(Xmm)
zcrit <- stats::qnorm(0.5 + level/2)
rows <- vector("list", length(t_seq))
for (i in seq_along(t_seq)) {
X <- .WA_design_at_t(Z_df, t_seq[i], Z_cols, object$knots, object$basis, object$degree)
beta_ord <- object$est[colnames(X)]
eta <- as.vector(X %*% beta_ord)
mu <- lf$linkinv(eta)
VX <- X %*% V[colnames(X), colnames(X), drop = FALSE]
var_eta <- rowSums(VX * X)
se_eta <- sqrt(pmax(var_eta, 0))
if (object$link == "log") {
lb <- mu * exp(-zcrit * se_eta)
ub <- mu * exp( zcrit * se_eta)
} else {
lb <- pmax(mu - zcrit*se_eta, 0)
ub <- pmax(mu + zcrit*se_eta, 0)
}
rows[[i]] <- data.frame(id = seq_len(nrow(newdata)), t = t_seq[i],
mu = mu, lb = lb, ub = ub)
}
do.call(rbind, rows)
}
# R/plot.R -------------------------------------------------------------
#' Plot while-alive trajectory or a covariate's time-varying effect
#'
#' @param x A \code{"WA"} object.
#' @param newdata Data used to rebuild the RHS design (same columns as in the model).
#' @param t_seq Times to plot over (numeric vector).
#' @param id Row index of \code{newdata} to use for the while-alive trajectory (mode = "wa").
#' @param mode \code{"wa"} to plot the while-alive loss rate, or \code{"cov"} to
#' plot a specific covariate's time-varying effect.
#' @param covariate Character; covariate name (must appear on RHS) when \code{mode="cov"}.
#' @param ylab_wa Y-axis label for while-alive plot.
#' @param ylab_cov Y-axis label for covariate-effect plot; default
#' \code{"Effect of <covariate> on \u03B7(t)"}.
#' @param xlab X-axis label.
#' @param level Confidence level for ribbons (default 0.95).
#' @param smooth Logical; if \code{TRUE}, apply LOESS smoothing to the displayed curve/CI.
#' @param span LOESS span used when \code{smooth=TRUE}.
#' @param ... Unused.
#'
#' @return A \pkg{ggplot2} object.
#'
#' @examples
#' \donttest{
#' ex_dt <- crt_dt[crt_dt$cluster %in% c(1,2,3,4,7,10), ]
#' fit <- WA_fit(survival::Surv(time, status) ~ trt + Z1 + Z2,
#' data = ex_dt, id="id", cluster="cluster",
#' knots=seq(0, max(ex_dt$time), length.out=6),
#' tau_grid=seq(0, max(ex_dt$time), length.out=6),
#' basis="bz", degree=1, link="log",
#' w_recur=c(1,1), w_term=2, ipcw="km")
#' nd <- unique(ex_dt[, c("trt","Z1","Z2")])
#' plot(fit, newdata = nd,
#' t_seq = seq(0, max(fit$tau_grid), length.out = 200),
#' id = 1, mode = "wa", smooth = TRUE)
#' }
#' @export
plot.WA <- function(x, newdata, t_seq, id = 1,
mode = c("wa","cov"),
covariate = NULL,
ylab_wa = "While-alive loss rate",
ylab_cov = NULL,
xlab = "Time",
level = 0.95,
smooth = FALSE,
span = 0.30,
...) {
mode <- match.arg(mode)
tf <- stats::terms(x$formula, data = newdata)
Xmm <- stats::model.matrix(stats::delete.response(tf), data = newdata)
if (ncol(Xmm) > 0 && colnames(Xmm)[1] == "(Intercept)")
Xmm <- Xmm[, -1, drop = FALSE]
Z_cols <- colnames(Xmm)
if (length(Z_cols) == 0L) stop("newdata lacks covariates specified in the model formula.")
if (id < 1 || id > nrow(newdata)) stop("`id` out of range for `newdata`.")
maybe_smooth <- function(df, y, ymin, ymax) {
if (!smooth) return(df)
lo <- stats::loess(stats::as.formula(paste0(y, "~ t")), data = df, span = span)
lol <- stats::loess(stats::as.formula(paste0(ymin, "~ t")), data = df, span = span)
lou <- stats::loess(stats::as.formula(paste0(ymax, "~ t")), data = df, span = span)
df[[y]] <- stats::predict(lo)
df[[ymin]] <- stats::predict(lol)
df[[ymax]] <- stats::predict(lou)
df
}
if (mode == "wa") {
d <- predict.WA(x, newdata = newdata, t_seq = t_seq, level = level)
d <- d[d$id == id, , drop = FALSE]
d <- maybe_smooth(d, y = "mu", ymin = "lb", ymax = "ub")
return(
ggplot2::ggplot(d, ggplot2::aes(x = .data$t, y = .data$mu)) +
ggplot2::geom_line(linewidth = 0.6) +
ggplot2::geom_ribbon(ggplot2::aes(ymin = .data$lb, ymax = .data$ub), alpha = 0.2) +
ggplot2::labs(x = xlab, y = ylab_wa) +
ggplot2::theme_minimal(base_size = 12) +
ggplot2::theme(panel.grid.minor = ggplot2::element_blank())
)
}
if (is.null(covariate) || !(covariate %in% Z_cols))
stop("Provide a valid `covariate` name present in the model RHS.")
if (is.null(ylab_cov)) ylab_cov <- paste0("Effect of ", covariate, " on \u03B7(t)")
want <- grep(paste0("^", covariate, "_t0_seg[0-9]+$"), names(x$est), value = TRUE)
if (!length(want)) stop("No time-basis coefficients found for ", covariate)
seg <- as.integer(sub(".*_seg([0-9]+)$", "\\1", want))
ord <- order(seg); want <- want[ord]; seg <- seg[ord]
build_basis_match <- function(inc_int) {
B <- .WA_time_basis(
t = t_seq,
knots = x$knots,
basis = x$basis,
degree = x$degree,
include_intercept = inc_int
)
colnames(B) <- sprintf("t0_seg%d", seq_len(ncol(B)))
need_names <- sprintf("t0_seg%d", seg)
idx <- match(need_names, colnames(B))
list(B = B, idx = idx, ok = !any(is.na(idx)))
}
try1 <- build_basis_match(FALSE)
if (try1$ok) { Bfull <- try1$B; idx <- try1$idx } else {
try2 <- build_basis_match(TRUE)
if (!try2$ok) stop("Basis/coef mismatch. Check basis/degree/knots vs fit.")
Bfull <- try2$B; idx <- try2$idx
}
Bz <- Bfull[, idx, drop = FALSE]
b <- as.numeric(x$est[want])
V <- as.matrix(x$vcov[want, want, drop = FALSE])
beta_hat <- as.vector(Bz %*% b)
var_hat <- rowSums((Bz %*% V) * Bz)
var_hat <- pmax(var_hat, 0)
zcrit <- stats::qnorm(0.5 + level/2)
lb <- beta_hat - zcrit * sqrt(var_hat)
ub <- beta_hat + zcrit * sqrt(var_hat)
df <- data.frame(t = t_seq, eff = beta_hat, lb = lb, ub = ub)
df <- maybe_smooth(df, y = "eff", ymin = "lb", ymax = "ub")
ggplot2::ggplot(df, ggplot2::aes(x = .data$t, y = .data$eff)) +
ggplot2::geom_line(linewidth = 0.6) +
ggplot2::geom_ribbon(ggplot2::aes(ymin = .data$lb, ymax = .data$ub), alpha = 0.2) +
ggplot2::labs(x = xlab, y = ylab_cov) +
ggplot2::theme_minimal(base_size = 12) +
ggplot2::theme(panel.grid.minor = ggplot2::element_blank(),
legend.position = "none")
}
# R/WA_cv.R ------------------------------------------------------------
#' K-fold cross-validation for WA configuration selection
#'
#' Runs K-fold CV over a grid of basis types, degrees, interior-knot counts,
#' and link functions. For each configuration, fits the model on K-1 folds and
#' accumulates the prediction error (PE) on the held-out fold using
#' \code{WA_PE()} (IPCW computed on the training subjects).
#'
#' @param formula A \code{Surv(time, status) ~ RHS} formula; see \code{\link{WA_fit}}.
#' @param data Long-format data frame; see \code{\link{WA_fit}}.
#' @param id Character scalar; subject ID column name; see \code{\link{WA_fit}}.
#' @param cluster Optional character scalar; cluster column name; see \code{\link{WA_fit}}.
#' @param basis_set Character vector of candidate bases.
#' @param degree_vec Integer vector of candidate degrees.
#' @param n_int_vec Integer vector of interior-knot counts; 0 means boundaries only.
#' @param knot_scheme \code{"equidist"} or \code{"quantile"} to construct interior knots.
#' @param link_set Character vector of candidate links (subset of \code{c("log","identity")}).
#' @param time_range Optional numeric length-2 vector \code{c(tmin, tmax)}. If \code{NULL},
#' inferred from \code{data}.
#' @param tau_grid Optional numeric vector; if \code{NULL}, a default dense grid over
#' \code{time_range} is created.
#' @param w_recur recurrent-event weights
#' @param w_term Numeric scalar; terminal-event weight; see \code{\link{WA_fit}}.
#' @param ipcw IPCW method (\code{"cox"} or \code{"km"}) for PE computation.
#' @param ipcw_formula One-sided RHS formula for IPCW Cox model (if \code{ipcw="cox"}).
#' @param K Number of folds.
#' @param seed RNG seed for fold assignment.
#' @param verbose Logical; show a text progress bar and per-fold messages.
#'
#' @return A data frame with columns: \code{basis}, \code{degree}, \code{n_int},
#' \code{link}, and aggregated \code{PE}. Lower \code{PE} is better.
#'
#' @export
WA_cv <- function(formula,
data,
id,
cluster = NULL,
basis_set = c("il","pl","bz"),
degree_vec = 1:2,
n_int_vec = c(0, 2, 4),
knot_scheme = c("equidist","quantile"),
link_set = c("log"),
time_range = NULL,
tau_grid = NULL,
w_recur,
w_term,
ipcw = c("cox","km"),
ipcw_formula = ~ 1,
K = 5, seed = 1L,
verbose = TRUE) {
knot_scheme <- match.arg(knot_scheme)
ipcw <- match.arg(ipcw)
par_cv <- .WA_parse_formula(formula, data)
t_obs <- par_cv$time_vec
if (is.null(time_range)) {
tmin <- 0
tmax <- max(t_obs, na.rm = TRUE)
} else {
tmin <- time_range[1]; tmax <- time_range[2]
}
if (is.null(tau_grid)) {
tau_grid <- seq(tmin, tmax - 1e-6, length.out = 150)
}
set.seed(seed)
if (!is.null(cluster) && cluster %in% names(data)) {
keys <- unique(data[[cluster]])
fold_ids <- sample(rep_len(seq_len(K), length(keys)))
names(fold_ids) <- as.character(keys)
data$.fold <- fold_ids[as.character(data[[cluster]])]
} else {
keys <- unique(data[[id]])
fold_ids <- sample(rep_len(seq_len(K), length(keys)))
names(fold_ids) <- as.character(keys)
data$.fold <- fold_ids[as.character(data[[id]])]
}
make_knots <- function(n_int) {
if (n_int <= 0) return(c(tmin, tmax))
if (knot_scheme == "equidist") {
inter <- seq(tmin, tmax, length.out = n_int + 2L)[-c(1, n_int + 2L)]
} else {
qs <- seq(0, 1, length.out = n_int + 2L)[-c(1, n_int + 2L)]
inter <- as.numeric(stats::quantile(t_obs, probs = qs, na.rm = TRUE))
}
sort(unique(c(tmin, inter, tmax)))
}
n_cfg <- length(basis_set) * length(degree_vec) * length(n_int_vec) * length(link_set)
total_steps <- n_cfg * K
step <- 0L
if (verbose) {
pb <- utils::txtProgressBar(min = 0, max = total_steps, style = 3)
on.exit(try(close(pb), silent = TRUE), add = TRUE)
}
bump <- function(extra_msg = NULL) {
if (verbose) {
step <<- step + 1L
utils::setTxtProgressBar(pb, step)
if (!is.null(extra_msg)) {
cat(sprintf("\n%s\n", extra_msg))
utils::flush.console()
}
}
}
results <- list(); idx <- 1L
for (bs in basis_set) {
for (deg in degree_vec) {
for (nk in n_int_vec) {
knots <- make_knots(nk)
for (lnk in link_set) {
if (verbose) {
cat(sprintf(
"\n[Config] basis=%s, degree=%s, n_int=%s, link=%s\n",
bs, deg, nk, lnk
))
utils::flush.console()
}
PE_sum <- 0
for (k in seq_len(K)) {
if (verbose) {
cat(sprintf(" - Fold %d/%d ... ", k, K))
utils::flush.console()
}
train <- data[data$.fold != k, , drop = FALSE]
test <- data[data$.fold == k, , drop = FALSE]
fit_k <- WA_fit(
formula = formula,
data = train,
id = id,
cluster = cluster,
knots = knots,
tau_grid = tau_grid,
basis = bs,
degree = deg,
link = lnk,
w_recur = w_recur,
w_term = w_term,
ipcw = ipcw,
ipcw_formula = ipcw_formula
)
par_tr <- .WA_parse_formula(formula, train)
tr_df <- train %>%
dplyr::mutate(
.id = .data[[id]],
.time = par_tr$time_vec,
.status = as.integer(par_tr$status_vec)
) %>%
dplyr::group_by(.data$.id) %>%
dplyr::summarise(
dplyr::across(dplyr::all_of(par_tr$Z_cols), ~ dplyr::first(.x)),
obs_T = max(.data$.time, na.rm = TRUE),
Delta = {
i_last <- which.max(.data$.time)
as.integer(.data$.status[i_last] == max(.data$.status, na.rm = TRUE))
},
.groups = "drop"
)
ipcw_fit_k <- .WA_ipcw_fit(tr_df, method = ipcw, ipcw_formula = ipcw_formula)
PE_k <- WA_PE(
fit = fit_k, formula = formula, data_test = test, id = id,
w_recur = w_recur, w_term = w_term,
ipcw_fit = ipcw_fit_k, tau_grid = tau_grid
)
PE_sum <- PE_sum + PE_k
if (verbose) {
cat("done.\n")
utils::flush.console()
}
bump()
}
results[[idx]] <- data.frame(
basis = bs, degree = deg, n_int = nk, link = lnk,
PE = PE_sum, stringsAsFactors = FALSE
)
idx <- idx + 1L
}
}
}
}
if (verbose) cat("\nAll folds complete.\n")
out <- do.call(rbind, results)
rownames(out) <- NULL
out[order(out$PE), ]
}
#' @keywords internal
#' @noRd
WA_PE <- function(fit,
formula,
data_test,
id,
w_recur, w_term,
ipcw_fit,
tau_grid = fit$tau_grid) {
par_pe <- .WA_parse_formula(formula, data_test)
Xmm <- par_pe$Xmm
Z_cols <- par_pe$Z_cols
time_name <- par_pe$time_name
status_name <- par_pe$status_name
if (length(Z_cols) == 0L) stop("No covariates on RHS for TEST data.")
test_df <- dplyr::bind_cols(
tibble::tibble(
.id = data_test[[id]],
.time = data_test[[time_name]],
.status = as.integer(data_test[[status_name]])
),
as.data.frame(Xmm)
)
s_max <- max(test_df$.status, na.rm = TRUE)
rec_types <- setdiff(sort(unique(test_df$.status)), c(0L, s_max))
subj <- test_df %>%
dplyr::group_by(.data$.id) %>%
dplyr::summarise(
dplyr::across(dplyr::all_of(Z_cols), ~ dplyr::first(.x)),
obs_T = max(.data$.time, na.rm = TRUE),
Delta = { i_last <- which.max(.data$.time); as.integer(.data$.status[i_last] == s_max) },
.groups = "drop"
)
uid <- subj$.id
int_each <- numeric(length(uid)); names(int_each) <- as.character(uid)
eps <- .Machine$double.eps
scalar_prev <- NULL; tau_prev <- NULL
recs <- test_df %>%
dplyr::filter(.data$.status %in% c(rec_types, s_max)) %>%
dplyr::select(dplyr::all_of(c(".id", ".status", ".time")))
for (tau in tau_grid) {
X_min_tau <- pmin(subj$obs_T, tau)
V_i_tau <- as.integer((subj$obs_T <= tau) & (subj$Delta == 1L)) + as.integer(subj$obs_T > tau)
grid <- tibble::tibble(.id = uid, tau = tau, X_min_tau = X_min_tau)
counts <- recs %>%
dplyr::right_join(grid, by = ".id", relationship = "many-to-many") %>%
dplyr::filter(.data$.time <= .data$X_min_tau) %>%
dplyr::group_by(.data$.id, .data$tau, .data$X_min_tau, .data$.status) %>%
dplyr::summarise(n = dplyr::n(), .groups = "drop")
if (nrow(counts)) {
counts_wide <- tidyr::pivot_wider(
counts,
id_cols = c(".id", "tau", "X_min_tau"),
names_from = ".status",
values_from = "n",
values_fill = 0L,
names_prefix = "N_"
)
} else {
counts_wide <- grid
}
tmp <- subj %>%
dplyr::mutate(tau = tau, X_min_tau = X_min_tau, V_i_tau = V_i_tau) %>%
dplyr::left_join(counts_wide, by = c(".id","tau","X_min_tau")) %>%
{ . <- .; for (s in rec_types) if (!paste0("N_",s) %in% names(.)) .[[paste0("N_",s)]] <- 0L; . } %>%
dplyr::mutate(N_term_tau = as.integer((.data$obs_T <= .data$tau) & (.data$Delta == 1L)))
L_recur <- 0
for (j in seq_along(rec_types)) L_recur <- L_recur + w_recur[j] * tmp[[paste0("N_", rec_types[j])]]
tmp$L <- L_recur + w_term * tmp$N_term_tau
tmp$G_X_min_tau <- .WA_ipcw_predict_G(ipcw_fit, tmp$X_min_tau, newdata = tmp)
Xnew <- as.data.frame(subj[, Z_cols, drop = FALSE])
X <- .WA_design_at_t(Xnew, tau, Z_cols, fit$knots, fit$basis, fit$degree)
beta_ord <- fit$est[colnames(X)]
eta <- as.vector(X %*% beta_ord)
lf <- .WA_link(fit$link)
mu <- lf$linkinv(eta)
resid <- tmp$L - mu * tmp$X_min_tau
fac <- tmp$V_i_tau / pmax(tmp$G_X_min_tau, eps)
scalar_now <- fac * (resid^2)
if (is.null(scalar_prev)) {
dt <- tau
int_each <- int_each + (scalar_now) * dt / 2
} else {
dt <- tau - tau_prev
int_each <- int_each + (scalar_now + scalar_prev) * dt / 2
}
scalar_prev <- scalar_now; tau_prev <- tau
}
sum(int_each, na.rm = TRUE)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.