Discriminant_analysis_examples"

knitr::opts_chunk$set(
 fig.width  = 5 ,
 fig.height = 3.5,
 fig.align  = 'center'
)
oldpar <- list(mar = par()$mar, mfrow = par()$mfrow)

Introduction

This vignette visualizes classification results from discriminant analysis, using tools from the package.

library("classmap")
library("ggplot2")
library("gridExtra")

Iris data

As a first small example, we consider the Iris data. We first load the data and inspect it.

data(iris)
X <- iris[, 1:4]
y <- iris[, 5]
is.factor(y)
table(y)

pairs(X, col = as.numeric(y) + 1, pch = 19)

Now we carry out quadratic discriminant analysis and inspect the output. Note that we can also do linear discriminant analysis by choosing rule = "LDA".

vcr.train <- vcr.da.train(X, y, rule = "QDA")
names(vcr.train)

We now inspect the output in detail. First look at the prediction as integer, the prediction as label, the alternative label as integer and the alternative label:

vcr.train$predint 
vcr.train$pred[c(1:10, 51:60, 101:110)]
vcr.train$altint  
vcr.train$altlab[c(1:10, 51:60, 101:110)]

The Probability of Alternative Class (PAC) of each object is found in the $PAC element of the output:

vcr.train$PAC[1:3] 
summary(vcr.train$PAC)

The $fig element of the output contains the distance from case i to class g. Let's look at it for the first 5 objects:

vcr.train$fig[1:5, ]

From the fig, the farness of each object can be computed. The farness of an object i is the f(i, g) to its own class:

vcr.train$farness[1:5]
summary(vcr.train$farness)

The "overall farness" of an object is defined as the lowest f(i, g) it has to any class g (including its own):

summary(vcr.train$ofarness)

Objects with ofarness > cutoff are flagged as "outliers". These can be included in a separate column in the confusion matrix. This confusion matrix can be computed using confmat.vcr, which also returns the accuracy.

To illustrate this we choose a rather low cutoff:

confmat.vcr(vcr.train, cutoff = 0.98)

With the default cutoff = 0.99 no objects are flagged in this example:

confmat.vcr(vcr.train)

Note that the accuracy is computed before any objects are flagged, so it does not depend on the cutoff.

The confusion matrix can also be constructed showing class numbers instead of labels. This option can be useful for long level names.

confmat.vcr(vcr.train, showClassNumbers = TRUE)

A stacked mosaic plot made with the stackedplot() function can be used to visualize the confusion matrix. The outliers, if there are any, appear as grey areas on top.

cols <- c("red", "darkgreen", "blue")
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE)

stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE, cutoff = 0.98)

The default stacked mosaic plot has no legend:

stplot <- stackedplot(vcr.train, classCols = cols, 
                     separSize = 1.5, minSize = 1,
                     main = "QDA on iris data")
stplot

We also make the silhouette plot using the silplot() function:

# pdf("Iris_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols, 
        main = "Silhouette plot of QDA on iris data")      
# dev.off()

We now make the class maps based on the vcr object. This can be done using the classmap() function. We make a separate class map for each of the three classes. We see that class 1 is a very tight class (low PAC, no high farness). Class 2 is not so tight, and has two points which are predicted as virginica. Class 3 has one point predicted as versicolor.

classmap(vcr.train, 1, classCols = cols)
classmap(vcr.train, 2, classCols = cols)
classmap(vcr.train, 3, classCols = cols) # With the default cutoff no farness values stand out:
# With a lower cutoff:
classmap(vcr.train, 3, classCols = cols, cutoff = 0.98)
# Now one point is to the right of the vertical line.
# It also has a black border, meaning that it is flagged
# as an outlier, in the sense that its farness to _all_
# classes is above 0.98.

To illustrate the use of new data we create a fake dataset which is a subset of the training data, where not all classes occur, and ynew has NA's.

Xnew <- X[c(1:50, 101:150), ]
ynew <- y[c(1:50, 101:150)]
ynew[c(1:10, 51:60)] <- NA
pairs(X, col = as.numeric(y) + 1, pch = 19) # 3 colors
pairs(Xnew, col = as.numeric(ynew) + 1, pch = 19) # only red and blue

Now we build the vcr object on the training data.

vcr.test <- vcr.da.newdata(Xnew, ynew, vcr.train)

Inspect some of the output to confirm that it corresponds with what we would expect:

