Nothing
#' Cluster-based KNN Regression for Longitudinal Data (CKNNRLD)
#'
#' This function implements a clustering-based KNN regression method designed
#' for longitudinal datasets. It first clusters the training data using longitudinal
#' clustering (via latrend) and then performs KNN regression within each selected cluster.
#'
#' @param xnew A matrix of predictor values for test data (new observations).
#' @param y A matrix or data frame of longitudinal responses (subjects x timepoints).
#' @param x A matrix or data frame of predictors for training data.
#' @param k Number of nearest neighbors to use. Can be a scalar (same k for all clusters)
#' or a vector (different k per cluster).
#' @param c Number of clusters for longitudinal clustering.
#' @param cluster_method Clustering method to use. Currently supports "kml" (default).
#'
#' @return A data frame with predicted values and cluster assignment for each observation in `xnew`.
#' Columns: cluster, Y1, Y2, ..., Yd (where d = number of timepoints).
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' n <- 30
#' T <- 3
#' d <- 2
#' x <- matrix(runif(n * d), nrow = n)
#' y <- matrix(rnorm(n * T), nrow = n)
#' train_idx <- sample(1:n, 20)
#' test_idx <- setdiff(1:n, train_idx)
#' result <- CKNNRLD(
#' x = x[train_idx, ],
#' y = y[train_idx, ],
#' xnew = x[test_idx, ],
#' k = 3,
#' c = 2
#' )
#' head(result)
#' }
#'
#' @importFrom latrend lcMethodKML latrend trajectoryAssignments
#' @importFrom Rfast dista colmeans
#'
#' @export
CKNNRLD <- function(xnew, y, x, k = 5, c = 4, cluster_method = "kml") {
y <- as.matrix(y)
x <- as.matrix(x)
xnew <- as.matrix(xnew)
if (nrow(y) != nrow(x)) {
stop("Number of rows in y and x must be equal")
}
n_subjects <- nrow(y)
n_timepoints <- ncol(y)
n_predictors <- ncol(x)
long_data <- data.frame(
id = 1:n_subjects,
time = rep(1:n_timepoints, each = n_subjects),
value = as.vector(t(y))
)
if (cluster_method == "kml") {
method_obj <- lcMethodKML(
response = "value",
time = "time",
id = "id",
nClusters = c,
save.data = FALSE
)
} else {
stop(paste("Method", cluster_method, "is not implemented. Use 'kml'."))
}
model <- latrend(method_obj, data = long_data, verbose = FALSE)
cluster_assignments <- trajectoryAssignments(model)
if (length(cluster_assignments) != n_subjects) {
stop("Cluster assignment length doesn't match number of rows in y")
}
tdata <- cbind(cluster = cluster_assignments, y, x)
y_start_idx <- 2
y_end_idx <- y_start_idx + n_timepoints - 1
x_start_idx <- y_end_idx + 1
x_end_idx <- x_start_idx + n_predictors - 1
cluster_centers <- matrix(NA, nrow = c, ncol = n_predictors)
for (i in 1:c) {
cluster_data <- tdata[tdata[, 1] == i, x_start_idx:x_end_idx, drop = FALSE]
if (nrow(cluster_data) > 1) {
cluster_centers[i, ] <- colMeans(cluster_data)
} else {
cluster_centers[i, ] <- cluster_data
}
}
if (!identical(xnew, x)) {
dist_result <- Rfast::dista(xnew, cluster_centers, trans = TRUE, k = 1, index = TRUE)
xnew_clustered <- cbind(dist_result, xnew)
} else {
xnew_clustered <- cbind(cluster_assignments, xnew)
}
predictions <- vector("list", c)
for (i in 1:c) {
cluster_rows <- which(xnew_clustered[, 1] == i)
if (length(cluster_rows) == 0) {
predictions[[i]] <- matrix(NA, nrow = 0, ncol = n_timepoints)
next
}
xnew_i <- xnew_clustered[cluster_rows, -1, drop = FALSE]
if (nrow(xnew_i) == 1) {
xnew_i <- matrix(xnew_i, nrow = 1)
}
y_i <- tdata[tdata[, 1] == i, y_start_idx:y_end_idx, drop = FALSE]
x_i <- tdata[tdata[, 1] == i, x_start_idx:x_end_idx, drop = FALSE]
k_used <- if (length(k) > 1) {
min(k[i], nrow(y_i))
} else {
min(k, nrow(y_i))
}
knn_result <- KNNRLD(xnew = xnew_i, y = y_i, x = x_i, k = k_used)
predictions[[i]] <- knn_result[[1]]
}
final_predictions <- matrix(NA, nrow = nrow(xnew), ncol = n_timepoints)
for (i in 1:c) {
cluster_rows <- which(xnew_clustered[, 1] == i)
if (length(cluster_rows) > 0) {
final_predictions[cluster_rows, ] <- predictions[[i]]
}
}
result <- data.frame(cluster = xnew_clustered[, 1], final_predictions)
colnames(result)[-1] <- paste0("Y", 1:n_timepoints)
return(result)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.