title: "MNIST investigation" author: "Eric Bridgeford" date: "December 14, 2017" output: html_document


require(ggplot2)
require(lol)
require(reshape2)
require(Rmisc)
require(randomForest)
require(gridExtra)
require(latex2exp)

MNIST

In this notebook, we investigate the performance of our Labelled High Dimensionality; Low Sample Size (LHDLSS) algorithms in the context of the popular MNIST database. The MNIST database consists of 60,000 training images and 10,000 testing images, each of which is a handwriting sample of a single numeric character from 0 through 9. Below, we will only use subsets of the training examples for our investigations.

Investigation

We will begin by loading in the MNIST dataset. We borrowed code for loading MNIST from MNIST loading functions:

# Load the MNIST digit recognition dataset into R
# http://yann.lecun.com/exdb/mnist/
# assume you have all 4 files and gunzip'd them
# creates train$n, train$x, train$y  and test$n, test$x, test$y
# e.g. train$x is a 60000 x 784 matrix, each row is one digit (28x28)
# call:  show_digit(train$x[5,])   to see a digit.
# brendan o'connor - gist.github.com/39760 - anyall.org

load_mnist <- function(sourcedir='.') {
  load_image_file <- function(filename) {
    ret = list()
    f = file(filename,'rb')
    readBin(f,'integer',n=1,size=4,endian='big')
    ret$n = readBin(f,'integer',n=1,size=4,endian='big')
    nrow = readBin(f,'integer',n=1,size=4,endian='big')
    ncol = readBin(f,'integer',n=1,size=4,endian='big')
    x = readBin(f,'integer',n=ret$n*nrow*ncol,size=1,signed=F)
    ret$x = matrix(x, ncol=nrow*ncol, byrow=T)
    close(f)
    ret
  }
  load_label_file <- function(filename) {
    f = file(filename,'rb')
    readBin(f,'integer',n=1,size=4,endian='big')
    n = readBin(f,'integer',n=1,size=4,endian='big')
    y = readBin(f,'integer',n=n,size=1,signed=F)
    close(f)
    y
  }
  train <- list()
  test <- list()

  train$x <- load_image_file(paste(sourcedir, 'mnist/train-images-idx3-ubyte', sep='/'))
  test$x <- load_image_file(paste(sourcedir, 'mnist/t10k-images-idx3-ubyte', sep='/'))

  train$y <- load_label_file(paste(sourcedir, 'mnist/train-labels-idx1-ubyte', sep='/'))
  test$y <- load_label_file(paste(sourcedir, 'mnist/t10k-labels-idx1-ubyte', sep='/'))
  return(list(X.train=train$x, X.test=test$x, Y.train=train$y, Y.test=test$y))
}

show_digit <- function(png, title="",xlabel="Pixel", ylabel="Pixel", legend.name="metric", legend.show=TRUE,
                                 font.size=12, limits=NULL) {
  mtx <- matrix(png, nrow=28)[,28:1]
  dm <- reshape2::melt(mtx)
  if (is.null(limits)) {
    limits <- c(min(mtx), max(mtx))
  }
  colnames(dm) <- c("x", "y", "value")
  jet.colors <- colorRampPalette(c("#00007F", "blue", "#007FFF", "cyan", "#7FFF7F", "yellow", "#FF7F00", "red", "#7F0000"))
  sqplot <- ggplot(dm, aes(x=x, y=y, fill=value)) +
    geom_tile() +
    scale_fill_gradientn(colours=jet.colors(7), name=legend.name, limits=limits) +
    xlab(xlabel) +
    ylab(ylabel) +
    ggtitle(title)
  if (legend.show) {
    sqplot <- sqplot +
      theme(text=element_text(size=font.size))
  } else {
    sqplot <- sqplot +
      theme(text=element_text(size=font.size, legend.position="none"))
  }
  return(sqplot)
}

result <- load_mnist('/home/eric/Downloads/')

