R/proto.R

Defines functions proto

Documented in proto

#' Extract the Prototype from Each Variable Group
#'
#' This function extracts the prototype from each variable group.
#'
#' @param X Predictor matrix
#' @param y Response matrix with one column
#' @param groups An group index vector containing the
#' group number each variable belongs to. For example:
#' \code{c(1, 1, 1, 1, 1, 2, 2, 2, ...)}.
#' Variable groups can be generated by the Fisher optimal
#' partition algorithm implemented in \code{\link{FOP}}.
#' @param type The rule for extracting the prototype.
#' Possible options are \code{"max"} and \code{"median"}.
#' @param mu The mean value of \code{y} for standarization.
#' Default is \code{NULL}, which uses the sample mean
#' of \code{y}.
#'
#' @return The prototypes (variable index) extracted
#' from each group (cluster).
#'
#' @export proto

proto <- function(X, y, groups, type = c("max", "median"), mu = NULL) {

  # check inputs
  if (length(y) != nrow(X)) {
    stop("X and y should have the same number of rows")
  }

  # extract some parameters # dimensions
  n <- length(y)
  p <- ncol(X)

  # sort out the data

  # standardise X
  X.tilde <- apply(X, 2L, function(col) col - mean(col))
  X.tilde <- apply(X.tilde, 2L, function(col) col / sqrt(sum(col^2)))

  # is mu specified? adjust y accordingly
  y.tilde <- if (is.null(mu)) y - mean(y) else y - mu

  # groups: this is a vector of index which denote the groups
  unique.groups <- unique(groups)
  K <- length(unique.groups)

  # proto.matrix = matrix(NA, n, K)
  col.ind <- NULL

  for (i in 1L:K) {
    var.index <- which(groups == i)
    abs.cor <- abs(t(X.tilde[, var.index]) %*% y.tilde)

    # find the maximum absolute correlation
    if (type == "max") {
      index <- which.max(abs.cor)
      index <- var.index[index[1]]
      col.ind[i] <- index
      # proto.matrix[, i] = X[, index,drop = FALSE]
    }

    # find the median of absolute correlation
    if (type == "median") {
      if (length(abs.cor) %% 2 == 1) {
        med.cor <- median(abs.cor)
        index <- which(abs.cor == med.cor)
        med.index <- var.index[index]
        col.ind[i] <- med.index
        # proto.matrix[, i] = x[, med.index, drop = FALSE]
      } else {
        min.index <- which.min(abs.cor)
        abs.cor <- abs.cor[-min.index]
        med.cor <- median(abs.cor)
        index <- which(abs.cor == med.cor)
        med.index <- var.index[index]
        col.ind[i] <- med.index
        # proto.matrix[, i] = x[, med.index, drop = FALSE]
      }
    }
  }

  col.ind
}
road2stat/OHPL documentation built on Feb. 4, 2023, 6:24 a.m.