R/cluster_data.R

Defines functions .print_covariate_profiles .run_covariate_analysis .compute_cluster_profiles .resolve_covariates plot.net_clustering summary.net_clustering print.net_clustering build_clusters .extract_sequence_data .dissimilarity_matrix_stringdist .dissimilarity_matrix_r .jaccard_matrix_tdm .cosine_matrix_tdm .qgram_matrix_tdm .build_qgram_tdm .hamming_matrix_outer .jw_dist_r .jaro_dist_r .jaccard_dist_r .cosine_dist_r .qgram_dist_r .get_qgram_r .lcs_dist_r .dl_dist_r .osa_dist_r .levenshtein_dist_r .hamming_dist_r .encode_sequences

Documented in build_clusters plot.net_clustering print.net_clustering summary.net_clustering

# ---- Sequence Clustering ----

# Available metrics and methods (private constants)
.clustering_metrics <- c(
  "hamming", "osa", "lv", "dl", "lcs", "qgram", "cosine", "jaccard", "jw"
)
.clustering_methods <- c(
  "pam", "ward.D2", "ward.D", "complete", "average",
  "single", "mcquitty", "median", "centroid"
)

# ==============================================================================
# 1. Sequence Encoding
# ==============================================================================

#' Encode sequences as integer matrix
#'
#' @param data Data frame or matrix of sequences (rows = sequences, cols = time)
#' @param na_syms Character vector of symbols to treat as NA
#' @return List with int_mat, len, states, n_states
#' @noRd
.encode_sequences <- function(data, na_syms = c("*", "%")) {
  mat <- as.matrix(data)
  storage.mode(mat) <- "character"
  mat[mat %in% na_syms] <- NA
  obs <- !is.na(mat)
  last_obs <- max.col(obs, ties.method = "last")
  all_na <- rowSums(obs) == 0L
  last_obs[all_na] <- 0L
  states <- sort(unique(mat[obs]))
  n_states <- length(states)
  int_mat <- matrix(match(mat, states), nrow = nrow(mat), ncol = ncol(mat))
  list(int_mat = int_mat, len = last_obs, states = states, n_states = n_states)
}

# ==============================================================================
# 2. Distance Functions (R fallback)
# ==============================================================================

#' @noRd
.hamming_dist_r <- function(x, y, n, m, weights, ...) {
  sum((x != y) * weights)
}

#' @noRd
.levenshtein_dist_r <- function(x, y, n, m, ...) {
  d <- matrix(0L, n + 1L, m + 1L)
  d[, 1L] <- 0:n
  d[1L, ] <- 0:m
  for (i in seq_len(n)) {
    for (j in seq_len(m)) {
      cost <- as.integer(x[i] != y[j])
      d[i + 1L, j + 1L] <- min(
        d[i, j + 1L] + 1L,
        d[i + 1L, j] + 1L,
        d[i, j] + cost
      )
    }
  }
  d[n + 1L, m + 1L]
}

#' @noRd
.osa_dist_r <- function(x, y, n, m, ...) {

  d <- matrix(0L, n + 1L, m + 1L)
  d[, 1L] <- 0:n
  d[1L, ] <- 0:m
  for (i in seq_len(n)) {
    for (j in seq_len(m)) {
      cost <- as.integer(x[i] != y[j])
      d[i + 1L, j + 1L] <- min(
        d[i, j + 1L] + 1L,
        d[i + 1L, j] + 1L,
        d[i, j] + cost
      )
      if (i > 1L && j > 1L && x[i] == y[j - 1L] && x[i - 1L] == y[j]) {
        d[i + 1L, j + 1L] <- min(d[i + 1L, j + 1L], d[i - 1L, j - 1L] + 1L)
      }
    }
  }
  d[n + 1L, m + 1L]
}

#' @noRd
.dl_dist_r <- function(x, y, n, m, ...) {
  if (n == 0L) return(m)
  if (m == 0L) return(n)
  x_sub <- x[seq_len(n)]
  y_sub <- y[seq_len(m)]
  alphabet <- unique(c(x_sub, y_sub))
  last <- stats::setNames(rep(0L, length(alphabet)), as.character(alphabet))
  d <- matrix(0L, n + 2L, m + 2L)
  maxdist <- n + m
  d[1L, ] <- maxdist
  d[, 1L] <- maxdist
  for (i in 0:n) d[i + 2L, 2L] <- i
  for (j in 0:m) d[2L, j + 2L] <- j
  for (i in seq_len(n)) {
    db <- 0L
    for (j in seq_len(m)) {
      k <- last[[as.character(y[j])]]
      if (is.null(k)) k <- 0L # nocov
      l <- db
      if (x[i] == y[j]) {
        cost <- 0L
        db <- j
      } else {
        cost <- 1L
      }
      d[i + 2L, j + 2L] <- min(
        d[i + 1L, j + 2L] + 1L,
        d[i + 2L, j + 1L] + 1L,
        d[i + 1L, j + 1L] + cost,
        d[k + 1L, l + 1L] + (i - k - 1L) + (j - l - 1L) + 1L
      )
    }
    last[[as.character(x[i])]] <- i
  }
  d[n + 2L, m + 2L]
}

#' @noRd
.lcs_dist_r <- function(x, y, n, m, ...) {
  lcs_len <- matrix(0L, n + 1L, m + 1L)
  for (i in seq_len(n)) {
    for (j in seq_len(m)) {
      if (x[i] == y[j]) {
        lcs_len[i + 1L, j + 1L] <- lcs_len[i, j] + 1L
      } else {
        lcs_len[i + 1L, j + 1L] <- max(lcs_len[i, j + 1L], lcs_len[i + 1L, j])
      }
    }
  }
  n + m - 2L * lcs_len[n + 1L, m + 1L]
}

#' Extract q-gram frequency profile from integer sequence
#' @noRd
.get_qgram_r <- function(x, n, q) {
  if (n < q) return(integer()) # nocov start
  ng <- n - q + 1L
  grams <- character(ng)
  for (i in seq_len(ng)) {
    grams[i] <- paste0(x[i:(i + q - 1L)], collapse = "\x01")
  }
  tab <- table(grams)
  counts <- as.integer(tab)
  names(counts) <- names(tab)
  counts # nocov end
}

#' @noRd
.qgram_dist_r <- function(x, y, qx, qy, ...) {
  grams <- union(names(qx), names(qy))
  x_all <- qx[grams]
  y_all <- qy[grams]
  x_all[is.na(x_all)] <- 0L
  y_all[is.na(y_all)] <- 0L
  sum(abs(x_all - y_all))
}

#' @noRd
.cosine_dist_r <- function(x, y, qx, qy, ...) {
  grams <- union(names(qx), names(qy))
  x_all <- qx[grams]
  y_all <- qy[grams]
  x_all[is.na(x_all)] <- 0L
  y_all[is.na(y_all)] <- 0L
  num <- sum(x_all * y_all)
  den <- sqrt(sum(x_all^2)) * sqrt(sum(y_all^2))
  if (den == 0) return(1)
  1 - num / den
}

#' @noRd
.jaccard_dist_r <- function(qx, qy, ...) {
  qx_names <- names(qx)
  qy_names <- names(qy)
  if (length(qx_names) == 0L && length(qy_names) == 0L) return(0)
  1 - length(intersect(qx_names, qy_names)) / length(union(qx_names, qy_names))
}

