Nothing
knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(rTorch) torch <- import("torch") Variable <- import("torch.autograd")$Variable np <- import("numpy") optim <- import("torch.optim") py <- import_builtins()
# make it reproducible torch$manual_seed(42L) X <- torch$linspace(-1L, 1L, 101L) y1 <- 2 * X # X$mul(2) y2 <- torch$randn(X$size()) * 0.33 Y <- y1 + y2 # y1$add(y2)
Y$sum()
n_examples <- torch_size(X) n_features <- 1L learning_rate <- 0.01 momentum <- 0.9 n_classes <- 1L batch_size <- 10L epochs <- 100 # original value for epochs = 100 neurons <- 512L
build_model <- function(input_dim, output_dim) { model <- torch$nn$Sequential() model$add_module("linear", torch$nn$Linear(input_dim, output_dim, bias = FALSE)) return(model) } train <- function(model, loss, optimizer, x, y) { x = Variable(x, requires_grad = FALSE) y = Variable(y, requires_grad = FALSE) # reset gradient optimizer$zero_grad() # forward fx <- model$forward(x$view(py$len(x), 1L)) output <- loss$forward(fx, y) # backward output$backward() # update parameters optimizer$step() return(output$data$index(0L)) } model <- build_model(n_features, n_classes) loss <- torch$nn$MSELoss(size_average = TRUE) optimizer <- optim$SGD(model$parameters(), lr = learning_rate, momentum = momentum)
for (i in seq(1, epochs)) { ccost <- 0.0 num_batches <- n_examples %/% batch_size for (k in seq(1, num_batches)) { k <- k - 1 # index in Python start at [0] start <- as.integer(k * batch_size) end <- as.integer((k + 1) * batch_size) cost <- train(model, loss, optimizer, X$narrow(-1L, start, end-start), Y$narrow(-1L, start, end-start)) ccost <- ccost + cost$numpy() # because we don't have `+` func } cat(sprintf("Epoch = %3d, cost = %s \n", i, ccost / num_batches)) } model_param <- model$parameters() w <- iter_next(model$parameters())$data cat(sprintf("w = %.3f", w$numpy())) # Epoch = 1, cost = 0.103987852856517 # Epoch = 100, cost = 0.103722278773785 # w = 1.968
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.