R/cv.coxkl_enet.R

Defines functions cv.coxkl_enet

Documented in cv.coxkl_enet

#' Cross-Validation for CoxKL Model with elastic net & lasso penalty
#'
#' This function performs cross-validation on the high-dimensional Cox model with
#' Kullback–Leibler (KL) penalty.
#' It tunes the parameter \code{eta} (external information weight) using user-specified
#' cross-validation criteria, while also evaluating a \code{lambda} path
#' (either provided or generated) and selecting the best \code{lambda} per \code{eta}.
#'
#' @param z Numeric matrix of covariates with rows representing individuals and
#'   columns representing predictors.
#' @param delta Numeric vector of event indicators (1 = event, 0 = censored).
#' @param time Numeric vector of observed times (event or censoring).
#' @param stratum Optional factor or numeric vector indicating strata.
#' @param RS Optional numeric vector or matrix of external risk scores. If not provided,
#'   \code{beta} must be supplied.
#' @param beta Optional numeric vector of external coefficients (length equal to
#'   \code{ncol(z)}). If not provided, \code{RS} must be supplied.
#' @param etas Numeric vector of candidate \code{eta} values to be evaluated.
#' @param alpha Elastic-net mixing parameter in \eqn{(0,1]}. Default = \code{1}
#'   (lasso penalty).
#' @param lambda Optional numeric scalar or vector of penalty parameters. If \code{NULL},
#'   a decreasing path is generated using \code{nlambda} and \code{lambda.min.ratio}.
#' @param nlambda Integer number of lambda values to generate when \code{lambda} is
#'   \code{NULL}. Default \code{100}.
#' @param lambda.min.ratio Ratio of the smallest to the largest lambda when
#'   generating a sequence (when \code{lambda} is \code{NULL}). Default \code{0.05} when \code{n < p},
#'   otherwise \code{1e-3}.
#' @param nfolds Integer; number of cross-validation folds. Default = \code{5}.
#' @param cv.criteria Character string specifying the cross-validation criterion. Choices are:
#'   \itemize{
#'     \item \code{"V&VH"} (default): \code{"V&VH"} loss.
#'     \item \code{"LinPred"}: loss based on cross-validated linear predictors.
#'     \item \code{"CIndex_pooled"}: pool all held-out predictions and compute one overall C-index.
#'     \item \code{"CIndex_foldaverage"}: average C-index across folds.
#'   }
#' @param c_index_stratum Optional stratum vector. Used only when
#'   \code{cv.criteria} is \code{"CIndex_pooled"} or \code{"CIndex_foldaverage"} to compute a
#'   stratified C-index while the fitted model is non-stratified; if supplied, it must be
#'   identical to \code{stratum}. Default \code{NULL}.
#' @param message Logical; whether to print progress messages. Default = \code{FALSE}.
#' @param seed Optional integer random seed for fold assignment.
#' @param ... Additional arguments passed to \code{\link{coxkl_enet}}.
#'
#' @details
#' Data are sorted by \code{stratum} and \code{time}. External info must be from \code{RS} or
#' \code{beta} (if \code{beta} given with length \code{ncol(z)}, \code{RS = z \%*\% beta}); \code{alpha} \eqn{\in (0,1]}.
#' 
#' For each candidate \code{eta}, a decreasing \code{lambda} path is used (generated from \code{nlambda}/\code{lambda.min.ratio}
#' if \code{lambda = NULL}); CV folds are created by \code{get_fold}. Each fold fits \code{\link{coxkl_enet}}
#' on the training split (full \code{lambda} path) and evaluates the chosen criterion on the test split.
#' 
#' Aggregation follows the code paths for \code{"V&VH"}, \code{"LinPred"}, \code{"CIndex_pooled"}, or \code{"CIndex_foldaverage"}:
#' \itemize{
#'   \item \code{"V&VH"}: sums \code{pl(full) - pl(train)} across folds (reported as loss via \code{Loss = -2 * score}).
#'   \item \code{"LinPred"}: aggregates test-fold linear predictors and evaluates partial log-likelihood on full data (reported as \code{Loss = -2 * score}).
#'   \item \code{"CIndex_pooled"}: pools comparable-pair numerators/denominators across folds to compute one C-index.
#'   \item \code{"CIndex_foldaverage"}: averages the per-fold stratified C-index.
#' }
#' The best \code{lambda} is selected per \code{eta} (min loss / max C-index), and the function returns full results,
#' the per-\code{eta} optimum, corresponding coefficients, and an external baseline from \code{RS}.
#'
#' @return An object of class \code{"cv.coxkl_enet"}:
#' \describe{
#'   \item{\code{integrated_stat.full_results}}{Data frame with columns
#'     \code{eta}, \code{lambda}, and the aggregated CV score for each \code{lambda}
#'     under the chosen \code{cv.criteria}. For loss criteria, an additional
#'     column with the transformed loss (\code{Loss = -2 * score}); for C-index criteria,
#'     a column named \code{CIndex_pooled} or \code{CIndex_foldaverage}.}
#'   \item{\code{integrated_stat.best_per_eta}}{Data frame with the best
#'     \code{lambda} (per \code{eta}) according to the chosen \code{cv.criteria}
#'     (minimizing loss or maximizing C-index).}
#'   \item{\code{integrated_stat.betahat_best}}{Matrix of coefficient vectors
#'     (columns) corresponding to the best \code{lambda} for each \code{eta}.}
#'   \item{\code{external_stat}}{Scalar baseline statistic computed from the external
#'     risk score \code{RS} under the same \code{cv.criteria}.}
#'   \item{\code{criteria}}{The evaluation criterion used (as provided in \code{cv.criteria}).}
#'   \item{\code{alpha}}{The elastic-net mixing parameter used.}
#'   \item{\code{nfolds}}{Number of folds.}
#' }
#'   
#' @examples
#' data(ExampleData_highdim) 
#' 
#' train_dat_highdim <- ExampleData_highdim$train
#' beta_external_highdim <- ExampleData_highdim$beta_external
#' 
#' etas <- generate_eta(method = "exponential", n = 10, max_eta = 100)
#' 
#' cv_res <- cv.coxkl_enet(z = train_dat_highdim$z,
#'                         delta = train_dat_highdim$status,
#'                         time = train_dat_highdim$time,
#'                         stratum = NULL,
#'                         RS = NULL,
#'                         beta = beta_external_highdim,
#'                         etas = etas,
#'                         alpha = 1.0)
#'    
#' @importFrom utils txtProgressBar setTxtProgressBar
#' @export
cv.coxkl_enet <- function(z, delta, time, stratum = NULL, RS = NULL, beta = NULL,
                          etas, alpha = 1.0,
                          lambda = NULL, nlambda = 100, lambda.min.ratio = ifelse(n < p, 0.05, 1e-03),
                          nfolds = 5, 
                          cv.criteria = c("V&VH", "LinPred", "CIndex_pooled", "CIndex_foldaverage"),
                          c_index_stratum = NULL,
                          message = FALSE, seed = NULL, ...) {
  
  ## Input check & data preparation
  if (is.null(etas)) stop("etas must be provided.", call. = FALSE)
  etas <- sort(etas)
  cv.criteria <- match.arg(cv.criteria, choices = c("V&VH", "LinPred", "CIndex_pooled", "CIndex_foldaverage"))
  
  if (alpha > 1 | alpha <= 0) stop("alpha must be in (0,1]", call. = FALSE)
  
  if (is.null(RS) && is.null(beta)) {
    stop("No external information is provided. Either RS or beta must be provided.")
  } else if (is.null(RS) && !is.null(beta)) {
    if (length(beta) != ncol(z)) stop("beta dimension mismatch with z.", call. = FALSE)
    RS <- as.matrix(z) %*% as.matrix(beta)
  } else {
    RS <- as.matrix(RS)
  }
  
  if (is.null(stratum)) {
    warning("Stratum not provided. Treating all data as one stratum.", call. = FALSE)
    stratum <- rep(1, nrow(z))
  } else {
    if (!is.null(c_index_stratum) & !identical(stratum, c_index_stratum)) {
      stop("Provided 'c_index_stratum' not identical to 'stratum'!")
    }
    stratum <- match(stratum, unique(stratum))
  }
  
  ## Sort data
  time_order <- order(stratum, time)
  time <- as.numeric(time[time_order])
  stratum <- as.numeric(stratum[time_order])
  z <- as.matrix(z)[time_order, , drop = FALSE]
  delta <- as.numeric(delta[time_order])
  RS <- RS[time_order, , drop = FALSE]
  n <- nrow(z)
  p <- ncol(z)
  n.each_stratum_full <- as.numeric(table(stratum))
  
  ## CV folds
  if (!is.null(seed)) set.seed(seed)
  folds <- get_fold(nfolds = nfolds, delta = delta, stratum = stratum)
  
  results_list <- list()
  n_eta <- length(etas)
  
  ## External baseline accumulators
  if (cv.criteria == "V&VH") {
    pl_full_RS <- pl_cal_theta(as.vector(RS), delta, n.each_stratum_full)
    ext_vvh_per_fold <- numeric(nfolds)
  } else if (cv.criteria == "CIndex_pooled") {
    ext_numer <- numeric(nfolds); ext_denom <- numeric(nfolds)
  } else if (cv.criteria == "CIndex_foldaverage") {
    ext_c_per_fold <- numeric(nfolds)
  }
  
  if (message) {
    cat("Cross-validation over etas sequence:\n")
    pb_eta <- txtProgressBar(min = 0, max = n_eta, style = 3, width = 30)
  }
  
  fit_beta_list <- vector("list", length(etas))  
  lambda_list <- vector("list", length(etas))  
  
  for (ei in seq_along(etas)) {
    eta <- etas[ei]
    
    ## Determine lambda path (on full data) & estimation on full data
    if (is.null(lambda)) {
      fit0 <- coxkl_enet(z = z, delta = delta, time = time, stratum = stratum,
                            RS = RS, eta = eta, alpha = alpha,
                            lambda = NULL, nlambda = nlambda, lambda.min.ratio = lambda.min.ratio,
                            data_sorted = TRUE, message = FALSE, ...)
      lambda_seq <- as.vector(fit0$lambda)
    } else {
      lambda_seq <- sort(lambda, decreasing = TRUE)
      fit0 <- coxkl_enet(z = z, delta = delta, time = time, stratum = stratum,
                            RS = RS, eta = eta, alpha = alpha,
                            lambda = lambda_seq, data_sorted = TRUE, message = FALSE, ...)
    }
    
    fit_beta_list[[ei]] <- fit0$beta 
    lambda_list[[ei]]   <- lambda_seq 
    
    L <- length(lambda_seq)
    
    ## Accumulators
    if (cv.criteria == "V&VH") {
      vvh_sum <- rep(0, L)
    } else if (cv.criteria == "LinPred") {
      Y <- matrix(NA_real_, nrow = n, ncol = L)
      colnames(Y) <- round(lambda_seq, 6)
    } else if (cv.criteria == "CIndex_pooled") {
      numer <- rep(0, L); denom <- rep(0, L)
    } else if (cv.criteria == "CIndex_foldaverage") {
      csum <- rep(0, L); cnt <- rep(0, L)
    }
    
    ## Loop over folds
    for (f in seq_len(nfolds)) {
      train_idx <- which(folds != f)
      test_idx  <- which(folds == f)
      
      ## External baseline per fold (computed once)
      if (ei == 1) {
        if (cv.criteria == "V&VH") {
          n.each_stratum_train <- as.numeric(table(stratum[train_idx]))
          pl_train_RS <- pl_cal_theta(as.vector(RS[train_idx]), delta[train_idx], n.each_stratum_train)
          ext_vvh_per_fold[f] <- pl_full_RS - pl_train_RS
        } else if (cv.criteria == "CIndex_pooled" || cv.criteria == "CIndex_foldaverage") {
          if (is.null(c_index_stratum)) stratum_test <- stratum[test_idx] else stratum_test <- c_index_stratum[test_idx]
          cstat_ext <- c_stat_stratcox(time[test_idx], as.vector(RS[test_idx]), stratum_test, delta[test_idx])
          if (cv.criteria == "CIndex_pooled") {
            ext_numer[f] <- cstat_ext$numer
            ext_denom[f] <- cstat_ext$denom
          } else {
            ext_c_per_fold[f] <- cstat_ext$c_statistic
          }
        }
      }
      
      ## Fit full lambda path on training fold
      fit_f <- coxkl_enet(z = z[train_idx, , drop = FALSE],
                             delta = delta[train_idx],
                             time = time[train_idx],
                             stratum = stratum[train_idx],
                             RS = RS[train_idx, , drop = FALSE],
                             eta = eta, alpha = alpha, lambda = lambda_seq,
                             data_sorted = TRUE, message = FALSE, ...)
      
      beta_mat <- fit_f$beta
      LP_train <- z[train_idx, ] %*% beta_mat
      LP_all   <- z %*% beta_mat
      LP_test  <- z[test_idx, ] %*% beta_mat
      
      ## Compute fold-wise CV metric
      if (cv.criteria == "V&VH") {
        n.each_stratum_train <- as.numeric(table(stratum[train_idx]))
        n.each_stratum_all   <- as.numeric(table(stratum))
        pl_all <- apply(LP_all,  2, function(col) pl_cal_theta(col, delta,        n.each_stratum_all))
        pl_tr  <- apply(LP_train,2, function(col) pl_cal_theta(col, delta[train_idx], n.each_stratum_train))
        vvh_sum <- vvh_sum + (pl_all - pl_tr)
      } else if (cv.criteria == "LinPred") {
        Y[test_idx, ] <- LP_test
      } else {
        if (is.null(c_index_stratum)) stratum_test <- stratum[test_idx] else stratum_test <- c_index_stratum[test_idx]
        numer_vec <- denom_vec <- cstat_vec <- rep(NA_real_, ncol(LP_test))
        for (j in seq_len(ncol(LP_test))) {
          cstat_j <- c_stat_stratcox(time[test_idx], LP_test[, j], stratum_test, delta[test_idx])
          numer_vec[j] <- cstat_j$numer
          denom_vec[j] <- cstat_j$denom
          cstat_vec[j] <- cstat_j$c_statistic
        }
        if (cv.criteria == "CIndex_pooled") {
          numer <- numer + numer_vec
          denom <- denom + denom_vec
        } else {
          csum <- csum + cstat_vec
          cnt  <- cnt  + 1
        }
      }
    }
    
    ## Aggregate folds for this eta
    if (cv.criteria == "V&VH") {
      cve_eta <- vvh_sum
    } else if (cv.criteria == "LinPred") {
      Lmat <- loss.coxkl_highdim(delta, Y, stratum, total = FALSE)
      cve_eta <- colSums(Lmat)
    } else if (cv.criteria == "CIndex_pooled") {
      cve_eta <- numer / denom
    } else {
      cve_eta <- csum / pmax(cnt, 1)
    }
    
    results_list[[ei]] <- data.frame(
      eta = eta,
      lambda = lambda_seq,
      score = cve_eta,
      stringsAsFactors = FALSE
    )
    if (message) setTxtProgressBar(pb_eta, ei)
  }
  
  if (message) close(pb_eta)
  
  results_df <- do.call(rbind, results_list)
  
  ## Best per eta
  if (cv.criteria %in% c("V&VH", "LinPred")) {
    results_df$Loss <- -2 * results_df$score  / n
    results_df$score <- NULL
    best_per_eta <- do.call(rbind, lapply(split(results_df, results_df$eta), function(df) df[which.min(df$Loss), , drop = FALSE]))
  } else if (cv.criteria == "CIndex_pooled") {
    results_df$CIndex_pooled <- results_df$score; results_df$score <- NULL
    best_per_eta <- do.call(rbind, lapply(split(results_df, results_df$eta), function(df) df[which.max(df$CIndex_pooled), , drop = FALSE]))
  } else {
    results_df$CIndex_foldaverage <- results_df$score; results_df$score <- NULL
    best_per_eta <- do.call(rbind, lapply(split(results_df, results_df$eta), function(df) df[which.max(df$CIndex_foldaverage), , drop = FALSE]))
  }
  rownames(best_per_eta) <- NULL
  
  beta_best_mat <- sapply(seq_along(etas), function(i) {
    beta_mat   <- fit_beta_list[[i]]
    lambda_seq <- lambda_list[[i]]
    
    idx <- which(abs(lambda_seq - best_per_eta$lambda[i]) < 1e-12)
    if (length(idx) != 1) {
      idx <- which.min(abs(lambda_seq - best_per_eta$lambda[i]))
    }
    beta_mat[, idx]
  })
  
  colnames(beta_best_mat) <- etas
  
  ## External baseline
  external_stat <- switch(cv.criteria,
                          "V&VH" = -2 * sum(ext_vvh_per_fold),
                          "LinPred" = -2 * pl_cal_theta(as.vector(RS), delta, n.each_stratum_full),
                          "CIndex_pooled" = sum(ext_numer) / sum(ext_denom),
                          "CIndex_foldaverage" = mean(ext_c_per_fold)
  )
  
  structure(
    list(
      integrated_stat.full_results = results_df,
      integrated_stat.best_per_eta = best_per_eta,
      integrated_stat.betahat_best = beta_best_mat,
      external_stat = external_stat,
      criteria = cv.criteria,
      alpha = alpha,
      nfolds = nfolds
    ),
    class = "cv.coxkl_enet"
  )
}

Try the survkl package in your browser

Any scripts or data that you put into this service are public.

survkl documentation built on April 22, 2026, 1:08 a.m.