plot.mnist: Nice plot of the MNIST digits with dimensionality reduction

Description Usage Arguments Details Value Author(s) See Also Examples

View source: R/nice_digits.R

Description

Plots MNIST digits (left), projections on a 2-D space (prediction, center) and their reconstructions (right).

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
plot.mnist(x = mnist$test$x, label = mnist$test$y + 1, model = prcomp(mnist$train$x),
           predictions = predict(model, x)[,1:2], reconstructions = tcrossprod(predictions, model$rotation[,1:2]),
		   show.reconstructions = TRUE, highlight.digits = c(11, 3, 2, 91, 20, 188, 92, 1, 111, 13),
		   digits.col = c("#FF0000FF", "#0000FFFF", "#964B00FF", "#FF00FFFF", "#00AAAAFF", "#00EE00FF", "#000000FF", "#000000FF", "#AAAAAAFF", "#FF9900FF"),
		   digits.alphas.reverse = c(FALSE, TRUE)[c(1, 1, 1, 1, 1, 1, 2, 1, 1, 1)],
		   pch = 21,
		   pch.col = c("black", "white")[c(1, 2, 1, 1, 1, 1, 1, 2, 1, 1)],
		   pch.bg = c("#FF0000FF", "#0000FFFF", "#964B00FF", "#FF00FFFF", "#00AAAAFF", "#00EE00FF", "#FFFFFFFF", "#000000FF", "#AAAAAAFF", "#FF9900FF"),
		   cex = 1.5,
		   cex.axis = .5,
		   cex.lab = .5,
		   cex.highlight = 2.5,
		   xlab = "Node 1",
		   ylab = "Node 2",
		   xlim = NULL,
		   ylim = NULL,
		   ncol = 10,
		   ...)

Arguments

x

A Nx784 matrix of data from which to draw the MNIST digits.

label

A vector of length N with labels corresponding to data.

model

An optional model to compute predictions and reconstructions. If supplied, you probably need to supply reconstructions as well. Ignored when both predictions and reconstructions are provided

predictions

A Nx2 matrix with 2 columns containing predictions.

reconstructions

A Nx784 matrix with the reconstructions of the digits.

show.reconstructions

Whether to plot a column of reconstructions to the right.

highlight.digits

The indices of the P digits to draw in the margin, ideally 10 of them.

digits.col

A vector of P colors for the highlight.digits.

digits.alphas.reverse

A logical vector of length P, TRUE when the digits has to be plotted white-on-color, FALSE when the digit has to be plotted color-on-white.

pch, pch.col, pch.bg

Plotting character and associated col and bg properties for the central plot.

cex,cex.axis, cex.lab, cex.highlight

Character expansion factor for the center points, the center plot axis, the center plot labels and the highlight digit points, respectively.

xlab, ylab

Labels for the x and y axis, see par

xlim, ylim

Plot limits, see par

col

a vector of colors to represent the digit. See image for more details.

ncol

Number of layout columns assigned to the central predictions plot. See “Layout” section below.

...

further arguments passed to image.

Details

plot.mnist plots model predictions surrounded by a selection of the original digits (x) and, if requested (with show.reconstructions = TRUE), their reconstructions. The highlight.digits argument is an numeric index of which digits to highlight, in order.

Layout

plot.mnist uses a graphical layout. The original digits x will always use exactly one column. If requested, the reconstructions will use a second one. The ncol argument controls how many columns are assigned to the predictions themselves and should be adjusted to the width of the graphical device. For instance if ncol = 1 the projections will use 1/2 or 1/3 of the plot.

With a square device and 10 highlighted digits, ncol = 8 will make the digits and reconstructions square. To have the predictions displayed in a square too, you should open a device with a 12:10 ratio. Graphical devices with a 16:10 ratio should use ncol = 14 to have the predictions slightly stretched. If more or fewer than 10 digits are highlighted, ncol should be adjusted too. See the Examples sections for a demo.

Value

None

Author(s)

Xavier Robin

See Also

mnist

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
data(mnist)

# DeepBeliefNet
# library(DeepLearning)
# plot.mnist(prediction = predict(trained.mnist.dbn, mnist$test$x), reconstruction=reconstruct(trained.mnist.dbn, mnist$test$x))

# With PCA
pca <- prcomp(mnist$train$x)
plot.mnist(prediction = predict(pca, mnist$test$x), reconstruction=tcrossprod(predict(pca, mnist$test$x)[,1:2], pca$rotation[,1:2]))

# Reconstruct with all 784 components of the prediction
plot.mnist(prediction = predict(pca, mnist$test$x), reconstruction=tcrossprod(predict(pca, mnist$test$x), pca$rotation))

# Note that it also works with no arguments - pca is performed
plot.mnist()

## Not run: 
# Change the number of columns. Digits will be square in a square plot device
plot.mnist(ncol = 8)
# Digits will be sqare in a 16:10 plot
plot.mnist(ncol = 14)
# With a 16:9 device we'd need ncol = 15.7778, ncol = 16 is a fair approximation
plot.mnist(ncol = 16)

## End(Not run)

# With a different number of digits (here 3: 0, 1 and 2) one should adjust ncol to something reasonable.
# Note that with ncol = 1, projections will be very stretched. The following code reverses the x and y from PCA
# to have more space for the 1st component.
mnist.02 <- mnist
mnist.02$train$x <- mnist$train$x[mnist$train$y < 3,]
mnist.02$train$y <- mnist$train$y[mnist$train$y < 3]
mnist.02$test$x <- mnist$test$x[mnist$test$y < 3,]
mnist.02$test$y <- mnist$test$y[mnist$test$y < 3]
pca <- prcomp(mnist.02$train$x)
plot.mnist(prediction = predict(pca, mnist.02$test$x)[,2:1], reconstruction=tcrossprod(predict(pca, mnist.02$test$x)[,1:2], pca$rotation[,1:2]),
           x = mnist.02$test$x, label = mnist.02$test$y + 1, highlight.digits = c(3, 2, 1),
           ncol = 1)

xrobin/mnist documentation built on Dec. 14, 2020, 11:28 a.m.