#' @noRd
.jaro_dist_r <- function(x, y, n, m, ...) {
  if (n == 0L && m == 0L) return(0) # nocov start
  if (n == 0L || m == 0L) return(1) # nocov end
  match_dist <- max(floor(max(n, m) / 2) - 1L, 0L)
  x_match <- rep(FALSE, n)
  y_match <- rep(FALSE, m)
  matches <- 0L
  for (i in seq_len(n)) {
    start <- max(1L, i - match_dist)
    end <- min(m, i + match_dist)
    if (start > end) next # nocov
    for (j in start:end) {
      if (!y_match[j] && x[i] == y[j]) {
        x_match[i] <- TRUE
        y_match[j] <- TRUE
        matches <- matches + 1L
        break
      }
    }
  }
  if (matches == 0L) return(1) # nocov
  trans <- 0L
  k <- 1L
  for (i in seq_len(n)) {
    if (x_match[i]) {
      while (k <= m && !y_match[k]) k <- k + 1L
      if (k <= m && x[i] != y[k]) trans <- trans + 1L
      k <- k + 1L
    }
  }
  trans <- trans / 2L
  jaro_sim <- (matches / n + matches / m + (matches - trans) / matches) / 3
  1 - jaro_sim
}

#' @noRd
.jw_dist_r <- function(x, y, n, m, p = 0.1, max_l = 4L, ...) {
  jaro_sim <- 1 - .jaro_dist_r(x, y, n, m)
  if (jaro_sim == 0) return(1) # nocov
  pref <- min(n, m, max_l)
  l <- 0L
  for (k in seq_len(pref)) {
    if (x[k] == y[k]) {
      l <- l + 1L
    } else {
      break
    }
  }
  p <- min(p, 0.25)
  jw <- jaro_sim + l * p * (1 - jaro_sim)
  1 - jw
}

# ==============================================================================
# 3. Vectorized Distance Matrix (hamming, qgram-family)
# ==============================================================================

#' Hamming distance matrix via column-wise outer() — no pair loop
#' @noRd
.hamming_matrix_outer <- function(mat_s, weights) {
  n <- nrow(mat_s)
  k <- ncol(mat_s)
  d <- matrix(0, n, n)
  if (length(weights) == 1L) {
    # Unweighted: just count mismatches per column
    for (c in seq_len(k)) {
      d <- d + outer(mat_s[, c], mat_s[, c], "!=")
    }
  } else {
    for (c in seq_len(k)) {
      d <- d + outer(mat_s[, c], mat_s[, c], "!=") * weights[c]
    }
  }
  stats::as.dist(d)
}

#' Build term-document matrix of q-gram counts (rows = seqs, cols = qgrams)
#' @noRd
.build_qgram_tdm <- function(mat_s, len, q) {
  n <- nrow(mat_s)
  # Collect all q-gram profiles
  profiles <- vector("list", n)
  all_grams <- character(0L)
  for (i in seq_len(n)) {
    if (len[i] >= q) {
      ng <- len[i] - q + 1L
      grams <- character(ng)
      for (j in seq_len(ng)) {
        grams[j] <- paste0(mat_s[i, j:(j + q - 1L)], collapse = "\x01")
      }
      tab <- table(grams)
      profiles[[i]] <- tab
      all_grams <- union(all_grams, names(tab))
    }
  }
  # Build matrix: rows = sequences, cols = unique qgrams
  tdm <- matrix(0L, nrow = n, ncol = length(all_grams))
  colnames(tdm) <- all_grams
  for (i in seq_len(n)) {
    if (!is.null(profiles[[i]])) {
      tdm[i, names(profiles[[i]])] <- as.integer(profiles[[i]])
    }
  }
  tdm
}

#' Q-gram distance matrix via term-document matrix subtraction
#' @noRd
.qgram_matrix_tdm <- function(tdm) {
  n <- nrow(tdm)
  # |a_i - b_i| summed over all qgrams = Manhattan distance between rows
  # Expand via outer difference per column, then sum absolute values
  d <- matrix(0, n, n)
  for (c in seq_len(ncol(tdm))) {
    d <- d + abs(outer(tdm[, c], tdm[, c], "-"))
  }
  stats::as.dist(d)
}

#' Cosine distance matrix via matrix multiply
#' @noRd
.cosine_matrix_tdm <- function(tdm) {
  # cosine_sim = (A %*% t(A)) / (norm_a * norm_b)
  norms <- sqrt(rowSums(tdm^2))
  norms[norms == 0] <- 1  # avoid division by zero (returns dist=1 for zero vecs)
  dot <- tcrossprod(tdm)
  sim <- dot / outer(norms, norms)
  # Clamp to [0, 1] for numerical safety
  sim[sim > 1] <- 1
  sim[sim < 0] <- 0
  # Handle zero-norm rows: distance = 1
  zero_rows <- which(rowSums(tdm^2) == 0)
  d <- 1 - sim
  diag(d) <- 0
  if (length(zero_rows) > 0L) {
    d[zero_rows, ] <- 1
    d[, zero_rows] <- 1
    d[cbind(zero_rows, zero_rows)] <- 0
  }
  stats::as.dist(d)
}

#' Jaccard distance matrix via set intersection/union on qgram presence
#' @noRd
.jaccard_matrix_tdm <- function(tdm) {
  # Binary presence matrix
  bin <- (tdm > 0L) * 1L
  # intersection = A %*% t(B), union = |A| + |B| - intersection
  isect <- tcrossprod(bin)
  sizes <- rowSums(bin)
  union_mat <- outer(sizes, sizes, "+") - isect
  # Avoid 0/0 (both empty → distance 0)
  d <- ifelse(union_mat == 0, 0, 1 - isect / union_mat)
  diag(d) <- 0
  stats::as.dist(d)
}

# ==============================================================================
# 4. Pairwise Distance Matrix (DP distances — pair loop unavoidable)
# ==============================================================================

#' Compute pairwise dissimilarity matrix using pure R
#' @noRd
.dissimilarity_matrix_r <- function(enc, dissimilarity, lambda,
                                    q = 2L, p = 0.1) {
  int_mat <- enc$int_mat
  len <- enc$len
  n <- nrow(int_mat)
  k <- ncol(int_mat)
  sentinel <- enc$n_states + 1L

  # Full-width matrix with sentinel for NA
  mat_s <- int_mat
  mat_s[is.na(mat_s)] <- sentinel

  # --- Vectorized fast paths ---
  if (dissimilarity == "hamming") {
    weights <- 1
    if (lambda > 0) {
      weights <- exp(-lambda * seq(0, k - 1))
      weights <- weights / max(weights)
    }
    return(.hamming_matrix_outer(mat_s, weights))
  }

  if (dissimilarity %in% c("qgram", "cosine", "jaccard")) {
    tdm <- .build_qgram_tdm(mat_s, len, q)
    if (dissimilarity == "qgram") return(.qgram_matrix_tdm(tdm))
    if (dissimilarity == "cosine") return(.cosine_matrix_tdm(tdm))
    return(.jaccard_matrix_tdm(tdm))
  }

  # --- Pair loop for DP distances (osa, lv, dl, lcs, jw) ---
  d_fun <- switch(dissimilarity,
    osa = .osa_dist_r,
    lv = .levenshtein_dist_r,
    dl = .dl_dist_r,
    lcs = .lcs_dist_r,
    jw = .jw_dist_r
  )

  d <- matrix(0, n, n)
  pairs <- combn(n, 2L)

  dists <- vapply(seq_len(ncol(pairs)), function(idx) {
    i <- pairs[1L, idx]
    j <- pairs[2L, idx]
    if (dissimilarity == "jw") {
      d_fun(x = mat_s[i, ], y = mat_s[j, ],
            n = len[i], m = len[j], p = p)
    } else {
      d_fun(x = mat_s[i, ], y = mat_s[j, ],
            n = len[i], m = len[j])
    }
  }, numeric(1L))

  d[lower.tri(d)] <- dists
  d <- d + t(d)
  stats::as.dist(d)
}