we visualize a few digits to ensure we have loaded the dataset properly:

nd <- 6
rdigs <- sample(result$X.train$n, nd, replace=FALSE)

plotlist <- list()
for (i in 1:length(rdigs)) {
  plotlist[[i]] <- show_digit(result$X.train$x[rdigs[i],], title=result$Y.train[rdigs[i]]) + theme(legend.position="none")
}
Rmisc::multiplot(plotlist=plotlist, cols=ceiling(sqrt(nd)))

Experiments

For the purposes of these experiments, we will use 2x 2-class set ups and a 3-class setup. In the first 2-class setup, we will assess performance distinguishing the set of 3s from the set of 7s. These digits are relatively unique, and we expect fairly good performance. In the second 2-class set up, we will assess performance between the set of 1s and 7s. The difference is much smaller, particularly depending on how someone writes their 7s or 1s, so we expect more difficulty with classification. In the 3-class setup, we assess performance distinguishing the set of 3s, 1s, and 7s. For each algorithm, we will visualize the top 9 components of the projection matrix with $r=10$ on the first 2-class set up between the 3s and 7s, as well as 9 randomly chosen data examples projected onto the top 10 components. We subset the data below:

# 3s and 7s
set1 <- list()
set1$X <- result$X.train$x[result$Y.train == 3 | result$Y.train == 7,]
set1$Y <- result$Y.train[result$Y.train == 3 | result$Y.train == 7]

# 1s and 7s
set2 <- list()
set2$X <- result$X.train$x[result$Y.train == 1 | result$Y.train == 7,]
set2$Y <- result$Y.train[result$Y.train == 1 | result$Y.train == 7]

# 3s, 1s, and 7s
set3 <- list()
set3$X <- result$X.train$x[result$Y.train == 1 | result$Y.train == 7 | result$Y.train == 3,]
set3$Y <- result$Y.train[result$Y.train == 1 | result$Y.train == 7 | result$Y.train == 3]

PCA

r <- 10
nplot <- 9
pca <- lol.project.pca(set1$X, set1$Y, r=r)

plotlist <- list()
for (i in 1:nplot) {
  plotlist[[i]] <- show_digit(pca$A[,i], title=paste("Projection Column:", i),
                              limits=c(min(pca$A), max(pca$A))) + theme(legend.position="none")
}
Rmisc::multiplot(plotlist=plotlist, cols=ceiling(sqrt(nplot)))

As we ca nsee, somewhat intuitively, our top 2 PCs clearly appear to be the two characters in our subset (PC1 looks similar to a 3 and PC2 to be a 3 in red with a 7 at the lower values in blue).

cPCA

r <- 10
nplot <- 9
cpca <- lol.project.cpca(set1$X, set1$Y, r=r)

plotlist <- list()
for (i in 1:nplot) {
  plotlist[[i]] <- show_digit(cpca$A[,i], title=paste("Projection Column:", i),
                              limits=c(min(cpca$A), max(cpca$A))) + theme(legend.position="none")
}
Rmisc::multiplot(plotlist=plotlist, cols=ceiling(sqrt(nplot)))

As we can see, here our top PCs appear to be less distinct from 3s and 7s, and appear to be more of mergings of the two. For example, in the above images we can see a bunch of faint curling arms from the 3s, and the 7s faintly over top, such as PC 3, 4, 5 and 6 in particular. This makes sense since we correct for the per-class mean here, meaning that we will get fewer characters that look "uniquely" the particular character it is on average. This will mean that cPCA will not be very effective for differentiating the classes, as it will tend to want to project a 7 similar image as the 3s.

LOL

r <- 10
nplot <- 9
lol <- lol.project.lol(set1$X, set1$Y, r=r)