plot(vcr.test$predint, vcr.train$predint[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$altint, vcr.train$altint[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$PAC, vcr.train$PAC[c(1:50, 101:150)]); abline(0, 1)
vcr.test$farness 
plot(vcr.test$farness, vcr.train$farness[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$fig, vcr.train$fig[c(1:50, 101:150), ]); abline(0, 1)
vcr.test$ofarness 
plot(vcr.test$ofarness, vcr.train$ofarness[c(1:50, 101:150)]); abline(0, 1)

The confusion matrix for the test data, as for the training data, can be constructed by the confmat.vcr() function. A cutoff of 0.98 flags three outliers in this example.

confmat.vcr(vcr.test)
confmat.vcr(vcr.test, cutoff = 0.98)

Also the stacked mosaic plot can be constructed on the test data:

stplot # to compare with:
stackedplot(vcr.test, classCols = cols, separSize = 1.5, minSize = 1)

We now make the silhouette plot on the test data:

#pdf("Iris_test_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.test, classCols = cols, 
        main = "Silhouette plot of QDA on iris subset") 
#dev.off()

Finally, we construct the class maps for the test data. We compare the class map of the training data with that of the test data for each class.

classmap(vcr.train, 1, classCols = cols)
classmap(vcr.test, 1, classCols = cols) 
classmap(vcr.train, 2, classCols = cols)
classmap(vcr.test, 2, classCols = cols)
classmap(vcr.train, 3, classCols = cols)
classmap(vcr.test, 3, classCols = cols) 

Floral buds data (shown in paper)

We now analyze the floral buds data, which was also used as an illustration in the paper. First load and inspect the data.

data(data_floralbuds)
X <- as.matrix(data_floralbuds[, 1:6])
y <- data_floralbuds$y
dim(X) # 550  6
length(y) # 550
table(y)
# branch     bud  scales support 
#     49     363      94      44 

# Pairs plot
cols <- c("saddlebrown", "orange", "olivedrab4", "royalblue3")
pairs(X, gap = 0, col = cols[as.numeric(y)]) # hard to separate visually

Now we perform quadratic discriminant analysis:

vcr.obj <- vcr.da.train(X, y)

Construct the confusion matrix without and with outliers shown:

confmat <- confmat.vcr(vcr.obj, showOutliers = FALSE)

confmat.vcr(vcr.obj) 

Construct the stacked mosaic plot:

stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
            minSize = 1.5,  main = "stacked plot of QDA on floral buds")

# Version in paper:
# pdf("Floralbuds_QDA_stackplot_without_outliers.pdf",
#     width=5, height=4.3)
# stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
#             minSize = 1.5, showOutliers = FALSE,
#             htitle = "given class", vtitle = "predicted class")
# dev.off()

Now make the silhouette plot:

#pdf("Floralbuds_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.obj, classCols = cols,
        main = "Silhouette plot of QDA on floral bud data")      
#dev.off()

The quasi residual plot can be made with the qresplot() function. We illustate this below by making the quasi residual plot against the sum of the variables. A correlation test confirms that the images with higher sums are significantly easier to classify:

PAC <- vcr.obj$PAC
feat <- rowSums(X); xlab = "rowSums(X)"
# pdf("Floralbuds_QDA_quasi_residual_plot.pdf", width=5, height=4.8)
qresplot(PAC, feat, xlab = xlab, plotErrorBars = TRUE, fac = 2, 
         main = "Floral buds: quasi residual plot")
# dev.off()

cor.test(feat, PAC, method = "spearman") 

Construct the class maps, as shown in the paper:

labels <- c("branch", "bud", "scale", "support")

# classmap of class "bud"
#
# To identify the points that stand out:
# classmap(vcr.obj, 2, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf("Floralbuds_QDA_classmap_bud.pdf", width=7, height=7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds",
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
         cex.main = 1.5) 
# For marking points:
indstomark <- c(294, 70, 69, 152, 204) # from identify = TRUE above
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, +0.04)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("topleft", fill = cols[1:4], legend = labels, 
       cex = 1, ncol = 1, bg = "white")
# dev.off()
par(oldpar)

All class maps:

#
# pdf(file = "Floralbuds_all_class_maps.pdf", width = 7, height = 7)
par(mfrow = c(2, 2))
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 1, classCols = cols,
         main = "predictions of branches")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds")
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, 0.04)
# xvals <- c( 1.75, 1.68, 1.25, 3.25, 4.00)
# yvals <- c(0.045, 0.92, 0.54, 0.97, 0.045)
text(x = xvals, y = yvals, labels = labs, cex = 1.0)
legend("topleft", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 3, classCols = cols,
         main = "predictions of scales")