# ==============================================================================
# 4. stringdist Fast Path
# ==============================================================================

#' Compute pairwise dissimilarity matrix using stringdist (C-level)
#' @noRd
.dissimilarity_matrix_stringdist <- function(enc, dissimilarity, lambda,
                                             q = 2L, p = 0.1) {
  int_mat <- enc$int_mat
  len <- enc$len
  n_states <- enc$n_states
  k <- ncol(int_mat)

  base_chars <- c(letters, LETTERS, as.character(0:9))
  sentinel_char <- base_chars[n_states + 1L]

  # Build character matrix (sentinel for NA)
  chr_mat <- matrix(sentinel_char, nrow = nrow(int_mat), ncol = k)
  obs <- !is.na(int_mat)
  chr_mat[obs] <- base_chars[int_mat[obs]]

  if (dissimilarity == "hamming") {
    # Weighted hamming not supported by stringdist — fall back to R
    if (lambda > 0) {
      return(.dissimilarity_matrix_r(enc, dissimilarity, lambda, q, p)) # nocov
    }
    # Full width strings (same length, NAs are sentinel chars)
    strings <- vapply(seq_len(nrow(chr_mat)), function(i) {
      paste0(chr_mat[i, ], collapse = "")
    }, character(1L))
    stringdist::stringdistmatrix(strings, method = "hamming")
  } else {
    # Truncate to effective length per sequence
    strings <- vapply(seq_len(nrow(chr_mat)), function(i) {
      if (len[i] == 0L) return("") # nocov
      paste0(chr_mat[i, seq_len(len[i])], collapse = "")
    }, character(1L))
    if (dissimilarity == "jw") {
      stringdist::stringdistmatrix(strings, method = "jw", p = p)
    } else if (dissimilarity %in% c("qgram", "cosine", "jaccard")) {
      stringdist::stringdistmatrix(strings, method = dissimilarity, q = q)
    } else {
      stringdist::stringdistmatrix(strings, method = dissimilarity)
    }
  }
}

# ==============================================================================
# 5. Input Extraction
# ==============================================================================

#' Extract wide-format sequence data from various object types
#'
#' Supports: data.frame, matrix, netobject, tna, cograph_network.
#' For tna objects, decodes integer-encoded data using stored labels.
#' Rejects association-method netobjects (cor/pcor/glasso/ising) since
#' those contain numeric observations, not categorical sequences.
#' @param data Input object
#' @return A data frame of wide-format sequences
#' @noRd
.extract_sequence_data <- function(data) {
  # --- netobject (from Nestimate) ---
  if (inherits(data, "netobject")) {
    seq_methods <- c("relative", "frequency", "co_occurrence", "attention")
    if (!is.null(data$method) && !(data$method %in% seq_methods)) {
      stop(sprintf(
        "build_clusters() requires sequence data, but this netobject uses method '%s'. ",
        data$method
      ), "Pass the original sequence data frame instead.", call. = FALSE)
    }
    if (is.null(data$data)) { # nocov start
      stop("netobject has no $data. Pass the original sequence data.", call. = FALSE)
    } # nocov end
    return(as.data.frame(data$data, stringsAsFactors = FALSE))
  }

  # --- tna model (integer-encoded with $labels) ---
  if (inherits(data, "tna")) {
    if (is.null(data$data) || is.null(data$labels)) { # nocov start
      stop("tna object missing $data or $labels.", call. = FALSE)
    } # nocov end
    # Decode integer matrix back to state names
    decoded <- matrix(data$labels[data$data], nrow = nrow(data$data),
                      ncol = ncol(data$data))
    colnames(decoded) <- colnames(data$data)
    return(as.data.frame(decoded, stringsAsFactors = FALSE))
  }

  # --- cograph_network ---
  if (inherits(data, "cograph_network")) {
    if (is.null(data$data)) { # nocov start
      stop("cograph_network has no $data. Pass the original sequence data.",
           call. = FALSE)
    } # nocov end
    return(as.data.frame(data$data, stringsAsFactors = FALSE))
  }

  # --- raw data.frame or matrix ---
  if (is.data.frame(data) || is.matrix(data)) {
    return(as.data.frame(data, stringsAsFactors = FALSE))
  }

  stop("Unsupported input type: ", paste(class(data), collapse = ", "),
       ". Expected data.frame, matrix, netobject, tna, or cograph_network.",
       call. = FALSE)
}

# ==============================================================================
# 6. Main Function
# ==============================================================================

