state <- torch:::state
`state<-` <- torch:::`state<-`
#' Implements AdamW algorithm.
#'
#' It has been proposed in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
#'
#' @param params (iterable): iterable of parameters to optimize or dicts defining
#' parameter groups
#' @param lr (float, optional): learning rate (default: 1e-3)
#' @param betas (`Tuple[float, float]`, optional): coefficients used for computing
#' running averages of gradient and its square (default: (0.9, 0.999))
#' @param eps (float, optional): term added to the denominator to improve
#' numerical stability (default: 1e-8)
#' @param weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
#' @param amsgrad (boolean, optional): whether to use the AMSGrad variant of this
#' algorithm from the paper [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
#' (default: FALSE)
#'
#' @examples
#' \dontrun{
#' optimizer <- optim_adamw(model$parameters(), lr=0.1)
#' optimizer$zero_grad()
#' loss_fn(model(input), target)$backward()
#' optimizer$step()
#' }
#'
#' @export
optim_adamw <- torch::optimizer(
"optim_adamw",
initialize = function(params, lr=1e-3, betas=c(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=FALSE) {
if (lr < 0)
value_error("Invalid learning rate: {lr}")
if (eps < 0)
value_error("Invalid eps: {eps}")
if (betas[[1]] < 0 || betas[[1]] > 1)
value_error("Invalid beta parameter at index 1")
if (betas[[2]] < 0 || betas[[2]] > 1)
value_error("Invalid beta parameter at index 2")
if (weight_decay < 0)
value_error("Invalid weight decay value: {weight_decay}")
defaults <- list(lr=lr, betas=betas, eps = eps, weight_decay = weight_decay,
amsgrad = amsgrad)
super$initialize(params, defaults)
},
step = function(closure = NULL) {
loop_fun <- function(group, param, g, p) {
grad <- param$grad
# if (grad$is_sparse) {
# runtime_error("Adam does not support sparse gradients, please consider",
# "SparseAdam instead")
# }
amsgrad <- group$amsgrad
# state initialization
if (length(state(param)) == 0) {
torch:::state(param) <- list()
torch:::state(param)[["step"]] <- 0
torch:::state(param)[["exp_avg"]] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
torch:::state(param)[["exp_avg_sq"]] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
if (amsgrad) {
torch:::state(param)[['max_exp_avg_sq']] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
}
}
exp_avg <- torch:::state(param)[["exp_avg"]]
exp_avg_sq <- torch:::state(param)[["exp_avg_sq"]]
if (amsgrad) {
max_exp_avg_sq <- torch:::state(param)[['max_exp_avg_sq']]
}
beta1 <- group$betas[[1]]
beta2 <- group$betas[[2]]
state(param)[["step"]] <- state(param)[["step"]] + 1
param$mul_(1 - group$lr * group$weight_decay)
bias_correction1 <- 1 - beta1 ^ torch:::state(param)[['step']]
bias_correction2 <- 1 - beta2 ^ torch:::state(param)[['step']]
# Decay the first and second moment running average coefficient
exp_avg$mul_(beta1)$add_(grad, alpha=1 - beta1)
exp_avg_sq$mul_(beta2)$addcmul_(grad, grad, value=1 - beta2)
if (amsgrad) {
# Maintains the maximum of all 2nd moment running avg. till now
max_exp_avg_sq$set_data(max_exp_avg_sq$max(other = exp_avg_sq))
# Use the max. for normalizing running avg. of gradient
denom <- (max_exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
} else {
denom <- (exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
}
step_size <- group$lr / bias_correction1
param$addcdiv_(exp_avg, denom, value=-step_size)
}
private$step_helper(closure, loop_fun)
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.