tests/testthat/test-grid.search.R

# --- Global helpers and shared data

gen.data <- function(n, p, sparse.p, noise.sd, rank.def = FALSE) {
  # Raw scrambled features
  X.raw <- (sample.int(1e7, n * p, replace = TRUE) %% 2000) / 100 - 10
  X <- matrix(X.raw, n, p)

  # Coefficients with controlled signal
  beta <- (sample.int(500, p) / 100)
  beta[sample.int(p, sparse.p)] <- 0

  # Rank-deficiency option
  if (rank.def && p > 2) {
    X[, p] <- X[, 1] + X[, 2]
  }

  # Response
  signal <- X %*% beta
  noise <- rnorm(n, sd = noise.sd)
  y <- signal + noise

  cbind(data.frame(y = as.numeric(y)), X)
}

muffle <- function(expr) {
  withCallingHandlers(
    expr,
    warning = function(w) {
      if (grepl("K has been changed", w$message, fixed = TRUE)) {
        invokeRestart("muffleWarning")
      }
    }
  )
}

# --- Scenario data sets

### Expect min cases to select near zero shrinkage (if not zero)
### Expect mid cases to select a mid-range shrinkage value
### Expect max cases to select max.lambda

set.seed(213)

# Narrow (n > p)
df.narrow.min <- gen.data(n = 183, p = 5, sparse.p = 0, noise.sd = 1e-4)
df.narrow.mid <- gen.data(n = 126, p = 20, sparse.p = 10, noise.sd = 10)
df.narrow.max <- gen.data(n = 181, p = 32, sparse.p = 30, noise.sd = 1e4)

# Wide (p > n)
df.wide.min <- gen.data(n = 51, p = 103, sparse.p = 1, noise.sd = 1e-3)
df.wide.mid <- gen.data(n = 75, p = 85, sparse.p = 35, noise.sd = 25)
df.wide.max <- gen.data(n = 56, p = 181, sparse.p = 176, noise.sd = 1e4)

# Rank-deficient
df.rd.mid <- gen.data(
  n = 132,
  p = 4,
  sparse.p = 2,
  noise.sd = 3,
  rank.def = TRUE
)

# --- Test parameters

K.vals <- c(2, 5, 10, NA)
seed <- 73568569
max.lambda <- 100
precision <- 0.5
scenarios <- list(
  df.narrow.min,
  df.narrow.mid,
  df.narrow.max,
  df.wide.min,
  df.wide.mid,
  df.wide.max,
  df.rd.mid
)

# --- Run tests

test_that("grid.search matches brute-force cvLM sweep", {
  # Can take a long time to run
  skip_on_cran()

  # Simulate grid generator
  lambdas <- seq(0, max.lambda, by = precision)

  if (max.lambda != tail(lambdas, 1)) {
    lambdas <- c(lambdas, max.lambda)
  }

  for (data.set in scenarios) {
    for (K in K.vals) {
      if (is.na(K)) {
        K <- nrow(data.set)
      }
      
      is.loocv <- K == nrow(data.set)
      generalized.opts <- if (is.loocv) c(FALSE, TRUE) else FALSE

      for (generalized in generalized.opts) {
        for (center in c(FALSE, TRUE)) {
          common.args <- list(
            y ~ .,
            data = data.set,
            generalized = generalized,
            seed = seed,
            center = center
          )

          grid.res <- muffle(do.call(
            grid.search,
            c(
              common.args,
              list(K = K, max.lambda = max.lambda, precision = precision)
            )
          ))

          manual.cvs <- vapply(
            lambdas,
            function(lambda) {
              muffle(do.call(
                cvLM,
                c(common.args, list(K.vals = K, lambda = lambda))
              ))$CV
            },
            numeric(1)
          )

          best.idx <- which.min(manual.cvs)

          if (grid.res$lambda != lambdas[best.idx]) {
            browser()
          }

          expect_equal(grid.res$CV, manual.cvs[best.idx])
          expect_equal(grid.res$lambda, lambdas[best.idx])
        }
      }
    }
  }
})

test_that("grid.search results are agnostic to the number of threads", {
  # Skip multithreaded tests on CRAN
  skip_on_cran()
  multi.threads <- max(RcppParallel::defaultNumThreads(), 2L)

  for (data.set in scenarios) {
    for (K in K.vals) {
      if (is.na(K)) {
        K <- nrow(data.set)
      }
      
      is.loocv <- K == nrow(data.set)
      generalized.opts <- if (is.loocv) c(FALSE, TRUE) else FALSE

      for (generalized in generalized.opts) {
        for (center in c(FALSE, TRUE)) {
          common.args <- list(
            y ~ .,
            data = data.set,
            generalized = generalized,
            seed = seed,
            center = center,
            K = K,
            max.lambda = max.lambda,
            precision = precision
          )

          res.single <- muffle(do.call(
            grid.search,
            c(common.args, list(n.threads = 1L))
          ))

          res.multiple <- muffle(do.call(
            grid.search,
            c(common.args, list(n.threads = multi.threads))
          ))

          # Results may not be exactly identical because of the lack of associativity for floating point
          # addition
          expect_equal(res.single$CV, res.multiple$CV)
          expect_identical(res.single$lambda, res.multiple$lambda)
        }
      }
    }
  }
})

Try the cvLM package in your browser

Any scripts or data that you put into this service are public.

cvLM documentation built on Feb. 3, 2026, 5:06 p.m.