R/proto.R

Defines functions proto

Documented in proto

#' Extract the prototype from each variable group
#'
#' Extracts the prototype from each variable group.
#'
#' @param X Predictor matrix.
#' @param y Response matrix with one column.
#' @param groups An group index vector containing the
#'   group number each variable belongs to. For example:
#'   `c(1, 1, 1, 1, 1, 2, 2, 2, ...)`.
#'   Variable groups can be generated by the Fisher optimal
#'   partition algorithm implemented in [FOP()].
#' @param type The rule for extracting the prototype.
#'   Possible options are `"max"` and `"median"`.
#' @param mu The mean value of `y` for standardization.
#'   Default is `NULL`, which uses the sample mean of `y`.
#'
#' @return The prototypes (variable index) extracted
#'   from each group (cluster).
#'
#' @export proto
proto <- function(X, y, groups, type = c("max", "median"), mu = NULL) {
  # Check inputs
  if (length(y) != nrow(X)) {
    stop("X and y should have the same number of rows")
  }

  # Extract some parameters # dimensions
  n <- length(y)
  p <- ncol(X)

  # Sort out the data

  # Standardize X
  X.tilde <- apply(X, 2L, function(col) col - mean(col))
  X.tilde <- apply(X.tilde, 2L, function(col) col / sqrt(sum(col^2)))

  # Is mu specified? adjust y accordingly
  y.tilde <- if (is.null(mu)) y - mean(y) else y - mu

  # groups: this is a vector of index which denote the groups
  unique.groups <- unique(groups)
  K <- length(unique.groups)

  # proto.matrix = matrix(NA, n, K)
  col.ind <- NULL

  for (i in 1L:K) {
    var.index <- which(groups == i)
    abs.cor <- abs(t(X.tilde[, var.index]) %*% y.tilde)

    # Find the maximum absolute correlation
    if (type == "max") {
      index <- which.max(abs.cor)
      index <- var.index[index[1]]
      col.ind[i] <- index
      # proto.matrix[, i] = X[, index,drop = FALSE]
    }

    # Find the median of absolute correlation
    if (type == "median") {
      if (length(abs.cor) %% 2 == 1) {
        med.cor <- median(abs.cor)
        index <- which(abs.cor == med.cor)
        med.index <- var.index[index]
        col.ind[i] <- med.index
        # proto.matrix[, i] = x[, med.index, drop = FALSE]
      } else {
        min.index <- which.min(abs.cor)
        abs.cor <- abs.cor[-min.index]
        med.cor <- median(abs.cor)
        index <- which(abs.cor == med.cor)
        med.index <- var.index[index]
        col.ind[i] <- med.index
        # proto.matrix[, i] = x[, med.index, drop = FALSE]
      }
    }
  }

  col.ind
}
road2stat/ohpl documentation built on July 24, 2024, 6:41 p.m.