R/cvLM.R

Defines functions cvLM.glm cvLM.lm cvLM.formula cvLM .cvLM_fit

Documented in cvLM cvLM.formula cvLM.glm cvLM.lm

## cvLM.R: Fast cross-validation for linear and ridge regression models using RcppArmadillo
##
## This file is part of the cvLM package.

# Internal function that accepts prepared data and parameters
.cvLM_fit <- function(
  y,
  X,
  K.vals,
  lambda,
  generalized,
  seed,
  n.threads,
  tol,
  center,
  mt
) {
  # --- Confirm validity of arguments

  # Number of folds
  K.vals <- .assert_valid_kvals(K.vals, nrow(X))

  # Shrinkage parameter
  lambda <- .assert_double_scalar(lambda, "lambda", nonneg = TRUE)

  # Generalized boolean
  .assert_logical_scalar(generalized, "generalized")

  # Seed
  seed <- .assert_integer_scalar(seed, "seed", nonneg = FALSE)

  # Number of threads (-1 -> defaultNumThreads)
  n.threads <- .assert_valid_threads(n.threads)

  # Threshold for complete orthogonal decomposition
  tol <- .assert_double_scalar(tol, "tol", nonneg = TRUE)

  # Whether to center the data - affecting whether the intercept term is penalized or not in the case of
  # ridge regression (can also provide different numbers for undetermined OLS cases)
  .assert_logical_scalar(center, "center")

  # Drop the intercept term if we're centering the data
  if (center && attr(mt, "intercept") == 1L) {
    X <- .drop_intercept(X)
  }

  # Check for valid regression data before passing to C++
  .assert_valid_data(y, X)

  # If generalized, K doesn't matter so just set it to look like LOOCV since it's an LOOCV shortcut
  if (generalized) {
    K.vals <- nrow(X)
  }

  # GCV and LOOCV aren't multithreaded
  if (all(K.vals == nrow(X))) {
    n.threads <- 1L
  }

  # Try to prevent from oversubscription
  if (n.threads > 1L) {
    if (requireNamespace("RhpcBLASctl", quietly = TRUE)) {
      old.blas.threads <- RhpcBLASctl::blas_get_num_procs()
      RhpcBLASctl::blas_set_num_threads(1L)
      on.exit(RhpcBLASctl::blas_set_num_threads(old.blas.threads), add = TRUE)
    } else {
      warning(
        "Parallel execution requested, but 'RhpcBLASctl' is not installed. Performance may be degraded if ",
        "using a multithreaded BLAS implementation. Install 'RhpcBLASctl' or use n.threads = 1 to silence",
        "this warning."
      )
    }
  }

  # Pass off to C++
  cvs <- vapply(
    K.vals,
    function(K) {
      cv.lm.rcpp(
        X = X,
        y = y,
        k0 = K,
        lambda = lambda,
        generalized = generalized,
        seed = seed,
        nThreads = min(K, n.threads),
        tolerance = tol,
        center = center
      )
    },
    numeric(1L),
    USE.NAMES = FALSE
  )

  data.frame(K = K.vals, CV = cvs, seed = seed)
}

cvLM <- function(object, ...) UseMethod("cvLM")

# Formula method
cvLM.formula <- function(
  object,
  data,
  subset,
  na.action,
  K.vals = 10L,
  lambda = 0,
  generalized = FALSE,
  seed = 1L,
  n.threads = 1L,
  tol = 1e-7,
  center = TRUE,
  ...
) {
  # --- Extract data (mimic lm() behavior)

  mf <- match.call(expand.dots = FALSE)
  m <- match(c("object", "data", "subset", "na.action"), names(mf), 0L)
  mf <- mf[c(1L, m)]
  names(mf)[names(mf) == "object"] <- "formula"
  mf$drop.unused.levels <- TRUE
  mf[[1L]] <- quote(stats::model.frame)
  mf <- eval(mf, parent.frame())
  mt <- attr(mf, "terms")

  if (stats::is.empty.model(mt)) {
    stop("Empty model specified.", call. = FALSE)
  }

  X <- stats::model.matrix(mt, mf)
  y <- stats::model.response(mf, "double")

  .cvLM_fit(
    y = y,
    X = X,
    K.vals = K.vals,
    lambda = lambda,
    generalized = generalized,
    seed = seed,
    n.threads = n.threads,
    tol = tol,
    center = center,
    mt = mt
  )
}

# lm method
cvLM.lm <- function(
  object,
  data,
  K.vals = 10L,
  lambda = 0,
  generalized = FALSE,
  seed = 1L,
  n.threads = 1L,
  tol = 1e-7,
  center = TRUE,
  ...
) {
  # Raise warning for unsupported lm features (weights and offset)
  if (!is.null(object$weights) && length(unique(object$weights)) > 1L) {
    warning(
      "cvLM does not currently support weighted least squares. Weights will be ignored.",
      call. = FALSE
    )
  }

  if (!is.null(object$offset)) {
    warning(
      "cvLM does not currently support offsets. Offset will be ignored.",
      call. = FALSE
    )
  }

  # --- Extract data

  mf <- stats::model.frame(object, data = data)
  mt <- attr(mf, "terms")
  X <- stats::model.matrix(mt, mf)
  y <- stats::model.response(mf, "double")

  .cvLM_fit(
    y = y,
    X = X,
    K.vals = K.vals,
    lambda = lambda,
    generalized = generalized,
    seed = seed,
    n.threads = n.threads,
    tol = tol,
    center = center,
    mt = mt
  )
}

cvLM.glm <- function(
  object,
  data,
  K.vals = 10L,
  lambda = 0,
  generalized = FALSE,
  seed = 1L,
  n.threads = 1L,
  tol = 1e-7,
  center = TRUE,
  ...
) {
  if (!.is_lm(object)) {
    stop(
      "cvLM only performs cross-validation for linear and ridge regression models.",
      call. = FALSE
    )
  }

  # Use NextMethod to dispatch to cvLM.lm
  NextMethod("cvLM")
}

Try the cvLM package in your browser

Any scripts or data that you put into this service are public.

cvLM documentation built on Feb. 3, 2026, 5:06 p.m.