Nothing
expect_optim_works <- function(optim, defaults) {
w_true <- torch_randn(10, 1)
x <- torch_randn(100, 10)
y <- torch_mm(x, w_true)
loss <- function(y, y_pred) {
torch_mean(
(y - y_pred)^2
)
}
w <- torch_randn(10, 1, requires_grad = TRUE)
z <- torch_randn(10, 1, requires_grad = TRUE)
defaults[["params"]] <- list(w, z)
opt <- do.call(optim, defaults)
fn <- function() {
opt$zero_grad()
y_pred <- torch_mm(x, w)
l <- loss(y, y_pred)
l$backward()
l
}
initial_value <- fn()
for (i in seq_len(200)) {
opt$step(fn)
}
expect_true(as_array(fn()) <= as_array(initial_value) / 2)
opt$state_dict()
}
expect_state_is_updated <- function(opt_fn) {
x <- torch_tensor(1, requires_grad = TRUE)
opt <- opt_fn(x)
opt$zero_grad()
y <- 2 * x
y$backward()
opt$step()
expect_equal(as.numeric(opt$state$get(opt$param_groups[[1]]$params[[1]])$step), 1)
opt$step()
expect_equal(as.numeric(opt$state$get(opt$param_groups[[1]]$params[[1]])$step), 2)
state <- opt$state_dict()
x2 <- torch_tensor(1, requires_grad = TRUE)
opt2 <- opt_fn(x)
opt2$load_state_dict(state)
for (i in seq_along(opt$state$map)) {
expect_equal(opt2$state$map[[i]], opt$state$map[[i]])
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.