R/Tuning_KNNRLD.R

Defines functions KNNRLD.tune

Documented in KNNRLD.tune

#' Tune k in KNNRLD using Cross-Validation
#'
#' Finds the optimal number of neighbors for KNN regression in longitudinal data
#' using k-fold cross-validation. Evaluates k values from 2 to A.
#'
#' @param y Matrix of longitudinal outcomes (subjects x timepoints).
#' @param x Matrix of predictor variables (subjects x features).
#' @param nfolds Number of cross-validation folds (default = 10).
#' @param folds Optional list of pre-specified fold indices. If provided, nfolds is ignored.
#' @param seed Optional random seed for reproducibility.
#' @param A Maximum number of neighbors to evaluate (searches from 2 to A, default = 10).
#' @param graph Logical; if TRUE, plots MSPE vs. k.
#'
#' @return A list containing:
#' \item{crit}{Mean squared prediction error (MSPE) for each k value}
#' \item{best_k}{Optimal number of neighbors (minimizes MSPE)}
#' \item{performance}{Minimum MSPE value}
#' \item{runtime}{Elapsed computation time}
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' n <- 30
#' T <- 3
#' d <- 2
#' x <- matrix(runif(n * d), nrow = n)
#' y <- matrix(rnorm(n * T), nrow = n)
#' tune_result <- KNNRLD.tune(
#'   y = y,
#'   x = x,
#'   nfolds = 3,
#'   A = 4
#' )
#' str(tune_result)
#' }
#'
#' @importFrom Directional makefolds
#' @importFrom Rfast colmeans
#' @importFrom graphics plot abline points
#'
#' @export
KNNRLD.tune <- function(y, x, nfolds = 10, folds = NULL, seed = NULL, A = 10, graph = FALSE) {
  
  y <- as.matrix(y)
  x <- as.matrix(x)
  n <- nrow(y)
  
  if (A < 2) {
    stop("A must be at least 2")
  }
  if (A > n) {
    warning(paste("A (", A, ") exceeds number of samples (", n, "). Reducing to", n))
    A <- n
  }
  
  if (is.null(folds)) {
    ina <- 1:n
    folds <- Directional::makefolds(ina, nfolds = nfolds, stratified = FALSE, seed = seed)
  }
  
  nfolds <- length(folds)
  k_values <- 2:A
  n_k <- length(k_values)
  
  fold_errors <- matrix(nrow = nfolds, ncol = n_k)
  
  runtime <- proc.time()
  
  for (fold_idx in 1:nfolds) {
    test_indices <- folds[[fold_idx]]
    train_indices <- setdiff(1:n, test_indices)
    
    y_test <- y[test_indices, , drop = FALSE]
    y_train <- y[train_indices, , drop = FALSE]
    x_test <- x[test_indices, , drop = FALSE]
    x_train <- x[train_indices, , drop = FALSE]
    
    predictions <- KNNRLD(xnew = x_test, y = y_train, x = x_train, k = k_values)
    
    for (j in 1:n_k) {
      fold_errors[fold_idx, j] <- mean((predictions[[j]] - y_test)^2, na.rm = TRUE)
    }
  }
  
  runtime <- proc.time() - runtime
  
  mspe <- Rfast::colmeans(fold_errors)
  names(mspe) <- paste0("k = ", k_values)
  
  best_k_idx <- which.min(mspe)
  best_k <- k_values[best_k_idx]
  best_performance <- mspe[best_k_idx]
  
  if (isTRUE(graph)) {
    graphics::plot(
      k_values, mspe,
      xlab = "Number of Nearest Neighbors (k)",
      ylab = "Mean Squared Prediction Error (MSPE)",
      type = "b",
      pch = 16,
      cex.axis = 1.2,
      cex.lab = 1.2,
      col = "darkgreen",
      lwd = 2,
      main = "KNN Tuning for Longitudinal Data"
    )
    graphics::abline(v = k_values, col = "lightgrey", lty = 2)
    graphics::abline(h = pretty(mspe), lty = 2, col = "lightgrey")
    graphics::points(best_k, best_performance, col = "red", pch = 19, cex = 1.5)
  }
  
  list(
    crit = mspe,
    best_k = best_k,
    performance = best_performance,
    runtime = runtime[3]
  )
}

Try the CKNNRLD package in your browser

Any scripts or data that you put into this service are public.

CKNNRLD documentation built on May 29, 2026, 1:06 a.m.