#' Discriminant analysis with Trace Regularization (DTR)
#'
#' Finds a low-dimensional discriminant subspace that maximizes
#' the between-class scatter while controlling the within-class scatter.
#'
#' @param X numeric matrix of predictors, of dimension n x p.
#' @param Y factor variable of class labels, of length n.
#' @param preproc A preprocessing function to apply to the data. Default is centering.
#' @param d integer, the dimension of the discriminant subspace. Must be <= K-1 where K is the number of classes.
#' @param alpha numeric, tuning parameter in [0,1] that controls the trade-off between between-class and within-class scatters.
#'
#' @return An S3 object of class "discriminant_projector" containing the transformation matrix W,
#' the transformed scores, and related metadata.
#'
#' @references
#' Ahn, J., Chung, H. C., & Jeon, Y. (2021). Trace Ratio Optimization for High-Dimensional Multi-Class Discrimination. Journal of Computational and Graphical Statistics, 30(1), 192-203. \doi{10.1080/10618600.2020.1807352}
#'
#' @examples
#' X = matrix(rnorm(100*1000), 100, 1000)
#' y = sample(1:3, 100, replace=TRUE)
#' V = dtrda(X, y, d=2, alpha=0.5, lambda=0.1)
#' Xp = X %*% V # project data onto discriminant subspace
#'
#' @export
dtrda <- function(X, Y, preproc=multivarious::center(), d=2, alpha) {
Y <- as.factor(Y)
procres <- multivarious::prep(preproc)
Xp <- init_transform(procres, X)
n <- nrow(Xp)
p <- ncol(Xp)
K <- length(unique(Y))
assertthat::assert_that(d <= K-1, "d must be less than the number of classes minus 1")
if (p > n) {
# Compute orthonormal basis of row space of Xc
P <- svd(Xp)$u[,1:(n-1)]
# Project data onto row space
Z <- Xp %*% P
} else {
Z <- Xp
}
# Compute scatter matrices in projected space
M <- between_class_scatter(Z, Y)
S <- within_class_scatter(Z, Y)
# Form regularized matrix in projected space
B <- (1-alpha)*M - alpha*S
# Compute d leading eigenvectors of B
eig <- eigen(B)
U <- eig$vectors[,1:d,drop=FALSE]
if (p > n) {
# Back-transform eigenvectors to original space
V <- P %*% U
} else {
V <- U
}
s <- Xp %*% V
multivarious::discriminant_projector(v = V,
s = s,
sdev = apply(s, 2, sd),
preproc = procres,
labels = Y,
classes = "dtrda")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.