Nothing
#' @include optim.R
NULL
#' Implements Adam algorithm.
#'
#' It has been proposed in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980).
#'
#' @param params (iterable): iterable of parameters to optimize or dicts defining
#' parameter groups
#' @param lr (float, optional): learning rate (default: 1e-3)
#' @param betas (`Tuple[float, float]`, optional): coefficients used for computing
#' running averages of gradient and its square (default: (0.9, 0.999))
#' @param eps (float, optional): term added to the denominator to improve
#' numerical stability (default: 1e-8)
#' @param weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
#' @param amsgrad (boolean, optional): whether to use the AMSGrad variant of this
#' algorithm from the paper [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
#' (default: FALSE)
#'
#' @includeRmd man/rmd/optim-note.Rmd note
#'
#' @examples
#' \dontrun{
#' optimizer <- optim_adam(model$parameters(), lr = 0.1)
#' optimizer$zero_grad()
#' loss_fn(model(input), target)$backward()
#' optimizer$step()
#' }
#'
#' @export
optim_adam <- optimizer(
"optim_adam",
initialize = function(params, lr = 1e-3, betas = c(0.9, 0.999), eps = 1e-8,
weight_decay = 0, amsgrad = FALSE) {
if (lr < 0) {
value_error("Invalid learning rate: {lr}")
}
if (eps < 0) {
value_error("Invalid eps: {eps}")
}
if (betas[[1]] < 0 || betas[[1]] > 1) {
value_error("Invalid beta parameter at index 1")
}
if (betas[[2]] < 0 || betas[[2]] > 1) {
value_error("Invalid beta parameter at index 2")
}
if (weight_decay < 0) {
value_error("Invalid weight decay value: {weight_decay}")
}
defaults <- list(
lr = lr, betas = betas, eps = eps, weight_decay = weight_decay,
amsgrad = amsgrad
)
super$initialize(params, defaults)
},
step = function(closure = NULL) {
loop_fun <- function(group, param, g, p) {
grad <- param$grad
# if (grad$is_sparse) {
# runtime_error("Adam does not support sparse gradients, please consider",
# "SparseAdam instead")
# }
amsgrad <- group$amsgrad
# state initialization
if (length(state(param)) == 0) {
state(param) <- list()
state(param)[["step"]] <- 0
state(param)[["exp_avg"]] <- torch_zeros_like(param, memory_format = torch_preserve_format())
state(param)[["exp_avg_sq"]] <- torch_zeros_like(param, memory_format = torch_preserve_format())
if (amsgrad) {
state(param)[["max_exp_avg_sq"]] <- torch_zeros_like(param, memory_format = torch_preserve_format())
}
}
exp_avg <- state(param)[["exp_avg"]]
exp_avg_sq <- state(param)[["exp_avg_sq"]]
if (amsgrad) {
max_exp_avg_sq <- state(param)[["max_exp_avg_sq"]]
}
beta1 <- group$betas[[1]]
beta2 <- group$betas[[2]]
state(param)[["step"]] <- state(param)[["step"]] + 1
bias_correction1 <- 1 - beta1^state(param)[["step"]]
bias_correction2 <- 1 - beta2^state(param)[["step"]]
if (group$weight_decay != 0) {
grad$add_(param, alpha = group$weight_decay)
}
# Decay the first and second moment running average coefficient
exp_avg$mul_(beta1)$add_(grad, alpha = 1 - beta1)
exp_avg_sq$mul_(beta2)$addcmul_(grad, grad, value = 1 - beta2)
if (amsgrad) {
# Maintains the maximum of all 2nd moment running avg. till now
max_exp_avg_sq$set_data(max_exp_avg_sq$max(other = exp_avg_sq))
# Use the max. for normalizing running avg. of gradient
denom <- (max_exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
} else {
denom <- (exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
}
step_size <- group$lr / bias_correction1
param$addcdiv_(exp_avg, denom, value = -step_size)
}
private$step_helper(closure, loop_fun)
}
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.