R/adamw.R

state <- torch:::state
`state<-` <- torch:::`state<-`

#' Implements AdamW algorithm.
#'
#' It has been proposed in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
#'
#' @param params (iterable): iterable of parameters to optimize or dicts defining
#'   parameter groups
#' @param lr (float, optional): learning rate (default: 1e-3)
#' @param betas (`Tuple[float, float]`, optional): coefficients used for computing
#'   running averages of gradient and its square (default: (0.9, 0.999))
#' @param eps (float, optional): term added to the denominator to improve
#'   numerical stability (default: 1e-8)
#' @param weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
#' @param amsgrad (boolean, optional): whether to use the AMSGrad variant of this
#'   algorithm from the paper [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
#'   (default: FALSE)
#'
#' @examples
#' \dontrun{
#' optimizer <- optim_adamw(model$parameters(), lr=0.1)
#' optimizer$zero_grad()
#' loss_fn(model(input), target)$backward()
#' optimizer$step()
#' }
#'
#' @export
optim_adamw <- torch::optimizer(
  "optim_adamw",
  initialize = function(params, lr=1e-3, betas=c(0.9, 0.999), eps=1e-8,
                        weight_decay=0, amsgrad=FALSE) {


    if (lr < 0)
      value_error("Invalid learning rate: {lr}")

    if (eps < 0)
      value_error("Invalid eps: {eps}")

    if (betas[[1]] < 0 || betas[[1]] > 1)
      value_error("Invalid beta parameter at index 1")

    if (betas[[2]] < 0 || betas[[2]] > 1)
      value_error("Invalid beta parameter at index 2")

    if (weight_decay < 0)
      value_error("Invalid weight decay value: {weight_decay}")

    defaults <- list(lr=lr, betas=betas, eps = eps, weight_decay = weight_decay,
                     amsgrad = amsgrad)

    super$initialize(params, defaults)
  },

  step = function(closure = NULL) {
    loop_fun <- function(group, param, g, p) {

      grad <- param$grad

      # if (grad$is_sparse) {
      #   runtime_error("Adam does not support sparse gradients, please consider",
      #                 "SparseAdam instead")
      # }
      amsgrad <- group$amsgrad

      # state initialization
      if (length(state(param)) == 0) {
        torch:::state(param) <- list()
        torch:::state(param)[["step"]] <- 0
        torch:::state(param)[["exp_avg"]] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
        torch:::state(param)[["exp_avg_sq"]] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
        if (amsgrad) {
          torch:::state(param)[['max_exp_avg_sq']] <- torch::torch_zeros_like(param, memory_format=torch_preserve_format())
        }
      }

      exp_avg <- torch:::state(param)[["exp_avg"]]
      exp_avg_sq <- torch:::state(param)[["exp_avg_sq"]]
      if (amsgrad) {
        max_exp_avg_sq <- torch:::state(param)[['max_exp_avg_sq']]
      }
      beta1 <- group$betas[[1]]
      beta2 <- group$betas[[2]]

      state(param)[["step"]] <- state(param)[["step"]] + 1

      param$mul_(1 - group$lr * group$weight_decay)
      bias_correction1 <- 1 - beta1 ^ torch:::state(param)[['step']]
      bias_correction2 <- 1 - beta2 ^ torch:::state(param)[['step']]


      # Decay the first and second moment running average coefficient
      exp_avg$mul_(beta1)$add_(grad, alpha=1 - beta1)
      exp_avg_sq$mul_(beta2)$addcmul_(grad, grad, value=1 - beta2)

      if (amsgrad) {

        # Maintains the maximum of all 2nd moment running avg. till now
        max_exp_avg_sq$set_data(max_exp_avg_sq$max(other = exp_avg_sq))
        # Use the max. for normalizing running avg. of gradient
        denom <- (max_exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
      } else {
        denom <- (exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
      }

      step_size <- group$lr / bias_correction1
      param$addcdiv_(exp_avg, denom, value=-step_size)
    }
    private$step_helper(closure, loop_fun)
  }
)
cmcmaster1/torchtabular documentation built on Dec. 19, 2021, 5:17 p.m.