R/cv-grpreg.R

Defines functions cvf cv.grpreg

Documented in cv.grpreg

#' Cross-validation for grpreg/grpsurv
#' 
#' Performs k-fold cross validation for penalized regression models with
#' grouped covariates over a grid of values for the regularization parameter
#' lambda.
#' 
#' The function calls \code{grpreg}/\code{cv.grpsurv} \code{nfolds} times, each
#' time leaving out 1/\code{nfolds} of the data.  The cross-validation error is
#' based on the deviance;
#' [see here for more details](https://pbreheny.github.io/grpreg/articles/web/models.html).
#' 
#' For Gaussian and Poisson responses, the folds are chosen according to simple
#' random sampling.  For binomial responses, the numbers for each outcome class
#' are balanced across the folds; i.e., the number of outcomes in which
#' \code{y} is equal to 1 is the same for each fold, or possibly off by 1 if
#' the numbers do not divide evenly.  This approach is used for Cox regression
#' as well to balance the amount of censoring cross each fold.
#' 
#' For Cox models, \code{cv.grpsurv} uses the approach of calculating the full
#' Cox partial likelihood using the cross-validated set of linear predictors.
#' Other approaches to cross-validation for the Cox regression model have been
#' proposed in the literature; the strengths and weaknesses of the various
#' methods for penalized regression in the Cox model are the subject of current
#' research.  A simple approximation to the standard error is provided,
#' although an option to bootstrap the standard error (\code{se='bootstrap'})
#' is also available.
#' 
#' As in \code{grpreg}, seemingly unrelated regressions/multitask learning can
#' be carried out by setting \code{y} to be a matrix, in which case groups are
#' set up automatically (see \code{\link{grpreg}} for details), and
#' cross-validation is carried out with respect to rows of \code{y}.  As
#' mentioned in the details there, it is recommended to standardize the
#' responses prior to fitting.
#' 
#' @aliases cv.grpreg cv.grpsurv
#' @param X The design matrix, as in \code{grpreg}/\code{grpsurv}.
#' @param y The response vector (or matrix), as in \code{grpreg}\code{grpsurv}.
#' @param group The grouping vector, as in \code{grpreg}\code{grpsurv}.
#' @param ... Additional arguments to \code{grpreg}\code{grpsurv}.
#' @param nfolds The number of cross-validation folds.  Default is 10.
#' @param seed You may set the seed of the random number generator in order to
#' obtain reproducible results.
#' @param fold Which fold each observation belongs to.  By default the
#' observations are randomly assigned.
#' @param returnY Should \code{cv.grpreg}\code{grpsurv} return the fitted
#' values from the cross-validation folds?  Default is FALSE; if TRUE, this
#' will return a matrix in which the element for row i, column j is the fitted
#' value for observation i from the fold in which observation i was excluded
#' from the fit, at the jth value of lambda.  NOTE: For \code{cv.grpsurv}, the
#' rows of \code{Y} are ordered by time on study, and therefore will not
#' correspond to the original order of observations pased to \code{cv.grpsurv}.
#' @param trace If set to TRUE, cv.grpreg will inform the user of its progress
#' by announcing the beginning of each CV fold.  Default is FALSE.
#' @param se For \code{cv.grpsurv}, the method by which the cross-valiation
#' standard error (CVSE) is calculated.  The 'quick' approach is based on a
#' rough approximation, but can be calculated more or less instantly.  The
#' 'bootstrap' approach is more accurate, but requires additional computing
#' time.
#' @return An object with S3 class \code{"cv.grpreg"} containing:
#' \item{cve}{The error for each value of \code{lambda}, averaged across the
#' cross-validation folds.} \item{cvse}{The estimated standard error associated
#' with each value of for \code{cve}.} \item{lambda}{The sequence of
#' regularization parameter values along which the cross-validation error was
#' calculated.} \item{fit}{The fitted \code{grpreg} object for the whole data.}
#' \item{fold}{The fold assignments for cross-validation for each observation;
#' note that for \code{cv.grpsurv}, these are in terms of the ordered
#' observations, not the original observations.} \item{min}{The index of
#' \code{lambda} corresponding to \code{lambda.min}.} \item{lambda.min}{The
#' value of \code{lambda} with the minimum cross-validation error.}
#' \item{null.dev}{The deviance for the intercept-only model.} \item{pe}{If
#' \code{family="binomial"}, the cross-validation prediction error for each
#' value of \code{lambda}.}
#' @author Patrick Breheny
#' @seealso \code{\link{grpreg}}, \code{\link{plot.cv.grpreg}},
#' \code{\link{summary.cv.grpreg}}, \code{\link{predict.cv.grpreg}}
#' 
#' @examples
#' \dontshow{set.seed(1)}
#' data(Birthwt)
#' X <- Birthwt$X
#' y <- Birthwt$bwt
#' group <- Birthwt$group
#' 
#' cvfit <- cv.grpreg(X, y, group)
#' plot(cvfit)
#' summary(cvfit)
#' coef(cvfit) ## Beta at minimum CVE
#' 
#' cvfit <- cv.grpreg(X, y, group, penalty="gel")
#' plot(cvfit)
#' summary(cvfit)
#' 
#' @export cv.grpreg