legend("left", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
# 
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 4, classCols = cols,
         main = "predictions of supports")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
# dev.off()
par(oldpar)

MNIST data

We now analyze the MNIST data, originally from the website of Yann LeCun. As the link on his website is currently down, we use a different source. Note that downloading the data may take a minute or two, depending on the speed of the internet connection.

mnist_url <- "https://wis.kuleuven.be/statdatascience/robust/data/mnist-rdata"
url.exists <- suppressWarnings(try(open.connection(url(mnist_url), open = "rt", timeout = 2),  silent = TRUE)[1], classes = "warning")

if (is.null(url.exists)) {load(url(mnist_url))} else {
  print(paste("The data source ", mnist_url, "is not active at the moment. The example can nevertheless be reproduced by downloading the mnist data from another source, formatting the training data to dimensions 60000 x 28 x 28, and running the code below."))
}
close(url(mnist_url))
X_train <- mnist$train$x
y_train <- as.factor(mnist$train$y)

head(y_train)
# Levels: 0 1 2 3 4 5 6 7 8 9
dim(X_train) # 60000    28    28
length(y_train) # 60000

We now inspect the data by plotting a few images

plotImage = function(tempImage) {
  tdm = reshape2::melt(apply((tempImage), 2, rev))
  p = ggplot(tdm, aes(x = Var2, y = Var1, fill = (value))) +
    geom_raster() +
    guides(color = "none", size = "none", fill = "none") +
    theme(axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.x = element_blank(),
          axis.ticks.y = element_blank()) +
    scale_fill_gradient(low = "white", high = "black")
  p
}

plotImage(X_train[1, , ])
plotImage(X_train[2, , ])
plotImage(X_train[3, , ])

We now unfold the array containing the data to a matrix, and inspect some sample images as well as the average image per digit:

# Change the dimensions of X for the sequel:
dim(X_train) <- c(60000, 28 * 28)
dim(X_train) # 60000    784

# Sampled digit images:
set.seed(123)
sampledigits <- list()
for (i in 0:9) {
  digit <- i
  idx <- sample(which(y_train == digit), size = 1)
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  sampledigits[[i + 1]] <- plotImage(tempImage) 
}
psampledigits <- grid.arrange(grobs = sampledigits, ncol = 5)
# ggsave("MNIST_sampled_images.pdf", plot = psampledigits,
#        width = 10, height = 1)


# Averaged digit images:
meanPlots <- list()
for (j in 0:9) {
  m.out <- colMeans(X_train[which(y_train == j), ])
  dim(m.out) <- c(28, 28)
  meanPlots[[j + 1]] <- plotImage(m.out) 
}
meanplot <- grid.arrange(grobs = meanPlots, ncol = 5)
# ggsave("MNIST_averaged_images.pdf", plot = meanplot,
#        width = 10, height = 1)

Before performing discriminant analysis, we reduce the dimension of the data by PCA.

library(svd)
ptm <- proc.time()
svd.out <- svd::propack.svd(X_train, neig = 50)
(proc.time() - ptm)[3]
loadings <- svd.out$v
rm(svd.out)
dataProj <- as.matrix(X_train %*% loadings)
dim(dataProj)

Now we perform discriminant analysis, which takes roughly 5 seconds.

vcr.train <- vcr.da.train(X = dataProj, y_train)

We compute the confusion matrix and make the stacked mosaic plot:

confmat.vcr(vcr.train, showOutliers = FALSE)

cols <- c("red3", "darkorange", "gold2", "darkolivegreen3",
         "darkolivegreen4", "cadetblue3", "deepskyblue4", 
         "darkslateblue", "darkorchid3", "deeppink4")

# stacked plot in paper:
# pdf("MNIST_stackplot_with_outliers.pdf", width=5, height=4.3)
stackedplot(vcr.train, classCols = cols, separSize = 0.6,
            minSize = 1.5, htitle = "given class",
            main = "Stacked plot of QDA on MNIST training data", vtitle = "predicted class")
# dev.off()

The silhouette plot:

# pdf("MNIST_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols,
        main = "Silhouette plot of QDA on MNIST training data")      
# dev.off()

Now we make the class maps.

wnq <- function(string, qwrite=TRUE) { # auxiliary function
  # writes a line without quotes
  if (qwrite) write(noquote(string), file = "", ncolumns = 100)
}

