knitr::opts_chunk$set(eval=FALSE, warning=FALSE, message=FALSE)
Import necessary libraries:
library(cntk) library(magrittr)
Define input dimensions:
input_dim_model <- c(1, 28, 28) # images are 28 x 28 with 1 channel of color (gray) input_dim <- 28*28 # used by readers to treat input data as a vector num_output_classes <- 10
Create reader function:
create_reader <- function(path, is_training, input_dim, num_label_classes) { ctf <- CTFDeserializer(path, StreamDefs( features = StreamDef(field = "features", shape = input_dim, is_sparse = FALSE), labels = StreamDef(field = "labels", shape = num_label_classes, is_sparse = FALSE))) mini_reader <- MinibatchSource(ctf, randomize = is_training, max_sweeps = ifelse(is_training, IO_INFINITELY_REPEAT, 1)) return(mini_reader) }
Install MNIST data first.
data_dir <- file.path("../../../", "Examples", "Image", "DataSets", "MNIST") if (!(file.exists(file.path(data_dir, "Train-28x28_cntk_text.txt")))) { message("Download MNIST data first!") } else { train_file <- file.path(data_dir, "Train-28x28_cntk_text.txt") test_file <- file.path(data_dir, "Test-28x28_cntk_text.txt") }
x <- op_input_variable(input_dim_model) y <- op_input_variable(num_output_classes)
create_model <- function(features) { h <- op_element_times(1/255, features) h <- Convolution2D(filter_shape = c(5, 5), num_filters = 8, strides = c(2, 2), pad = TRUE, name = "first_conv", activation = op_relu)(h) h <- Convolution2D(filter_shape = c(5, 5), num_filters = 16, strides = c(2, 2), pad = TRUE, name = "conv", activation = op_relu)(h) r <- Dense(num_output_classes, activation = NULL, name = "classify")(h) return(r) } z <- create_model(x) sprintf("Output shape of the first convolution layer: %s", paste0(z$first_conv$shape, collapse = ", ")) sprintf("Bias value of the last dense layer: %s", paste0(z$classify$b$value, collapse = ", ")) visualize_network(z)
create_criterion_function <- function(model, labels) { loss <- loss_cross_entropy_with_softmax(model, labels) errs <- classification_error(model, labels) return(list(loss, errs)) } print_training_progress <- function(trainer, mb, frequency, verbose = 1) { training_loss <- NA eval_error <- NA if (mb %% frequency == 0) { training_loss <- trainer$previous_minibatch_loss_average eval_error <- trainer$previous_minibatch_evaluation_average if (verbose) { sprintf("Minibatch: %s, Losss: %s, Error %s", mb, training_loss, eval_error*100) } } return(list(mb, training_loss, eval_error)) }
train_test <- function(train_reader, test_reader, model_func, num_sweeps_to_train_with = 10) { model <- model_func loss_error <- create_criterion_function(model, y) learning_rate <- 0.2 lr_schedule <- learning_rate_schedule(learning_rate, UnitType("minibatch")) learner <- learner_sgd(z$parameters, lr_schedule) trainer <- Trainer(z, loss_error, learner) minibatch_size <- 64 num_samples_per_sweep <- 6*10^4 num_minibatches_to_train <- (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size input_map <- dict("y" = train_reader$streams$labels, "x" = train_reader$streams$features) training_progress_output_freq <- 500 start <- Sys.time() for (i in 0:ceiling(num_minibatches_to_train)) { data <- train_reader %>% next_minibatch(minibatch_size, input_map = input_map) trainer %>% train_minibatch(data) print_training_progress(trainer, i, training_progress_output_freq, verbose = 1) } sprintf("Training took %s: ", Sys.time() - start) test_input_map <- list( y = test_reader$streams$labels, x = test_reader$streams$features ) test_minibatch_size <- 512 num_samples <- 10^4 num_minibatches_to_test <- num_samples / test_minibatch_size test_result <- 0 for (i in 1:ceiling(num_minibatches_to_test)) { data <- test_reader %>% next_minibatch(test_minibatch_size, input_map = test_input_map) eval_error <- trainer %>% test_minibatch(data) test_result <- test_result + eval_error } sprintf("Average test error %s: ", test_result * 100 / num_minibatches_to_test) }
do_train_test <- function() { reader_train <- create_reader(train_file, TRUE, input_dim, num_output_classes) reader_test <- create_reader(test_file, FALSE, input_dim, num_output_classes) train_test(reader_train, reader_test, z) } do_train_test()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.