context("nn_module")
source("utils.R")
test_succeeds('download MNIST pickler', {
download.file('https://github.com/henry090/fastai/raw/master/files/mnist.pkl.gz','mnist.pkl.gz')
R.utils::gunzip("mnist.pkl.gz", remove=FALSE)
})
test_succeeds('fit nn_module MNIST', {
object = reticulate::py_load_object('mnist.pkl', encoding='latin-1')
x_train = object[[1]][[1]][1:500,1:784]
x_valid = object[[2]][[1]][1:500,1:784]
y_train = as.integer(object[[1]][[2]])[1:500]
y_valid = as.integer(object[[2]][[2]])[1:500]
example = array_reshape(x_train[1,], c(28,28))
example %>% show_image(cmap = 'gray') %>% plot()
TensorDataset = torch()$utils$data$TensorDataset
bs = 32
train_ds = TensorDataset(tensor(x_train), tensor(y_train))
valid_ds = TensorDataset(tensor(x_valid), tensor(y_valid))
train_dl = TfmdDL(train_ds, bs = bs, shuffle = TRUE)
valid_dl = TfmdDL(valid_ds, bs = 2 * bs)
dls = Data_Loaders(train_dl, valid_dl)
one = one_batch(dls)
x = one[[1]]
y = one[[2]]
x$shape; y$shape
nn = nn()
Functional = torch()$nn$functional
model = nn_module(function(self) {
self$lin1 = nn$Linear(784L, 50L, bias=TRUE)
self$lin2 = nn$Linear(50L, 10L, bias=TRUE)
forward = function(y) {
x = self$lin1(y)
x = Functional$relu(x)
self$lin2(x)
}
})
learn = Learner(dls, model, loss_func=nn$CrossEntropyLoss(), metrics=accuracy)
learn %>% summary()
learn %>% fit_one_cycle(1, 1e-2)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.