cv.grpreg <- function(X, y, group=1:ncol(X), ..., nfolds=10, seed, fold, returnY=FALSE, trace=FALSE) {

  # Complete data fit
  fit.args <- list(...)
  fit.args$X <- X
  fit.args$y <- y
  if (!inherits(X, "expandedMatrix")) fit.args$group <- group
  fit.args$returnX <- TRUE
  if ('penalty' %in% names(fit.args) && fit.args$penalty == 'gBridge') {
    fit <- do.call("gBridge", fit.args)
  } else {
    fit <- do.call("grpreg", fit.args)
  }

  # Get y, standardized X
  XG <- fit$XG
  X <- XG$X
  m <- attr(fit$y, "m")
  y <- if (fit$family=="gaussian") fit$y - attr(fit$y, "mean") else fit$y
  returnX <- list(...)$returnX
  if (is.null(returnX) || !returnX) fit$XG <- NULL

  # Set up folds
  if (!missing(seed)) {
    original_seed <- .GlobalEnv$.Random.seed
    on.exit(.GlobalEnv$.Random.seed <- original_seed)
    set.seed(seed)
  }
  n <- length(y)  
  if (missing(fold)) {
    if (m > 1) {
      nn <- n/m
      fold_ <- sample(1:nn %% (nfolds))
      fold_[fold_==0] <- nfolds
      fold <- rep(fold_, each=m)
    } else if (fit$family=="binomial") {
      ind1 <- which(y==1)
      ind0 <- which(y==0)
      n1 <- length(ind1)
      n0 <- length(ind0)
      fold1 <- 1:n1 %% nfolds
      fold0 <- (n1 + 1:n0) %% nfolds
      fold1[fold1==0] <- nfolds
      fold0[fold0==0] <- nfolds
      fold <- integer(n)
      fold[y==1] <- sample(fold1)
      fold[y==0] <- sample(fold0)
    } else {
      fold <- sample(1:n %% nfolds)
      fold[fold==0] <- nfolds
    }
  } else {
    nfolds <- max(fold)
  }

  # Do cross-validation
  E <- Y <- matrix(NA, nrow=length(y), ncol=length(fit$lambda))
  if (fit$family=="binomial") PE <- E
  cv.args <- list(...)
  cv.args$lambda <- fit$lambda
  cv.args$group <- XG$g
  cv.args$group.multiplier <- XG$m
  cv.args$warn <- FALSE
  for (i in 1:nfolds) {
    if (trace) cat("Starting CV fold #", i, sep="","\n")
    res <- cvf(i, X, y, fold, cv.args)
    Y[fold==i, 1:res$nl] <- res$yhat
    E[fold==i, 1:res$nl] <- res$loss
    if (fit$family=="binomial") PE[fold==i, 1:res$nl] <- res$pe
  }

  # Eliminate saturated lambda values, if any
  ind <- which(apply(is.finite(E), 2, all))
  E <- E[, ind, drop=FALSE]
  Y <- Y[, ind]
  lambda <- fit$lambda[ind]

  # Return
  cve <- apply(E, 2, mean)
  cvse <- apply(E, 2, sd) / sqrt(n)
  min <- which.min(cve)
  null.dev <- calcNullDev(X, y, group=XG$g, family=fit$family)

  val <- list(cve=cve, cvse=cvse, lambda=lambda, fit=fit, fold=fold, min=min, lambda.min=lambda[min], null.dev=null.dev)
  if (fit$family=="binomial") val$pe <- apply(PE[, ind], 2, mean)
  if (returnY) {
    if (fit$family=="gaussian") val$Y <- Y + attr(y, "mean")
    else val$Y <- Y
  }
  structure(val, class="cv.grpreg")
}
cvf <- function(i, X, y, fold, cv.args) {
  cv.args$X <- X[fold!=i, , drop=FALSE]
  cv.args$y <- y[fold!=i]
  if ('penalty' %in% names(cv.args) && cv.args$penalty == 'gBridge') {
    fit.i <- do.call("gBridge", cv.args)
  } else {
    fit.i <- do.call("grpreg", cv.args)
  }

  X2 <- X[fold==i, , drop=FALSE]
  y2 <- y[fold==i]
  yhat <- matrix(predict(fit.i, X2, type="response"), length(y2))
  loss <- loss.grpreg(y2, yhat, fit.i$family)
  pe <- if (fit.i$family=="binomial") {(yhat < 0.5) == y2} else NULL
  list(loss=loss, pe=pe, nl=length(fit.i$lambda), yhat=yhat)
}
pbreheny/grpreg documentation built on April 3, 2024, 3:53 p.m.