R/CKNNRLD.R

Defines functions CKNNRLD

Documented in CKNNRLD

#' Cluster-based KNN Regression for Longitudinal Data (CKNNRLD)
#'
#' This function implements a clustering-based KNN regression method designed
#' for longitudinal datasets. It first clusters the training data using longitudinal
#' clustering (via latrend) and then performs KNN regression within each selected cluster.
#'
#' @param xnew A matrix of predictor values for test data (new observations).
#' @param y A matrix or data frame of longitudinal responses (subjects x timepoints).
#' @param x A matrix or data frame of predictors for training data.
#' @param k Number of nearest neighbors to use. Can be a scalar (same k for all clusters)
#'          or a vector (different k per cluster).
#' @param c Number of clusters for longitudinal clustering.
#' @param cluster_method Clustering method to use. Currently supports "kml" (default).
#'
#' @return A data frame with predicted values and cluster assignment for each observation in `xnew`.
#'         Columns: cluster, Y1, Y2, ..., Yd (where d = number of timepoints).
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' n <- 30
#' T <- 3
#' d <- 2
#' x <- matrix(runif(n * d), nrow = n)
#' y <- matrix(rnorm(n * T), nrow = n)
#' train_idx <- sample(1:n, 20)
#' test_idx <- setdiff(1:n, train_idx)
#' result <- CKNNRLD(
#'   x = x[train_idx, ],
#'   y = y[train_idx, ],
#'   xnew = x[test_idx, ],
#'   k = 3,
#'   c = 2
#' )
#' head(result)
#' }
#'
#' @importFrom latrend lcMethodKML latrend trajectoryAssignments
#' @importFrom Rfast dista colmeans
#'
#' @export
CKNNRLD <- function(xnew, y, x, k = 5, c = 4, cluster_method = "kml") {
  
  y <- as.matrix(y)
  x <- as.matrix(x)
  xnew <- as.matrix(xnew)
  
  if (nrow(y) != nrow(x)) {
    stop("Number of rows in y and x must be equal")
  }
  
  n_subjects <- nrow(y)
  n_timepoints <- ncol(y)
  n_predictors <- ncol(x)
  
  long_data <- data.frame(
    id = 1:n_subjects,
    time = rep(1:n_timepoints, each = n_subjects),
    value = as.vector(t(y))
  )
  
  if (cluster_method == "kml") {
    method_obj <- lcMethodKML(
      response = "value",
      time = "time",
      id = "id",
      nClusters = c,
      save.data = FALSE
    )
  } else {
    stop(paste("Method", cluster_method, "is not implemented. Use 'kml'."))
  }
  
  model <- latrend(method_obj, data = long_data, verbose = FALSE)
  cluster_assignments <- trajectoryAssignments(model)
  
  if (length(cluster_assignments) != n_subjects) {
    stop("Cluster assignment length doesn't match number of rows in y")
  }
  
  tdata <- cbind(cluster = cluster_assignments, y, x)
  
  y_start_idx <- 2
  y_end_idx <- y_start_idx + n_timepoints - 1
  x_start_idx <- y_end_idx + 1
  x_end_idx <- x_start_idx + n_predictors - 1
  
  cluster_centers <- matrix(NA, nrow = c, ncol = n_predictors)
  for (i in 1:c) {
    cluster_data <- tdata[tdata[, 1] == i, x_start_idx:x_end_idx, drop = FALSE]
    if (nrow(cluster_data) > 1) {
      cluster_centers[i, ] <- colMeans(cluster_data)
    } else {
      cluster_centers[i, ] <- cluster_data
    }
  }
  
  if (!identical(xnew, x)) {
    dist_result <- Rfast::dista(xnew, cluster_centers, trans = TRUE, k = 1, index = TRUE)
    xnew_clustered <- cbind(dist_result, xnew)
  } else {
    xnew_clustered <- cbind(cluster_assignments, xnew)
  }
  
  predictions <- vector("list", c)
  
  for (i in 1:c) {
    cluster_rows <- which(xnew_clustered[, 1] == i)
    
    if (length(cluster_rows) == 0) {
      predictions[[i]] <- matrix(NA, nrow = 0, ncol = n_timepoints)
      next
    }
    
    xnew_i <- xnew_clustered[cluster_rows, -1, drop = FALSE]
    if (nrow(xnew_i) == 1) {
      xnew_i <- matrix(xnew_i, nrow = 1)
    }
    
    y_i <- tdata[tdata[, 1] == i, y_start_idx:y_end_idx, drop = FALSE]
    x_i <- tdata[tdata[, 1] == i, x_start_idx:x_end_idx, drop = FALSE]
    
    k_used <- if (length(k) > 1) {
      min(k[i], nrow(y_i))
    } else {
      min(k, nrow(y_i))
    }
    
    knn_result <- KNNRLD(xnew = xnew_i, y = y_i, x = x_i, k = k_used)
    predictions[[i]] <- knn_result[[1]]
  }
  
  final_predictions <- matrix(NA, nrow = nrow(xnew), ncol = n_timepoints)
  for (i in 1:c) {
    cluster_rows <- which(xnew_clustered[, 1] == i)
    if (length(cluster_rows) > 0) {
      final_predictions[cluster_rows, ] <- predictions[[i]]
    }
  }
  
  result <- data.frame(cluster = xnew_clustered[, 1], final_predictions)
  colnames(result)[-1] <- paste0("Y", 1:n_timepoints)
  
  return(result)
}

Try the CKNNRLD package in your browser

Any scripts or data that you put into this service are public.

CKNNRLD documentation built on May 29, 2026, 1:06 a.m.