Neural networks are all the rage these days. After reading a book on how to use gradient descent to fit a simple neural network via backpropagation, I decided to code it up by hand with the ultimate goal of having it classify some MNIST digits.
set.seed(1337) library(rprojroot)
library(ArtisanalMachineLearning) data(iris)
Preprocess the data into a binary response dataset. Also change the response to be {-1, 1}
due to the nature of the tanh
function.
iris = iris[1:100,] iris$Species = as.numeric(iris$Species) - 2 iris$Species[iris$Species == 0] = 1
Split the data into training and vaildation.
validation_set = rbind(iris[1:10,], iris[91:100,]) validation_set$Species = NULL validation_response = c(iris$Species[1:10], iris$Species[91:100]) data = iris[11:90,] data$Species = NULL response = iris$Species[11:90]
Train a network.
network = aml_neural_network(sizes=c(4, 15, 15, 4, 1), learning_rate=.01, data=data, response=response, epochs=100, verbose=FALSE)
Inspect training and validation accuracy by naively classifying all predictions less than 0 to the -1
level and 1
otherwise.
preds = predict(network, data) training_accuracy = sum((preds < 0) == (response < 0)) / length(preds) print(paste("Training accuracy:", training_accuracy)) validation = predict(network, validation_set) validation_accuracy = sum((validation < 0) == (validation_response < 0)) / length(validation) print(paste("Validation accuracy:", validation_accuracy))
I have written the smartest algorithm of all time. OR This is a highly separable very small data set and the network picked up on that immediately (see the k means vignette for proof of this).
Let's read in the MNIST handwritten digits data and define a few helper functions. Code was largely robbed from here.
library(readr) library(dplyr) library(tidyr) library(ggplot2) library(viridis) process_digit_data <- function(data_row){ data_row %>% gather(pixel, value) %>% tidyr::extract(pixel, "pixel", "(\\d+)", convert = TRUE) %>% mutate(pixel = pixel - 2, x = pixel %% 28, y = 28 - pixel %/% 28) } plot_digit <- function(data_row){ image_df = process_digit_data(data_row) image = ggplot(image_df, aes(x = x, y = y, fill = value)) + geom_raster() + coord_equal() + scale_fill_viridis() image } mnist_raw <- read_csv("https://pjreddie.com/media/files/mnist_train.csv", col_names = FALSE, col_types = cols()) response = unlist(mnist_raw[,1]) data = mnist_raw data[,1] = NULL
Preprocess data to only the digits 0
and 7
and transform the response to be {-1, 1}
due to the nature of the tanh
function. Also split into training and test sets.
digits_to_keep = c(0,7) data_sub = data[response %in% digits_to_keep,] response_sub = response[response %in% digits_to_keep] response_transformed = sapply(response_sub, function(x){ifelse(x == 0, -1, 1)}) training_indicators = sample(1:length(response_transformed), 10000) test_indicator = !(1:nrow(data_sub) %in% training_indicators) training_data = data_sub[training_indicators,] training_response = response_transformed[training_indicators] test_data = data_sub[test_indicator,]
If you're not familiar with the MNIST data set, it's a dataset of handwritten digit images. Here are a few examples.
set.seed(999999)
plot_digit(training_data[sample(1:nrow(training_data), 1),]) plot_digit(training_data[sample(1:nrow(training_data), 1),]) plot_digit(training_data[sample(1:nrow(training_data), 1),]) plot_digit(training_data[sample(1:nrow(training_data), 1),]) plot_digit(training_data[sample(1:nrow(training_data), 1),])
The average of all 0 and 7 images looks like the following.
# Plot average 0 image averages_frame = as.data.frame(t(apply((training_data[training_response == -1,]),MARGIN=2,FUN=mean))) plot_digit(averages_frame) # Plot average 7 image averages_frame = as.data.frame(t(apply((training_data[training_response == 1,]),MARGIN=2,FUN=mean))) plot_digit(averages_frame)
Now, let's train a small neural network and inspect. Warning! This takes some time to run
small_network = aml_neural_network(c(784, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=10, verbose = FALSE)
# Do a cooking show trick and bring out an already baked net small_network = readRDS(file.path(find_root('DESCRIPTION'), 'data/small_net.rds'))
Accuracy rates using a naive classification split at 0
# Training error preds = predict(small_network, training_data) training_error = sum((training_response < 0) == (preds < 0)) / length(preds) print(paste("Training error:", training_error)) # Test error test_preds = predict(small_network, test_data) testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds) print(paste("Testing error:", testing_error))
Train a big neural network and inspect. Warning! This takes a long time to run
big_network = aml_neural_network(c(784, 100, 100, 50, 25, 1), learning_rate = .01, data=training_data, response=training_response, epochs=100, verbose = FALSE)
# Do a cooking show trick and bring out an already baked net big_network = readRDS(file.path(find_root('DESCRIPTION'), 'data/big_net.rds'))
Accuracy rates using a naive classification split at 0
# Training error preds = predict(big_network, training_data) training_error = sum((training_response < 0) == (preds < 0)) / length(preds) print(paste("Training error:", training_error)) # Test error test_preds = predict(big_network, test_data) testing_error = sum((response_transformed[test_indicator] < 0) == (test_preds < 0)) / length(test_preds) print(paste("Testing error:", testing_error))
median_obs = which(preds == median(preds)) # Many observations are equal to the median plot_digit(training_data[sample(median_obs, 1),]) plot_digit(training_data[sample(median_obs, 1),]) plot_digit(training_data[sample(median_obs, 1),]) plot_digit(training_data[sample(median_obs, 1),])
median_obs = which(test_preds == median(test_preds)) # Many observations are equal to the median plot_digit(test_data[sample(median_obs, 1),]) plot_digit(test_data[sample(median_obs, 1),]) plot_digit(test_data[sample(median_obs, 1),]) plot_digit(test_data[sample(median_obs, 1),])
plot_digit(training_data[which.min(preds),])
plot_digit(training_data[which.max(preds),])
plot_digit(test_data[which.min(test_preds),])
plot_digit(test_data[which.max(test_preds),])
Image processing with neural networks is very hard. This very quick analysis is clearly picking up on some signal, but you can see the network struggle with digits that were written thicker than others, digits that were not centered in the middle of the grid, or a 7 digit written with a cross through the middle, etc. It's possible that training larger networks and/or training with more epochs would help in this situation, but that intense tuning and wasting of time is frivolous. More sophisticated preprocessing techniques and/or other types of neural networks such as convolutional neural networks are necessary to tackle this problem appropriately.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.