R/temporal_forest.R

Defines functions print.TemporalForest temporal_forest

Documented in print.TemporalForest temporal_forest

#' Temporal Forest for Longitudinal Feature Selection
#'
#' @description
#' The main user-facing function for the `TemporalForest` package. It performs the
#' complete three-stage algorithm to select a top set of features from
#' high-dimensional longitudinal data.
#'
#' @details
#' The function executes a three-stage process:
#' \enumerate{
#'   \item **Time-Aware Module Construction:** Builds a consensus network across time points to identify modules of stably co-correlated features.
#'   \item **Within-Module Screening:** Uses bootstrapped mixed-effects model trees (`glmertree`) to screen for important predictors within each module.
#'   \item **Stability Selection:** Performs a final stability selection step on the surviving features to yield a reproducible final set.
#' }
#'
#' **Unbalanced Panels:** The algorithm is robust to unbalanced panel data (i.e., subjects with missing time points). The consensus TOM is constructed using the time points available, and the mixed-effects models naturally handle missing observations.
#'
#' **Outcome Family:** The current version is designed for **Gaussian (continuous) outcomes**, as it relies on `glmertree::lmertree`. Support for other outcome families is not yet implemented.
#'
#' **Reproducibility (Determinism):** For reproducible results, it is recommended to set a seed using `set.seed()` before running. The algorithm has both stochastic and deterministic components:
#' \itemize{
#'   \item \strong{Stochastic} (depends on `set.seed()`): The bootstrap resampling of subjects in both the screening and selection stages.
#'   \item \strong{Deterministic} (does not depend on `set.seed()`): The network construction process (correlation, adjacency, and TOM calculation).
#' }
#'
#' @param X A list of numeric matrices, one for each time point. The rows of each
#'   matrix should be subjects and columns should be predictors. Required unless
#'   `dissimilarity_matrix` is provided.
#' @param Y A numeric vector for the longitudinal outcome.
#' @param id A vector of subject identifiers.
#' @param time A vector of time point indicators.
#' @param dissimilarity_matrix An optional pre-computed dissimilarity matrix (e.g., `1 - TOM`).
#'   If provided, the network construction step (Stage 1) is skipped. The matrix must be
#'   square with predictor names as rownames and colnames. Defaults to `NULL`.
#' @param n_features_to_select The number of top features to return in the final selection.
#'   This is passed to the `number_selected_final` argument of the internal function.
#'   Defaults to 10.
#' @param min_module_size The minimum number of features in a module. Passed to the
#'   `minClusterSize` argument of the internal function. Defaults to 4.
#' @param n_boot_screen The number of bootstrap repetitions for the initial screening
#'   stage within modules. Defaults to 50.
#' @param keep_fraction_screen The proportion of features to keep from each module
#'   during the screening stage. Defaults to 0.25.
#' @param n_boot_select The number of bootstrap repetitions for the final stability
#'   selection stage. Defaults to 100.
#' @param alpha_screen The significance level for splitting in the screening stage trees.
#'   Defaults to 0.2.
#' @param alpha_select The significance level for splitting in the selection stage trees.
#'   Defaults to 0.05.
#'
#' @section Input contract:
#' - **X**: list of numeric matrices, one per time point; **columns (names and
#'   order) must be identical across all time points**. The function does not
#'   reorder or reconcile columns.
#' - **Row order / binding rule**: when rows from `X` are stacked internally,
#'   they are assumed to already be in **subject-major × time-minor** order in
#'   the user's data. The function does **not** re-order subjects or time.
#' - **Y, id, time**: vectors of equal length. `id` and `time` may be
#'   integer/character/factor; `time` is coerced to a numeric sequence
#'   via `as.numeric(as.factor(time))`.
#' - **Missing values**: this function does **not** perform NA filtering or
#'   imputation. Users should pre-clean the data (e.g., `keep <- complete.cases(Y,id,time)`).
#'
#' @section Unbalanced panels:
#' Missing time points per subject are allowed **provided the user supplies
#' `X`, `Y`, `id`, `time` that already align under the binding rule above**.
#' Stage 1 builds a TOM at the feature level for each available time-point
#' matrix; the **consensus TOM** is the element-wise minimum across time points.
#' Subject-level missingness at a given time does not prevent feature-wise
#' similarity from being computed at other times. This function does not perform
#' any subject-level alignment across time.
#'
#' @section Outcome family:
#' Current version targets **Gaussian** outcomes via `glmertree::lmertree`.
#' Other families (e.g., binomial/Poisson) are not supported in this version.
#'
#' @section Stability selection and thresholds:
#' Final selection is **top-K** by bootstrap frequency (K = `n_features_to_select`).
#' A probability cutoff (e.g., `pi_thr`) is **not** used and selection
#' probabilities are **not returned** in the current API.
#'
#' @section Reproducibility (determinism):
#' - **Stochastic** (affected by `set.seed()`): bootstrap resampling and tree
#'   partitioning.
#' - **Deterministic**: correlation/adjacency/TOM and consensus-TOM given fixed inputs.
#'
#' @section Internal validation:
#' An internal helper \code{\link{check_temporal_consistency}} is called
#' automatically at the start (whenever \code{dissimilarity_matrix} is \code{NULL}).
#' It throws an error if column names across time points are not identical
#' (names and order).
#'
#' @return An object of class `TemporalForest` with:
#' \itemize{
#'   \item \code{top_features} (\strong{character}): the K selected features in
#'         descending stability order.
#'   \item \code{candidate_features} (\strong{character}): all features that
#'         entered the final (second-stage) selection.
#' }
#'
#' @note The current API does not expose selection probabilities, module labels,
#' or a parameter snapshot; these may be added in a future version.
#'
#' @seealso \code{\link{select_soft_power}}, \code{\link{calculate_fs_metrics_cv}},
#'   \code{\link{calculate_pred_metrics_cv}}, \code{\link{check_temporal_consistency}}
#'
#' @author Sisi Shao, Jason H. Moore, Christina M. Ramirez
#' @references Shao, S., Moore, J.H., Ramirez, C.M. (2025). Network-Guided
#'   Temporal Forests for Feature Selection in High-Dimensional Longitudinal Data.
#'   *Journal of Statistical Software*.
#' @export
#' @importFrom WGCNA adjacency TOMsimilarity labels2colors
#' @importFrom dynamicTreeCut cutreeDynamic
#' @importFrom stats as.dist hclust as.formula
#' @examples
#' \donttest{
#' # Tiny demo: selects V1, V2, V3 quickly (skips Stage 1 via precomputed A)
#' set.seed(11)
#' n_subjects <- 60; n_timepoints <- 2; p <- 20
#' X <- replicate(n_timepoints, matrix(rnorm(n_subjects * p), n_subjects, p), simplify = FALSE)
#' colnames(X[[1]]) <- colnames(X[[2]]) <- paste0("V", 1:p)
#' X_long <- do.call(rbind, X)
#' id   <- rep(seq_len(n_subjects), each = n_timepoints)
#' time <- rep(seq_len(n_timepoints), times = n_subjects)
#' u <- rnorm(n_subjects, 0, 0.7)
#' eps <- rnorm(length(id), 0, 0.08)
#' Y <- 4*X_long[,"V1"] + 3.5*X_long[,"V2"] + 3.2*X_long[,"V3"] + rep(u, each = n_timepoints) + eps
#' A <- 1 - abs(stats::cor(X_long)); diag(A) <- 0
#' dimnames(A) <- list(colnames(X[[1]]), colnames(X[[1]]))
#' fit <- temporal_forest(
#'   X, Y, id, time,
#'   dissimilarity_matrix = A,
#'   n_features_to_select = 3,
#'   n_boot_screen = 6, n_boot_select = 18,
#'   keep_fraction_screen = 1, min_module_size = 2,
#'   alpha_screen = 0.5, alpha_select = 0.6
#' )
#' print(fit$top_features)
#' }
temporal_forest <- function(X = NULL, Y, id, time,
                            dissimilarity_matrix = NULL,
                            n_features_to_select = 10,
                            min_module_size = 4,
                            n_boot_screen = 50,
                            keep_fraction_screen = 0.25,
                            n_boot_select = 100,
                            alpha_screen = 0.2,
                            alpha_select = 0.05) {
    
    # --- Input Validation ---
  
    if (is.null(dissimilarity_matrix)) {
        check_temporal_consistency(X) 
    }
    
    if (is.null(dissimilarity_matrix) && is.null(X)) {
        stop("You must provide either the raw predictor data 'X' or a 'dissimilarity_matrix'.")
    }
    
    predictor_names <- if (!is.null(X)) colnames(X[[1]]) else colnames(dissimilarity_matrix)
    if (is.null(predictor_names)) stop("Predictors must have column names.")
    
    # --- STAGE 1: Time-Aware Module Construction (or skip) ---
    if (!is.null(dissimilarity_matrix)) {
        message("Step 1/3: Skipped. Using provided dissimilarity matrix.")
        A_combined_dissim <- dissimilarity_matrix
    } else {
        message("Step 1/3: Constructing consensus network from raw data 'X'...")
        n_predictors <- length(predictor_names)
        toms_array <- array(0, dim = c(n_predictors, n_predictors, length(X)),
                            dimnames = list(predictor_names, predictor_names, NULL))
        
        for (i in 1:length(X)) {
            soft_power <- tryCatch(select_soft_power(as.matrix(X[[i]])), error = function(e) 6)
            adj_matrix <- WGCNA::adjacency(as.matrix(X[[i]]), power = soft_power, type = "signed")
            toms_array[,,i] <- WGCNA::TOMsimilarity(adj_matrix, TOMType = "signed")
        }
        consensus_tom <- apply(toms_array, c(1, 2), min)
        A_combined_dissim <- 1 - consensus_tom
    }
    
    # --- Data Formatting ---
    message("Step 2/3: Formatting data for analysis...")
    all_X_combined <- do.call(rbind, lapply(X, as.data.frame))
    long_df <- data.frame(
        y = Y,
        patient = id,
        time_numeric = as.numeric(as.factor(time))
    )
    long_df <- cbind(long_df, all_X_combined)
    
    # --- STAGE 2 & 3: Call the internal algorithm ---
    message("Step 3/3: Running stability selection...")
    
    internal_results <- TemporalTree_time(
        data = long_df,
        A_combined = A_combined_dissim,
        var_select = predictor_names,
        cluster = "patient",
        fixed_regress = "time_numeric",
        number_selected_final = n_features_to_select,
        minClusterSize = min_module_size,
        n_boot_screen = n_boot_screen,
        keep_fraction_screen = keep_fraction_screen,
        n_boot_select = n_boot_select,
        alpha_screen = alpha_screen,
        alpha_select = alpha_select
    )
    
    # --- Finalize and Return Output ---
    message("Done.")
    
    results <- list(
        top_features = internal_results$final_selection,
        candidate_features = internal_results$second_stage_splitters
    )
    
    class(results) <- "TemporalForest"
    return(results)
}

#' Print Method for TemporalForest Objects
#'
#' @param x An object of class `TemporalForest`.
#' @param ... Additional arguments passed to `print`.
#' @return Invisibly returns the input object `x`.
#' @export
print.TemporalForest <- function(x, ...) {
    cat("--- Temporal Forest Results ---\n\n")
    n_selected <- length(x$top_features)
    cat(sprintf("Top %d feature(s) selected:\n", n_selected))
    if (n_selected > 0) {
        cat(paste(" ", x$top_features, collapse = "\n"), "\n\n")
    } else {
        cat("  (No features were selected)\n\n")
    }
    
    n_candidates <- length(x$candidate_features)
    cat(sprintf("%d feature(s) were candidates in the final stage.\n", n_candidates))
    
    invisible(x)
}

Try the TemporalForest package in your browser

Any scripts or data that you put into this service are public.

TemporalForest documentation built on Dec. 23, 2025, 1:06 a.m.