#' Cluster Sequences by Dissimilarity
#'
#' @description
#' Clusters wide-format sequences using pairwise string dissimilarity and
#' either PAM (Partitioning Around Medoids) or hierarchical clustering.
#' Supports 9 distance metrics including temporal weighting for Hamming
#' distance. When the \pkg{stringdist} package is available, uses C-level
#' distance computation for 100-1000x speedup on edit distances.
#'
#' @param data Input data. Accepts multiple formats:
#' \describe{
#'   \item{data.frame / matrix}{Wide-format sequences (rows = sequences,
#'     columns = time points, values = state names).}
#'   \item{netobject}{A network object from \code{\link{build_network}}.
#'     Extracts the stored sequence data. Only valid for sequence-based
#'     methods (relative, frequency, co_occurrence, attention).}
#'   \item{tna}{A tna model from the tna package. Decodes the
#'     integer-encoded sequence data using stored labels.}
#'   \item{cograph_network}{A cograph network object. Extracts the
#'     stored sequence data.}
#' }
#' @param k Integer. Number of clusters (must be between 2 and
#'   \code{nrow(data) - 1}).
#' @param dissimilarity Character. Distance metric. One of \code{"hamming"},
#'   \code{"osa"} (optimal string alignment), \code{"lv"} (Levenshtein),
#'   \code{"dl"} (Damerau-Levenshtein), \code{"lcs"} (longest common
#'   subsequence), \code{"qgram"}, \code{"cosine"}, \code{"jaccard"},
#'   \code{"jw"} (Jaro-Winkler). Default: \code{"hamming"}.
#' @param method Character. Clustering method. \code{"pam"} for Partitioning
#'   Around Medoids, or a hierarchical method: \code{"ward.D2"},
#'   \code{"ward.D"}, \code{"complete"}, \code{"average"}, \code{"single"},
#'   \code{"mcquitty"}, \code{"median"}, \code{"centroid"}.
#'   Default: \code{"pam"}.
#' @param na_syms Character vector. Symbols treated as missing values.
#'   Default: \code{c("*", "\%")}.
#' @param weighted Logical. Apply exponential decay weighting to Hamming
#'   distance positions? Only valid when \code{dissimilarity = "hamming"}.
#'   Default: \code{FALSE}.
#' @param lambda Numeric. Decay rate for weighted Hamming. Higher values
#'   weight earlier positions more strongly. Default: 1.
#' @param seed Integer or NULL. Random seed for reproducibility. Default:
#'   \code{NULL}.
#' @param q Integer. Size of q-grams for \code{"qgram"}, \code{"cosine"},
#'   and \code{"jaccard"} distances. Default: \code{2L}.
#' @param p Numeric. Winkler prefix penalty for Jaro-Winkler distance
#'   (clamped to 0--0.25). Default: \code{0.1}.
#' @param covariates Optional. Post-hoc covariate analysis of cluster
#'   membership via multinomial logistic regression. Accepts:
#'   \describe{
#'     \item{formula}{\code{~ Age + Gender}}
#'     \item{character vector}{\code{c("Age", "Gender")}}
#'     \item{string}{\code{"Age + Gender"}}
#'     \item{data.frame}{All columns used as covariates}
#'     \item{NULL}{No covariate analysis (default)}
#'   }
#'   Covariates are looked up in \code{netobject$metadata} or
#'   non-sequence columns of the input data. For \code{tna} and
#'   \code{cograph_network} inputs, pass covariates as a data.frame.
#'   Results stored in \code{$covariates}. Requires the \pkg{nnet}
#'   package.
#' @param ... Additional arguments (currently unused).
#'
#' @return An object of class \code{"net_clustering"} containing:
#' \describe{
#'   \item{data}{The original input data.}
#'   \item{k}{Number of clusters.}
#'   \item{assignments}{Named integer vector of cluster assignments.}
#'   \item{silhouette}{Overall average silhouette width.}
#'   \item{sizes}{Named integer vector of cluster sizes.}
#'   \item{method}{Clustering method used.}
#'   \item{dissimilarity}{Distance metric used.}
#'   \item{distance}{The computed dissimilarity matrix (\code{dist} object).}
#'   \item{medoids}{Integer vector of medoid row indices (PAM only; NULL
#'     for hierarchical methods).}
#'   \item{seed}{Seed used (or NULL).}
#'   \item{weighted}{Logical, whether weighted Hamming was used.}
#'   \item{lambda}{Lambda value used (0 if not weighted).}
#' }
#'
#' @examples
#' seqs <- data.frame(V1 = c("A","B","C","A","B"), V2 = c("B","C","A","B","A"),
#'                    V3 = c("C","A","B","C","B"))
#' cl <- build_clusters(seqs, k = 2)
#' cl
#' \donttest{
#' seqs <- data.frame(
#'   V1 = sample(LETTERS[1:3], 20, TRUE), V2 = sample(LETTERS[1:3], 20, TRUE),
#'   V3 = sample(LETTERS[1:3], 20, TRUE), V4 = sample(LETTERS[1:3], 20, TRUE)
#' )
#' cl <- build_clusters(seqs, k = 2)
#' print(cl)
#' summary(cl)
#' }
#'
#' @importFrom utils combn
#' @export
build_clusters <- function(data, k, dissimilarity = "hamming", method = "pam",
                         na_syms = c("*", "%"), weighted = FALSE, lambda = 1,
                         seed = NULL, q = 2L, p = 0.1,
                         covariates = NULL, ...) {
  # --- Extract sequence data from supported objects ---
  raw_data <- data
  data <- .extract_sequence_data(data)

  # --- Input validation ---
  stopifnot(
    is.numeric(k), length(k) == 1L,
    is.character(dissimilarity), length(dissimilarity) == 1L,
    is.character(method), length(method) == 1L,
    is.character(na_syms),
    is.logical(weighted), length(weighted) == 1L,
    is.numeric(lambda), length(lambda) == 1L,
    is.numeric(q), length(q) == 1L,
    is.numeric(p), length(p) == 1L
  )
  k <- as.integer(k)
  q <- as.integer(q)
  n <- nrow(data)
  stopifnot(k >= 2L, k <= n - 1L)

  dissimilarity <- match.arg(dissimilarity, .clustering_metrics)
  method <- match.arg(method, .clustering_methods)

  if (weighted && dissimilarity != "hamming") {
    stop("Weighting is only supported for Hamming distance.", call. = FALSE)
  }
  lambda <- if (weighted) lambda else 0

  if (!is.null(seed)) set.seed(seed)

  # --- Encode sequences ---
  enc <- .encode_sequences(data, na_syms)

  # --- Compute distance matrix ---
  use_stringdist <- requireNamespace("stringdist", quietly = TRUE) &&
    enc$n_states < 62L &&  # need room for sentinel char
    !(dissimilarity == "hamming" && lambda > 0)  # weighted hamming not in stringdist

  if (use_stringdist) {
    dist_mat <- .dissimilarity_matrix_stringdist(
      enc, dissimilarity, lambda, q, p
    )
  } else {
    dist_mat <- .dissimilarity_matrix_r(enc, dissimilarity, lambda, q, p)
  }

  # --- Cluster ---
  medoids <- NULL
  if (method == "pam") {
    clust_result <- cluster::pam(dist_mat, diss = TRUE, k = k)
    assignments <- clust_result$clustering
    silhouette_score <- clust_result$silinfo$avg.width
    medoids <- clust_result$id.med
  } else {
    hc <- stats::hclust(dist_mat, method = method)
    assignments <- stats::cutree(hc, k = k)
    sil <- cluster::silhouette(assignments, dist = dist_mat)
    silhouette_score <- mean(sil[, 3L])
  }

  sizes <- c(table(assignments))

  # --- Post-hoc covariate analysis ---
  cov_result <- NULL
  cov_resolved <- .resolve_covariates(covariates, raw_data, n)
  if (!is.null(cov_resolved)) {
    cov_result <- .run_covariate_analysis(
      assignments, cov_resolved$cov_df, cov_resolved$rhs, k
    )
  }

  structure(
    list(
      data = data,
      k = k,
      assignments = assignments,
      silhouette = silhouette_score,
      sizes = sizes,
      method = method,
      dissimilarity = dissimilarity,
      distance = dist_mat,
      medoids = medoids,
      seed = seed,
      weighted = weighted,
      lambda = lambda,
      covariates = cov_result,
      network_method = if (inherits(raw_data, "netobject")) raw_data$method else NULL,
      build_args     = if (inherits(raw_data, "netobject")) raw_data$build_args else NULL
    ),
    class = "net_clustering"
  )
}


# ==============================================================================
# 6. S3 Methods
# ==============================================================================