showdigit <- function(digit=digit, i, plotIt = TRUE) {
  idx = which(y_train == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.train$pred[idx]), sep=""))
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}

Class map of digit 0, shown in paper:

digit <- 0
#
# To identify outliers:
# classmap(vcr.train, digit+1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ",digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(4000, 3964, 5891, 2485, 822, 
               2280, 2504, 3906, 5869, 1034) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.04, -0.01, 0, -0.11, 0.06,
    0.07, 0.06, 0.10, 0.06, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, 0.022, -0.025, 
    -0.025, -0.035, -0.025, 0.03, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, plotIt = FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] = tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Class map of digit 1, shown in paper:

digit <- 1
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
# indices of the 1s predicted as 2 (takes a while):
#
indstomark <- which(vcr.train$predint[which(y_train == digit)] == 3)
length(indstomark) # 104
labs  <- letters[1:length(indstomark)]
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("\"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 8)
# ggsave(paste0("MNIST_discussionplot_digit", digit, "predictedAs2b.pdf"),
#        plot = discussionPlot, width = 10,
#        height = (length(indstomark) %/% 10 +
#                    (length(indstomark) %% 10 > 0)))

# The digits 1 predicted as a 2 are mostly ones written with
# a horizontal line at the bottom.

Class map of digit 2:

digit <- 2
# To identify outliers:
# classmap(vcr.train, digit + 1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit,".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
                  main = paste0("predictions of digit", digit), cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
                  cex.main = 1.5)
indstomark <- c(3164, 5434, 2319 , 4224, 3682, 
               2642, 4920, 1233, 3741, 3993) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0, 0.08, 0, 0, 0, 0, 0, 0, 0, 0)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, -0.03, -0.03, 
    -0.03, -0.03, -0.03, 0.03, 0.03)  
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Now we analyze the MNIST test data. First load and inspect the data, and project it onto the PCA subspace extracted from the training data.

X_test <- mnist$test$x
y_test <- as.factor(mnist$test$y)

head(y_test)
#
dim(X_test) # 10000    28    28
length(y_test) # 10000

plotImage(X_test[1, , ])
plotImage(X_test[2, , ])
plotImage(X_test[3, , ])

dim(X_test) <- c(10000, 28 * 28)
dim(X_test) # 10000  784

dataProj_test <- as.matrix(X_test %*% loadings)

Now prepare the VCR object:

vcr.test <- vcr.da.newdata(Xnew = dataProj_test,
                           ynew = y_test,
                           vcr.da.train.out = vcr.train)

Build the confusion matrix and plot a stacked mosaic plot of the classification performance on the test data:

confmat.vcr(vcr.test, showOutliers = FALSE, showClassNumbers = TRUE)

# In supplementary material:
# pdf("MNISTtest_stackplot_with_outliers.pdf", width = 5, height = 4.3)
stackedplot(vcr.test, classCols = cols, separSize = 0.6,
            main = "Stacked plot of QDA on MNIST test data",
            minSize = 1.5)
# dev.off()

Silhouette plot:

#pdf("MNIST_test_QDA_silhouettes.pdf", width = 5.0, height = 4.6)
silplot(vcr.test, classCols = cols,
        main = "Silhouette plot of QDA on MNIST test data")      
#dev.off()

Now we can construct the class maps on the test data. First for digit 0:

showdigit_test <- function(digit = digit, i, plotIt = TRUE) {
  idx = which(y_test == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.test$pred[idx]), sep = ""))
  tempImage <- matrix(unlist(X_test[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}

digit <- 0
# classmap(vcr.test, digit+1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit,".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(140, 630, 241, 967, 189,
               377, 78, 943, 64, 354)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0.08, 0.07, -0.07, 0.06, 0,
    0.04, 0.05, 0.09, -0.04, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.025, -0.03, -0.024, -0.025, -0.03, 
    -0.03, -0.03, 0.022, 0.035, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
      bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Now for digit 3:

digit <- 3
# classmap(vcr.test, digit + 1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit, ".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(883, 659, 262, 60, 310,
               832, 223, 784, 835, 289)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.01, 0.08, -0.10, 0.06, 0.07, 
    0.06, 0.03, 0.11, 0.02, 0.06)
yvals <- coords[indstomark, 2] +
  c(0.035, 0.033, -0.017, -0.022, -0.025, 
    -0.025, -0.033, -0.022, 0.035, 0.038)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds    <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))



Try the classmap package in your browser

Any scripts or data that you put into this service are public.

classmap documentation built on April 23, 2023, 5:09 p.m.