
Defines functions readMNIST provideMNIST

Documented in provideMNIST readMNIST

# Copyright (C) 2013-2016 Martin Drees
# Copyright (C) 2015-2016 Johannes Rueckert
# This file is part of darch.
# darch is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# darch is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with darch. If not, see <http://www.gnu.org/licenses/>.

#' Function for generating .RData files of the MNIST Database
#' This function reads the MNIST-Database, randomized it and saves it in the 
#' files "train" for the training data and "test" for test data.
#' @details
#' When the data is read the variables for the training data is \code{trainData}
#' and \code{trainLabels} and for the test data \code{testData} and
#' \code{testLabels}. To start the function
#' the files "train-images-idx3-ubyte", "train-labels-idx1-ubyte',
#' "t10k-images-idx3-ubyte", and "t10k-labels-idx1-ubyte" have to be in the
#' folder given by the parameter \code{folder}. The folder name must end with
#' a slash.
#' @param folder The location of the MNIST-Database files.
#' @keywords internal
#' @export
readMNIST <- function(folder)
  futile.logger::flog.info("Loading the MNIST data set.")
  # This function reads the data and labels from the two files given by
  # dataName and labelName. Afterwards it puts the data and labels
  # together in one matrix and sorts it by the labels. The label is in
  # the last column. Then it returns the sorted matrix.
  loadData <- function(dataName, labelName)
    fileFunction <- file
    # Switch to gzfile function if necessary
    if (file.exists(paste0(dataName,".gz")))
      dataName <- paste0(dataName, ".gz")
      labelName <- paste0(labelName, ".gz")
      fileFunction <- gzfile
    # Read the data
    file <- fileFunction(dataName,'rb')
    rows <- readBin(file,'integer',n=1,size=4,endian='big')
    numRow <- readBin(file,'integer',n=1,size=4,endian='big')
    numCol <- readBin(file,'integer',n=1,size=4,endian='big')
    columns <- numRow*numCol
    buffer <- readBin(file,'integer',n=rows*columns,size=1,signed=F)
    data <- matrix(buffer, nrow=rows, byrow=T)/255
    # Read the labels
    file <- fileFunction(labelName,'rb')
    num <- readBin(file,'integer',n=1,size=4,endian='big')
    labels <- readBin(file,'integer',n=num,size=1,signed=F)+1
    # Sort the data by the labels
    sortedData <- cbind(data[],labels[]) # putting data and labels together
    sortedData <- sortedData[order(sortedData[,columns+1]),] # sort the data by the column 785 (the label)
  # Bring the sorted data matrices in a random order
  generateData <- function(data,random,dims){
    # Mix the train data
    randomData <- cbind(data[],random)
    randomData <- randomData[order(randomData[,dims[2]+1]),]
    rdata <- randomData[,1:(dims[2]-1)]
  generateLabels <- function(counts,random,rows){
    # generate a label matrix with rows of kind c(1,0,0,0,0,0,0,0,0,0) and 
    # mix the train labels
    rlabels <- matrix(0, nrow=rows, ncol=10)
    start <- 1
    end <- 0
    for(i in 1:10){
      c <- rep(0,10) 
      c[i] <- 1
      l <- matrix(c,nrow=counts[i],ncol=10,byrow=TRUE)
      end <- end + counts[i]
      rlabels[start:end,] <- l
      start <- start + counts[i]
      futile.logger::flog.info(paste0("class ", (i-1)," = ", counts[i], " images"))
    randomLabels <- cbind(rlabels, random)
    randomLabels <- randomLabels[order(randomLabels[,11]),]
    rlabels <- randomLabels[,1:10]
  futile.logger::flog.info("Loading train set with 60000 images.")
  train <- loadData(paste(folder,"train-images-idx3-ubyte",sep=""), paste(folder,"train-labels-idx1-ubyte",sep=""))
  dims <- dim(train)
  random <- sample(1:dims[1])
  counts <- table(train[,dims[2]])
  futile.logger::flog.info("Generating randomized data set and label matrix")
  trainData <- generateData(train,random,dims)		
  trainLabels <- generateLabels(counts,random,dims[1])
  futile.logger::flog.info("Saving the train data (filename=train)")
  save(trainData, trainLabels, file=paste0(folder, "train.RData"), precheck=T, compress=T)
  futile.logger::flog.info("Loading test set with 10000 images.")
  test <- loadData(paste(folder,"t10k-images-idx3-ubyte",sep=""),paste(folder,"t10k-labels-idx1-ubyte",sep=""))
  dims <- dim(test)
  random <- sample(1:dims[1])
  counts <- table(test[,dims[2]])
  futile.logger::flog.info("Generating randomized data set and label matrix")
  testData <- generateData(test,random,dims)		
  testLabels <- generateLabels(counts,random,dims[1])
  print(paste("Saving the test data (filename=test)"))
  save(testData, testLabels, file=paste0(folder, "test.RData"), precheck=T, compress=T)

#' Provides MNIST data set in the given folder.
#' This function will, if necessary and allowed, download the compressed MNIST
#' data set and save it to .RData files using \code{\link{readMNIST}}. If the
#' compressed MNIST archives are available, it will convert them into RData
#' files loadable from within R. If the RData files are already available,
#' nothing will be done.
#' @param download Logical indicating whether download is allowed.
#' @param folder Folder name, including a trailing slash.
#' @return Boolean value indicating success or failure.
#' @examples
#' \dontrun{
#' provideMNIST("mnist/", download = T)
#' }
#' @export
provideMNIST <- function(folder="data/", download = F)
  # TODO: does not work on windows, will generate warning message because it
  # tries to create the directory even if it exists
  # TODO: make trailing slash optional
  if (!file.exists(folder))
  fileNameTrainImages <- "train-images-idx3-ubyte.gz"
  fileNameTrainLabels <- "train-labels-idx1-ubyte.gz"
  fileNameTestImages <- "t10k-images-idx3-ubyte.gz"
  fileNameTestLabels <- "t10k-labels-idx1-ubyte.gz"
  mnistUrl <- "http://yann.lecun.com/exdb/mnist/"
  if (file.exists(paste0(folder, "train.RData")) &&
        file.exists(paste0(folder, "test.RData")))
    futile.logger::flog.info("MNIST data set already available, nothing left to do.")
  if (download && any(
    futile.logger::flog.info("Compressed MNIST files not found, attempting to download...")
    statusCodes <- c()
    for (file in c(fileNameTrainImages, fileNameTrainLabels,
                   fileNameTestImages, fileNameTestLabels))
      statusCodes <- c(statusCodes,
                        utils::download.file(paste0(mnistUrl, file),
                        paste0(folder, file)))
    if (any(statusCodes > 0))
      futile.logger::flog.error(paste("Error downloading MNIST files.",
        "Download manually from %s or try again."), mnistUrl) 
    futile.logger::flog.info("Successfully downloaded compressed MNIST files.")
      "Compressed MNIST files found or download disabled, skipping download.")

Try the darch package in your browser

Any scripts or data that you put into this service are public.

darch documentation built on May 29, 2017, 8:14 p.m.