#' Print Method for net_clustering
#'
#' @param x A \code{net_clustering} object.
#' @param ... Additional arguments (ignored).
#'
#' @return The input object, invisibly.
#'
#' @examples
#' seqs <- data.frame(V1 = c("A","B","C","A","B"), V2 = c("B","C","A","B","A"),
#'                    V3 = c("C","A","B","C","B"))
#' cl <- build_clusters(seqs, k = 2)
#' print(cl)
#' \donttest{
#' set.seed(1)
#' seqs <- data.frame(
#'   V1 = sample(c("A","B","C"), 20, TRUE),
#'   V2 = sample(c("A","B","C"), 20, TRUE),
#'   V3 = sample(c("A","B","C"), 20, TRUE)
#' )
#' cl <- build_clusters(seqs, k = 2)
#' print(cl)
#' }
#'
#' @export
print.net_clustering <- function(x, ...) {
  cat("Sequence Clustering\n")
  cat("  Method:       ", x$method, "\n")
  cat("  Dissimilarity:", x$dissimilarity,
      if (x$weighted) sprintf("(weighted, lambda = %g)", x$lambda) else "",
      "\n")
  cat("  Clusters:     ", x$k, "\n")
  cat("  Silhouette:   ", round(x$silhouette, 4), "\n")
  cat("  Cluster sizes:", paste(x$sizes, collapse = ", "), "\n")
  if (!is.null(x$medoids)) {
    cat("  Medoids:      ", paste(x$medoids, collapse = ", "), "\n")
  }
  if (!is.null(x$covariates)) {
    cov_names <- setdiff(
      unique(x$covariates$coefficients$variable), "(Intercept)"
    )
    cat("  Covariates:   ",
        paste(cov_names, collapse = ", "),
        sprintf("(post-hoc, %d predictors)", length(cov_names)), "\n")
  }
  invisible(x)
}

#' Summary Method for net_clustering
#'
#' @param object A \code{net_clustering} object.
#' @param ... Additional arguments (ignored).
#'
#' @return The input object, invisibly.
#'
#' @examples
#' seqs <- data.frame(V1 = c("A","B","C","A","B"), V2 = c("B","C","A","B","A"),
#'                    V3 = c("C","A","B","C","B"))
#' cl <- build_clusters(seqs, k = 2)
#' summary(cl)
#' \donttest{
#' set.seed(1)
#' seqs <- data.frame(
#'   V1 = sample(c("A","B","C"), 20, TRUE),
#'   V2 = sample(c("A","B","C"), 20, TRUE),
#'   V3 = sample(c("A","B","C"), 20, TRUE)
#' )
#' cl <- build_clusters(seqs, k = 2)
#' summary(cl)
#' }
#'
#' @export
summary.net_clustering <- function(object, ...) {
  dist_mat <- as.matrix(object$distance)
  k <- object$k
  assignments <- object$assignments

  cluster_stats <- vapply(seq_len(k), function(cl) {
    members <- which(assignments == cl)
    sz <- length(members)
    if (sz > 1L) {
      within <- dist_mat[members, members]
      mean_within <- mean(within[lower.tri(within)])
    } else {
      mean_within <- 0
    }
    c(size = sz, mean_within_dist = mean_within)
  }, numeric(2L))

  cluster_stats <- as.data.frame(t(cluster_stats))
  cluster_stats <- cbind(cluster = seq_len(k), cluster_stats)
  rownames(cluster_stats) <- NULL

  cat("Sequence Clustering Summary\n")
  cat("  Method:       ", object$method, "\n")
  cat("  Dissimilarity:", object$dissimilarity, "\n")
  cat("  Silhouette:   ", round(object$silhouette, 4), "\n\n")
  cat("Per-cluster statistics:\n")
  print(cluster_stats, row.names = FALSE)

  # --- Covariate analysis output ---
  if (!is.null(object$covariates)) {
    cov <- object$covariates
    cat("\n")
    .print_covariate_profiles(cov)
    cat("\nNote: Covariates are post-hoc and do not influence cluster assignments.\n")

    result <- list(cluster_stats = cluster_stats, covariates = cov)
    return(invisible(result))
  }

  invisible(cluster_stats)
}

#' Plot Sequence Clustering Results
#'
#' @param x A \code{net_clustering} object.
#' @param type Character. Plot type: \code{"silhouette"} (per-observation
#'   silhouette bars), \code{"mds"} (2D MDS projection), or
#'   \code{"heatmap"} (distance matrix heatmap ordered by cluster).
#'   Default: \code{"silhouette"}.
#' @param ... Additional arguments (currently unused).
#' @return A \code{ggplot} object (invisibly).
#'
#' @examples
#' seqs <- data.frame(V1 = c("A","B","C","A","B"), V2 = c("B","C","A","B","A"),
#'                    V3 = c("C","A","B","C","B"))
#' cl <- build_clusters(seqs, k = 2)
#' plot(cl, type = "silhouette")
#' \donttest{
#' set.seed(1)
#' seqs <- data.frame(
#'   V1 = sample(c("A","B","C"), 20, TRUE),
#'   V2 = sample(c("A","B","C"), 20, TRUE),
#'   V3 = sample(c("A","B","C"), 20, TRUE)
#' )
#' cl <- build_clusters(seqs, k = 2)
#' plot(cl, type = "silhouette")
#' }
#'
#' @import ggplot2
#' @export
plot.net_clustering <- function(x, type = c("silhouette", "mds", "heatmap",
                                             "predictors"), ...) {
  type <- match.arg(type)

  if (type == "predictors") {
    if (is.null(x$covariates)) {
      stop("No covariate analysis found. Run build_clusters() with covariates.",
           call. = FALSE)
    }
    return(.plot_covariate_forest(
      x$covariates$coefficients,
      sprintf("Covariate Effects on Cluster Membership (ref: Cluster %s)",
              x$covariates$fit$reference_cluster)
    ))
  }

  if (type == "silhouette") {
    sil <- cluster::silhouette(x$assignments, dist = x$distance)
    sil_df <- data.frame(
      obs = seq_len(nrow(sil)),
      cluster = factor(sil[, 1L]),
      width = sil[, 3L]
    )
    # Order by cluster then width within cluster
    sil_df <- sil_df[order(sil_df$cluster, -sil_df$width), ]
    sil_df$rank <- seq_len(nrow(sil_df))

    p <- ggplot(sil_df, aes(x = .data$rank, y = .data$width,
                             fill = .data$cluster)) +
      geom_col(width = 1) +
      coord_flip() +
      labs(x = "Observation", y = "Silhouette Width",
           title = sprintf("Silhouette Plot (avg = %.3f)", x$silhouette),
           fill = "Cluster") +
      theme_minimal() +
      geom_hline(yintercept = x$silhouette, linetype = "dashed", alpha = 0.5)

  } else if (type == "mds") {
    mds <- stats::cmdscale(x$distance, k = 2L)
    mds_df <- data.frame(
      dim1 = mds[, 1L],
      dim2 = mds[, 2L],
      cluster = factor(x$assignments)
    )

    p <- ggplot(mds_df, aes(x = .data$dim1, y = .data$dim2,
                             colour = .data$cluster)) +
      geom_point(size = 2.5) +
      labs(x = "MDS Dimension 1", y = "MDS Dimension 2",
           title = sprintf("MDS Projection (k = %d)", x$k),
           colour = "Cluster") +
      theme_minimal()

  } else {
    # Heatmap — order by cluster assignment
    ord <- order(x$assignments)
    dist_mat <- as.matrix(x$distance)[ord, ord]
    n <- nrow(dist_mat)

    # Build long-format data frame
    idx <- expand.grid(row = seq_len(n), col = seq_len(n))
    idx$distance <- as.vector(dist_mat)
    idx$row <- factor(idx$row, levels = seq_len(n))
    idx$col <- factor(idx$col, levels = rev(seq_len(n)))

    p <- ggplot(idx, aes(x = .data$row, y = .data$col,
                          fill = .data$distance)) +
      geom_tile() +
      scale_fill_gradient(low = "white", high = "steelblue") +
      labs(x = "Sequence", y = "Sequence",
           title = "Distance Matrix (ordered by cluster)",
           fill = "Distance") +
      theme_minimal() +
      theme(axis.text = element_blank(), axis.ticks = element_blank())
  }

  print(p)
  invisible(p)
}