plotlist <- list()
for (i in 1:nplot) {
  plotlist[[i]] <- show_digit(pca$A[,i], title=paste("Projection Column:", i),
                              limits=c(min(lol$A), max(lol$A))) + theme(legend.position="none")
}
Rmisc::multiplot(plotlist=plotlist, cols=ceiling(sqrt(nplot)))

As we can see with LOL, we get an immediately distinguishing character in the PC1 looking like a 3 and PC3 looking distinctly like a 7. Many of the PCs look less fused between the two characters; we can see that either the 7 or the 3 are hot or cold in each image; they do not tend to share coloring compared to cPCA.

CCA

r <- 10
nplot <- 9
cca <- lol.project.lrcca(set1$X, set1$Y, r=r)

plotlist <- list()
for (i in 1:nplot) {
  plotlist[[i]] <- show_digit(cca$A[,i], title=paste("Projection Column:", i),
                              limits=c(min(cca$A), max(cca$A))) + theme(legend.position="none")
}
Rmisc::multiplot(plotlist=plotlist, cols=ceiling(sqrt(nplot)))

I am less sure of the intuition here, but it seems like CCA is trying to find the most dispersed group that differentiates the 7 from the 3. This model appears to be less interpretable than, say, LOL.

Simulations

Below, we perform the afformentioned simulations with 2-fold validation:

performance <- data.frame(set=c(), algorithm=c(), r=c(), class=c(), lhat=c())

algorithms <- c(lol.project.pca, lol.project.cpca, lol.project.lol, lol.project.lrcca)
algnames <- c("PCA", "cPCA", "LOL", "LRCCA")
k <- 2
setsX <- list(set1$X, set2$X, set3$X)
setsY <- list(set1$Y, set2$Y, set3$Y)
setname <- c("set 1", "set 2", "set 3")
rs <- c(10, 20, 50, 100)
classalgs <- c("lda", "rf")

for (j in 1:length(setsX)) {
  print(paste('Set:', j))
  setX <- setsX[[j]]
  setY <- setsY[[j]]
  for (i in 1:length(algorithms)) {
    print(paste('Alg:', algnames[i]))
    for (r in rs) {
      for (class in classalgs) {
        res <- suppressWarnings(lol.eval.xval(setX, setY, r=r, alg=algorithms[i][[1]], classifier=class, k=k))
          performance <- rbind(performance, data.frame(set=setname[j], algorithm=algnames[i], r=r, class=class, lhat=res$Lhat))

      }
    }
  }
}

saveRDS(performance, 'mnist.rds')
g_legend<-function(a.gplot){
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)}

performance <- readRDS('./mnist.rds')

plot_example <- function(data, sim) {
  p1 <- ggplot(data[data$class =='lda',], aes(x=r, y=lhat, color=algorithm, group=algorithm)) +
    stat_summary(geom="line", fun.y="mean", size=2) +
    xlab("Dimensions Retained") +
    ylab(TeX("$\\hat{L}$")) +
    ggtitle(paste(sim, "Simulation, LDA Classifier")) +
    scale_color_discrete(name="Algorithm") +
    theme_bw()
  p2 <- ggplot(data[data$class=='rf',], aes(x=r, y=lhat, color=algorithm, group=algorithm)) +
    stat_summary(geom="line", fun.y="mean", size=2) +
    xlab("Dimensions Retained") +
    ylab(TeX("$\\hat{L}$")) +
    ggtitle(paste(sim, "Simulation, RF Classifier")) +
    theme_bw() +
    scale_color_discrete(name="Algorithm")

  my_legend <- g_legend(p1)
  p3 <- grid.arrange(arrangeGrob(p1 + theme(legend.position=NaN), p2 + theme(legend.position=NaN), nrow=2), my_legend, nrow=1, widths=c(.88, .12))

}

for (set in unique(performance$set)) {
  subset <- performance[performance$set == set,]
  plotlist <- list()
  undims <- unique(subset$r)
  plot_example(subset, set)
}


neurodata/lol documentation built on March 3, 2021, 1:46 a.m.