# Linear Discriminant Analysis
library(Matrix)
source("covariance.R")
# LDA computations
# One component in LDA calc
di = function(i, means, Sigma)
{
    xi = means[i, ]
    # This isn't storing the matrix factorization. Maybe solving for a
    # vector doesn't require this?
    a = solve(Sigma, xi)
    as.numeric(xi %*% a)
}
lda2 = function(X0, groups)
{
    # Each row contains a group mean
    means = by(X0, groups, colMeans)
    means = do.call(rbind, means)
    means = Matrix(means)
    Sigma = cov_Matrix_pkg(X0)
    # Cholesky decompositions are cached. Doing it here so it propagates into
    # the functions.
    chol(Sigma)
    d = sapply(1:k, di, means = means, Sigma = Sigma)
    d = d / 2
    out = list(Sigma = Sigma, d = d, means = means)
    class(out) = "lda2"
    out
}
predict.lda2 = function(fit, X)
{
    Sigma = fit$Sigma
    d = fit$d
    means = fit$means
    Sigma_inv_Xt = solve(Sigma, t(X))
    obj = means %*% Sigma_inv_Xt - d
    maxs = apply(obj, 2, which.max)
    maxs
}
# Testing data:
############################################################
library(MASS)
n = 10000
p = 50
k = 4
set.seed(891234)
X0 = matrix(rnorm(n * p), ncol = p)
colnames(X0) = paste0("X", 1:p)
groups = rep(1:k, length.out = n)
X = Matrix(rnorm(10000 * p), ncol = p)
Xd = as.data.frame(as.matrix(X))
colnames(Xd) = colnames(X0)
X0groups = data.frame(X0, groups)
fit = lda(groups ~ ., X0groups)
p0 = as.integer(predict(fit, Xd)$class)
fit2 = lda2(X0, groups)
p1 = predict(fit2, X)
mean(p0 == p1)
# 1 in 10000 is off, but not sure why.
# This is in the docs:
#
#     This version centres the linear discriminants so that the weighted
#     mean (weighted by ‘prior’) of the group centroids is at the
#     origin.
#
# Timings
############################################################
if(FALSE)
{
library(microbenchmark)
microbenchmark(lda(groups ~ ., X0groups), times = 10L)
microbenchmark(lda2(X0, groups), times = 10L)
# So we get a speedup of 2-3 x
# How much time is spent in covariance calc?
# Over 40%
#
# Also 48% in `by`. Which means it's quite inefficient, considering that
# column means can be computed in place with exactly one loop through the
# data. I'll bet data.table is really good at this.
#
# scale() is also a big offender at 19%, half the time of the covariance
# calc. The inefficient part of scale() is in the sweep() function. All we
# really need to do is subtract the column means
Rprof("lda.out")
replicate(100, lda2(X0, groups))
Rprof(NULL)
summaryRprof("lda.out")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.