# ==============================================================================
# 7. Covariate Analysis (post-hoc)
# ==============================================================================

#' Resolve covariates into a data.frame
#'
#' Handles five input forms: formula, character vector, string, data.frame,
#' NULL. Looks up column names in metadata or raw data as needed.
#' @param covariates User-supplied covariates argument
#' @param raw_data Original input to build_clusters (before extraction)
#' @param n Number of sequences
#' @return A list with \code{cov_df} (data.frame) and \code{formula} (formula),
#'   or NULL if covariates is NULL.
#' @noRd
.resolve_covariates <- function(covariates, raw_data, n) {
  if (is.null(covariates)) return(NULL)

  if (!requireNamespace("nnet", quietly = TRUE)) {
    stop("The 'nnet' package is required for covariate analysis. ", # nocov
         "Install it with install.packages('nnet').", call. = FALSE) # nocov
  }

  # --- Covariates supplied as a data.frame: use all columns ---
  if (is.data.frame(covariates)) {
    if (nrow(covariates) != n) {
      stop(sprintf(
        "Covariate data.frame has %d rows but sequence data has %d rows.",
        nrow(covariates), n
      ), call. = FALSE)
    }
    rhs <- paste(names(covariates), collapse = " + ")
    return(list(cov_df = covariates, rhs = rhs))
  }

  # --- Parse into column names ---
  if (inherits(covariates, "formula")) {
    rhs_text <- deparse(covariates[[length(covariates)]])
    cov_names <- all.vars(covariates)
  } else if (is.character(covariates) && length(covariates) == 1L &&
             grepl("+", covariates, fixed = TRUE)) {
    # String like "Age + Gender"
    rhs_text <- covariates
    cov_names <- all.vars(stats::as.formula(paste("~", covariates)))
  } else if (is.character(covariates)) {
    # Character vector like c("Age", "Gender")
    cov_names <- covariates
    rhs_text <- paste(cov_names, collapse = " + ")
  } else {
    stop("covariates must be a formula, character vector, string, or data.frame.",
         call. = FALSE)
  }

  if (length(cov_names) == 0L) {
    stop("No covariate variables specified.", call. = FALSE)
  }

  # --- Look up columns in metadata or raw data ---
  cov_df <- NULL

  # Try netobject: metadata first, then $data columns
  if (inherits(raw_data, "netobject")) {
    # Combine metadata and data columns as available sources
    avail_df <- NULL
    avail_names <- character(0L)
    if (!is.null(raw_data$metadata)) {
      avail_df <- raw_data$metadata
      avail_names <- names(raw_data$metadata)
    }
    if (!is.null(raw_data$data) && is.data.frame(raw_data$data)) {
      extra <- setdiff(names(raw_data$data), avail_names)
      if (length(extra) > 0L) {
        avail_names <- c(avail_names, extra)
        avail_df <- if (is.null(avail_df)) {
          raw_data$data[, extra, drop = FALSE]
        } else {
          cbind(avail_df, raw_data$data[, extra, drop = FALSE])
        }
      }
    }
    if (all(cov_names %in% avail_names)) {
      cov_df <- avail_df[, cov_names, drop = FALSE]
    } else {
      missing <- setdiff(cov_names, avail_names)
      stop(sprintf(
        "Covariate columns not found in netobject: %s. Available: %s",
        paste(missing, collapse = ", "),
        paste(avail_names, collapse = ", ")
      ), call. = FALSE)
    }
  }

  # Try tna / cograph_network: require data.frame covariates
  if (is.null(cov_df) &&
      (inherits(raw_data, "tna") || inherits(raw_data, "cograph_network"))) {
    stop("Column-name covariates are not supported for tna/cograph_network ",
         "inputs. Pass covariates as a data.frame instead.", call. = FALSE)
  }

  # Try raw data.frame columns
  if (is.null(cov_df) && (is.data.frame(raw_data) || is.matrix(raw_data))) {
    avail <- colnames(raw_data)
    if (all(cov_names %in% avail)) {
      cov_df <- as.data.frame(raw_data[, cov_names, drop = FALSE],
                               stringsAsFactors = FALSE)
    } else {
      missing <- setdiff(cov_names, avail)
      stop(sprintf(
        "Covariate columns not found: %s. Available: %s",
        paste(missing, collapse = ", "),
        paste(avail, collapse = ", ")
      ), call. = FALSE)
    }
  }

  if (is.null(cov_df)) {
    stop("Could not resolve covariate columns from the input data.", # nocov start
         call. = FALSE) # nocov end
  }

  if (nrow(cov_df) != n) { # nocov start
    stop(sprintf(
      "Covariate data has %d rows but sequence data has %d rows.",
      nrow(cov_df), n
    ), call. = FALSE) # nocov end
  }

  list(cov_df = cov_df, rhs = rhs_text)
}

#' Compute per-cluster descriptive profiles
#' @noRd
.compute_cluster_profiles <- function(cov_df, assignments, k) {
  n_total <- length(assignments)
  num_cols <- vapply(cov_df, is.numeric, logical(1L))
  cat_cols <- !num_cols

  # --- Numeric profiles ---
  numeric_profile <- NULL
  if (any(num_cols)) {
    num_names <- names(cov_df)[num_cols]
    rows <- vector("list", k * length(num_names))
    idx <- 0L
    for (cl in seq_len(k)) {
      mask <- assignments == cl
      cl_n <- sum(mask)
      cl_pct <- round(100 * cl_n / n_total, 1)
      for (vname in num_names) {
        idx <- idx + 1L
        vals <- cov_df[[vname]][mask]
        vals <- vals[!is.na(vals)]
        rows[[idx]] <- data.frame(
          cluster = cl, n = cl_n, pct = cl_pct, variable = vname,
          mean = mean(vals), sd = stats::sd(vals),
          median = stats::median(vals),
          stringsAsFactors = FALSE, row.names = NULL
        )
      }
    }
    numeric_profile <- do.call(rbind, rows)
  }

  # --- Categorical profiles ---
  categorical_profile <- NULL
  if (any(cat_cols)) {
    cat_names <- names(cov_df)[cat_cols]
    # Pre-count total rows to avoid O(n^2) list growth
    total_rows <- 0L
    for (cl in seq_len(k)) {
      mask <- assignments == cl
      for (vname in cat_names) {
        total_rows <- total_rows + length(unique(cov_df[[vname]][mask & !is.na(cov_df[[vname]])]))
      }
    }
    rows <- vector("list", total_rows)
    idx <- 0L
    for (cl in seq_len(k)) {
      mask <- assignments == cl
      cl_n <- sum(mask)
      for (vname in cat_names) {
        tab <- table(cov_df[[vname]][mask], useNA = "no")
        for (lev in names(tab)) {
          idx <- idx + 1L
          rows[[idx]] <- data.frame(
            cluster = cl, n = cl_n, variable = vname,
            level = lev, count = as.integer(tab[[lev]]),
            pct = round(100 * as.integer(tab[[lev]]) / cl_n, 1),
            stringsAsFactors = FALSE, row.names = NULL
          )
        }
      }
    }
    categorical_profile <- do.call(rbind, rows)
  }

  list(numeric = numeric_profile, categorical = categorical_profile)
}

#' Run multinomial logistic regression on cluster assignments
#' @noRd
.run_covariate_analysis <- function(assignments, cov_df, rhs, k) {
  # --- Prepare data ---
  fit_df <- cov_df

  # Coerce character to factor, unorder ordered factors
  for (nm in names(fit_df)) {
    if (is.character(fit_df[[nm]])) {
      fit_df[[nm]] <- factor(fit_df[[nm]])
    } else if (is.ordered(fit_df[[nm]])) {
      fit_df[[nm]] <- factor(fit_df[[nm]], ordered = FALSE)
    }
  }

  # Handle NAs: subset to complete cases
  complete <- stats::complete.cases(fit_df)
  n_dropped <- sum(!complete)
  if (n_dropped > 0L) {
    warning(sprintf(
      "Dropped %d rows with NA covariates (%d remaining).",
      n_dropped, sum(complete)
    ), call. = FALSE)
    fit_df <- fit_df[complete, , drop = FALSE]
    assignments <- assignments[complete]
  }

  fit_df$cluster <- factor(assignments)

  # Build formula inside the fitting environment
  fml <- stats::as.formula(paste("cluster ~", rhs))

  # Validate: no constant columns
  for (nm in setdiff(names(fit_df), "cluster")) {
    if (length(unique(fit_df[[nm]][!is.na(fit_df[[nm]])])) < 2L) {
      stop(sprintf("Covariate '%s' is constant. Remove it.", nm), call. = FALSE)
    }
  }

  # --- Fit multinomial logistic regression ---
  fit <- nnet::multinom(fml, data = fit_df, trace = FALSE)
  s <- summary(fit)

  # Warn if small clusters relative to parameter count
  n_params <- length(s$coefficients)
  min_cl <- min(table(assignments))
  if (min_cl < n_params) {
    warning(sprintf(
      "Smallest cluster has %d observations but model has %d parameters. ",
      min_cl, n_params
    ), "Estimates may be unreliable.", call. = FALSE)
  }

  # --- Extract coefficients ---
  coefs <- s$coefficients
  ses <- s$standard.errors

  # Handle k=2 case: summary returns named vectors, not matrices
  if (is.null(dim(coefs))) {
    coefs <- matrix(coefs, nrow = 1L, dimnames = list(levels(fit_df$cluster)[2L],
                                                       names(coefs)))
    ses <- matrix(ses, nrow = 1L, dimnames = list(levels(fit_df$cluster)[2L],
                                                    names(ses)))
  }

  # Build output data.frame
  rows <- vector("list", nrow(coefs) * ncol(coefs))
  idx <- 0L
  for (i in seq_len(nrow(coefs))) {
    for (j in seq_len(ncol(coefs))) {
      idx <- idx + 1L
      est <- coefs[i, j]
      se <- ses[i, j]
      z_val <- est / se
      p_val <- 2 * (1 - stats::pnorm(abs(z_val)))
      or <- exp(est)
      ci_lo <- exp(est - 1.96 * se)
      ci_hi <- exp(est + 1.96 * se)
      sig <- if (is.na(p_val)) "" else if (p_val < 0.001) "***" else
        if (p_val < 0.01) "**" else if (p_val < 0.05) "*" else
        if (p_val < 0.1) "." else ""
      rows[[idx]] <- data.frame(
        cluster = rownames(coefs)[i],
        variable = colnames(coefs)[j],
        estimate = est, std_error = se,
        odds_ratio = or, ci_lower = ci_lo, ci_upper = ci_hi,
        z = z_val, p = p_val, sig = sig,
        stringsAsFactors = FALSE, row.names = NULL
      )
    }
  }
  coef_df <- do.call(rbind, rows)

  # --- Model fit ---
  null_fit <- nnet::multinom(cluster ~ 1, data = fit_df, trace = FALSE) # nolint
  null_dev <- -2 * as.numeric(stats::logLik(null_fit))
  model_dev <- -2 * as.numeric(stats::logLik(fit))
  mcfadden_r2 <- 1 - model_dev / null_dev

  ref_cluster <- levels(fit_df$cluster)[1L]

  # --- Profiles (on complete-case data) ---
  profiles_df <- if (n_dropped > 0L) cov_df[complete, , drop = FALSE] else cov_df
  profiles <- .compute_cluster_profiles(profiles_df, assignments, k)

  list(
    profiles = profiles,
    coefficients = coef_df,
    fit = list(
      aic = stats::AIC(fit),
      bic = stats::BIC(fit),
      deviance = model_dev,
      mcfadden_r2 = mcfadden_r2,
      reference_cluster = ref_cluster
    ),
    model = fit
  )
}

#' Print covariate profiles and coefficient table (shared by build_clusters/mmm)
#' @noRd
.print_covariate_profiles <- function(cov, header = "Post-hoc Covariate Analysis (does not influence cluster membership)") {
  cat(header, "\n\n")

  # Numeric profiles
  if (!is.null(cov$profiles$numeric)) {
    cat("Cluster Profiles (numeric):\n")
    np <- cov$profiles$numeric
    vars <- unique(np$variable)
    rows <- lapply(unique(np$cluster), function(cl) {
      sub <- np[np$cluster == cl, , drop = FALSE]
      row <- data.frame(
        Cluster = cl,
        `N (%)` = sprintf("%d (%.0f%%)", sub$n[1L], sub$pct[1L]),
        stringsAsFactors = FALSE, check.names = FALSE
      )
      for (v in vars) {
        sv <- sub[sub$variable == v, ]
        row[[sprintf("%s Mean (SD)", v)]] <- sprintf("%.2f (%.2f)", sv$mean, sv$sd)
        row[[sprintf("%s Median", v)]] <- sprintf("%.2f", sv$median)
      }
      row
    })
    print(do.call(rbind, rows), row.names = FALSE, right = FALSE)
    cat("\n")
  }

  # Categorical profiles
  if (!is.null(cov$profiles$categorical)) {
    cat("Cluster Profiles (categorical):\n")
    cp <- cov$profiles$categorical
    # Pre-compute all variable=level combos for consistent columns
    all_cols <- unique(paste0(cp$variable, "=", cp$level))
    rows <- lapply(unique(cp$cluster), function(cl) {
      sub <- cp[cp$cluster == cl, , drop = FALSE]
      row <- data.frame(
        Cluster = cl, N = sub$n[1L],
        stringsAsFactors = FALSE, check.names = FALSE
      )
      for (col in all_cols) {
        parts <- strsplit(col, "=", fixed = TRUE)[[1L]]
        v <- parts[1L]
        lev <- paste(parts[-1L], collapse = "=")  # handle = in level names
        slev <- sub[sub$variable == v & sub$level == lev, ]
        if (nrow(slev) > 0L) {
          row[[sprintf("%s N(%%)", col)]] <- sprintf(
            "%d (%.0f%%)", slev$count, slev$pct)
        } else {
          row[[sprintf("%s N(%%)", col)]] <- "0 (0%)"
        }
      }
      row
    })
    print(do.call(rbind, rows), row.names = FALSE, right = FALSE)
    cat("\n")
  }

  # Coefficients table
  cat(sprintf("Predictors of Membership (reference: Cluster %s):\n",
              cov$fit$reference_cluster))
  coef_display <- cov$coefficients[
    cov$coefficients$variable != "(Intercept)", , drop = FALSE]
  disp <- data.frame(
    Cluster = coef_display$cluster,
    Variable = coef_display$variable,
    OR = sprintf("%.2f", coef_display$odds_ratio),
    `95% CI` = sprintf("[%.2f, %.2f]", coef_display$ci_lower,
                        coef_display$ci_upper),
    p = ifelse(coef_display$p < 0.001, "<0.001",
               sprintf("%.3f", coef_display$p)),
    Sig = coef_display$sig,
    stringsAsFactors = FALSE, check.names = FALSE
  )
  print(disp, row.names = FALSE, right = FALSE)
  cat(sprintf("\nModel: AIC = %.1f | BIC = %.1f | McFadden R-squared = %.2f\n",
              cov$fit$aic, cov$fit$bic, cov$fit$mcfadden_r2))
}

#' Odds ratio forest plot (shared by build_clusters/mmm)
#' @noRd
.plot_covariate_forest <- function(coef_df, title) {
  coef_df <- coef_df[coef_df$variable != "(Intercept)", , drop = FALSE]
  coef_df$cluster <- factor(coef_df$cluster)
  coef_df$significant <- coef_df$p < 0.05

  p <- ggplot(coef_df, aes(
    x = .data$odds_ratio, y = .data$variable,
    xmin = .data$ci_lower, xmax = .data$ci_upper,
    colour = .data$significant
  )) +
    geom_vline(xintercept = 1, linetype = "dashed", alpha = 0.5) +
    geom_pointrange(size = 0.6) +
    facet_wrap(~ .data$cluster, labeller = label_both) +
    scale_colour_manual(
      values = c("TRUE" = "#D62728", "FALSE" = "#7F7F7F"),
      labels = c("TRUE" = "p < 0.05", "FALSE" = "n.s."),
      name = NULL
    ) +
    scale_x_log10() +
    labs(x = "Odds Ratio (log scale)", y = NULL, title = title) +
    theme_minimal()

  print(p)
  invisible(p)
}

# ==============================================================================
# 8. build_network dispatch for net_clustering
# ==============================================================================

#' Build per-cluster networks from clustering results
#'
#' Called by \code{build_network()} when data is a \code{net_clustering}
#' object. Splits sequence data by cluster assignment and builds a network
#' for each cluster.
#'
#' @param x A \code{net_clustering} object.
#' @param method Network estimation method (default: "relative").
#' @param ... Passed to \code{build_network()}.
#' @return A \code{netobject_group}.
#' @noRd
.build_network_clustering <- function(x, method = "relative", ...) {
  assignments <- x$assignments
  seq_data <- x$data
  k <- x$k

  # Merge stored build_args with caller's ...; caller takes precedence
  dots <- list(...)
  build_args <- if (!is.null(x$build_args))
    modifyList(x$build_args, dots) else dots

  nets <- lapply(seq_len(k), function(cl) {
    sub <- seq_data[assignments == cl, , drop = FALSE]
    do.call(build_network, c(list(data = sub, method = method), build_args))
  })

  names(nets) <- paste("Cluster", seq_len(k))
  attr(nets, "group_col") <- "cluster"
  attr(nets, "clustering") <- x
  class(nets) <- "netobject_group"
  nets
}

# ==============================================================================
# 9. cluster_network — one-shot clustering + network estimation
# ==============================================================================

#' Cluster data and build per-cluster networks in one step
#'
#' Combines sequence clustering and network estimation into a single call.
#' Clusters the data using the specified algorithm, then calls
#' \code{\link{build_network}} on each cluster subset.
#'
#' If \code{data} is a \code{netobject} and \code{method} is not provided in
#' \code{...}, the original network method is inherited automatically so the
#' per-cluster networks match the type of the input network.
#'
#' @param data Sequence data. Accepts a data frame, matrix, or
#'   \code{netobject}. See \code{\link{build_clusters}} for supported formats.
#' @param k Integer. Number of clusters.
#' @param cluster_by Character. Clustering algorithm passed to
#'   \code{\link{build_clusters}}'s \code{method} parameter (\code{"pam"},
#'   \code{"ward.D2"}, \code{"ward.D"}, \code{"complete"}, \code{"average"},
#'   \code{"single"}, \code{"mcquitty"}, \code{"median"}, \code{"centroid"}),
#'   or \code{"mmm"} for Mixed Markov Model clustering. Default: \code{"pam"}.
#' @param dissimilarity Character. Distance metric for sequence clustering
#'   (ignored when \code{cluster_by = "mmm"}). Default: \code{"hamming"}.
#' @param ... Passed directly to \code{\link{build_network}}. Use
#'   \code{method} to specify the network type; \code{threshold},
#'   \code{scaling}, and all other \code{build_network} arguments are
#'   supported.
#' @return A \code{netobject_group}.
#' @seealso \code{\link{build_clusters}}, \code{\link{cluster_mmm}},
#'   \code{\link{build_network}}
#' @examples
#' seqs <- data.frame(V1 = c("A","B","C","A","B"), V2 = c("B","C","A","B","A"),
#'                    V3 = c("C","A","B","C","B"))
#' grp <- cluster_network(seqs, k = 2)
#' grp
#' \donttest{
#' seqs <- data.frame(
#'   V1 = sample(LETTERS[1:4], 50, TRUE), V2 = sample(LETTERS[1:4], 50, TRUE),
#'   V3 = sample(LETTERS[1:4], 50, TRUE), V4 = sample(LETTERS[1:4], 50, TRUE)
#' )
#' # Default: PAM clustering, relative (transition) networks
#' grp <- cluster_network(seqs, k = 3)
#'
#' # Specify network method (cor requires numeric panel data)
#' \dontrun{
#' panel <- as.data.frame(matrix(rnorm(1500), nrow = 300, ncol = 5))
#' grp <- cluster_network(panel, k = 2, method = "cor")
#' }
#'
#' # MMM-based clustering
#' grp <- cluster_network(seqs, k = 2, cluster_by = "mmm")
#' }
#' @export
cluster_network <- function(data, k, cluster_by = "pam",
                             dissimilarity = "hamming", ...) {
  dots <- list(...)

  # Inherit build_args and method from input netobject when not explicitly set
  if (inherits(data, "cograph_network") && !inherits(data, "netobject")) {
    data <- .as_netobject(data)
  }
  if (inherits(data, "netobject")) {
    if (!is.null(data$build_args))
      dots <- modifyList(data$build_args, dots)
    if (is.null(dots$method))
      dots$method <- data$method
  }

  if (identical(cluster_by, "mmm")) {
    mmm_fit <- build_mmm(data, k = k)
    return(do.call(build_network, c(list(data = mmm_fit), dots)))
  }

  cls <- build_clusters(data, k = k, method = cluster_by,
                        dissimilarity = dissimilarity)
  do.call(build_network, c(list(data = cls), dots))
}

Try the Nestimate package in your browser

Any scripts or data that you put into this service are public.

Nestimate documentation built on April 20, 2026, 5:06 p.m.