R/optim-lbfgs.R

Defines functions .new_zeros_tensor .strong_wolfe .cubic_interpolate

#' @include optim.R
NULL

# Compute minimum of interpolating polynomial based on function and derivative values
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
.cubic_interpolate <-
  function(x1, f1, g1, x2, f2, g2, bounds = NULL) {
    # Compute bounds of interpolation area
    if (!is.null(bounds)) {
      xmin_bound <- bounds[1]
      xmax_bound <- bounds[2]
    } else if (x1 <= x2) {
      xmin_bound <- x1
      xmax_bound <- x2
    } else {
      xmin_bound <- x2
      xmax_bound <- x1
    }

    # Code for most common case: cubic interpolation of 2 points
    #   w/ function and derivative values for both
    # Solution in this case (where x2 is the farthest point):
    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
    #   d2 = sqrt(d1^2 - g1*g2);
    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
    d1 <- g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
    d2_square <- d1^2 - g1 * g2
    if (as.logical(d2_square>= 0)) {
      d2 <- sqrt(d2_square)
      min_pos <- if (x1 <= x2) {
        x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
      } else {
        x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
      }
      as.numeric(min(max(min_pos, xmin_bound), xmax_bound))
    } else {
      as.numeric((xmin_bound + xmax_bound) / 2)
    }
  }


# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
#
# Parameters:
# - objfunc             a function (the objective) that takes as inputs the point of evaluation,
#                         the step size, and the descent direction, and returns f(X) and df/dX
# - x                   initial point / starting location
# - t                   initial step size
# - d                   descent direction
# - f                   initial function value
# - g                   gradient at initial location
# - gtd                 directional derivative at starting location
# - c1                  sufficient decrease parameter
# - c2                  curvature parameter
# - tolerance_change    minimum allowable step length
# - max_ls              maximum nb of iterations
#
# Return values:
# - f                   function value at x+t*d
# - g                   gradient value at x+t*d
# - x                   the next x (=x+t*d)
# - t                   the step length
# - ls_func_evals       the number of function evaluations
#
#
### Rationale ###
# 1 (sufficient decrease / Armijo condition):
#    function value should decrease at least as a fraction of how it would decrease with seepest descent
# 2 (curvature condition):
#    gradient should not be too steep, or else we could hope for stronger decrease
# 3 (strong Wolfe modification):
#    gradient should not be too positive either
# ยด
.strong_wolfe <- function(obj_func,
                          x,
                          t,
                          d,
                          f,
                          g,
                          gtd,
                          c1 = 1e-4,
                          c2 = 0.9,
                          tolerance_change = 1e-9,
                          max_ls = 25) {
  d_norm <- d$abs()$max()
  g <- g$clone(memory_format = torch_contiguous_format())
  # evaluate objective and gradient using initial step
  ret <- obj_func(x, t, d)
  f_new <- ret[[1]]
  g_new <- ret[[2]]
  ls_func_evals <- 1
  gtd_new <- g_new$dot(d)

  # initial phase: find a solution point, or
  # bracket an initial interval containing a point satisfying the strong Wolfe criteria
  t_prev <- 0
  f_prev <- f
  g_prev <- g
  gtd_prev <- gtd
  ls_iter <- 0
  done <- FALSE

  while (ls_iter < max_ls) {
    # sufficient decrease (Armijo) condition violated
    # construct interval between previous (smaller) step size and current
    if (as.logical(f_new > f + c1 * t * gtd) ||
      as.logical(ls_iter > 1 && f_new >= f_prev)) {
      bracket <- c(t_prev, t)
      bracket_f <- c(f_prev, f_new)
      bracket_g <- c(g_prev, g_new$clone(memory_format = torch_contiguous_format()))
      bracket_gtd <- c(gtd_prev, gtd_new)
      # cat("Initial search: sufficient decrease condition violated", "\n")
      break
    }

    # curvature condition satisfied (slope not too steep)
    # return current parameters, no further zoom stage
    if (as.logical(abs(gtd_new) <= -c2 * gtd)) {
      bracket <- c(t)
      bracket_f <- c(f_new)
      bracket_g <- c(g_new)
      done <- TRUE
      # cat("Initial search: strong Wolfe condition satisfied", "\n")
      break
    }

    # curvature condition (strong Wolfe 2) violated (gradient positive)
    # construct interval between previous (smaller) step size and current
    if (as.logical(gtd_new >= 0)) {
      bracket <- c(t_prev, t)
      bracket_f <- c(f_prev, f_new)
      bracket_g <- c(g_prev, g_new$clone(memory_format = torch_contiguous_format()))
      bracket_gtd <- c(gtd_prev, gtd_new)
      # cat("Initial search: curvature condition violated (gradient positive)", "\n")
      break
    }

    # interpolate
    min_step <- t + 0.01 * (t - t_prev)
    max_step <- t * 10
    tmp <- t
    t <- .cubic_interpolate(t_prev,
      f_prev,
      gtd_prev,
      t,
      f_new,
      gtd_new,
      bounds = c(min_step, max_step)
    )

    # next step
    t_prev <- tmp
    f_prev <- f_new
    g_prev <- g_new$clone(memory_format = torch_contiguous_format())
    gtd_prev <- gtd_new
    ret <- obj_func(x, t, d)
    f_new <- ret[[1]]
    g_new <- ret[[2]]
    ls_func_evals <- ls_func_evals + 1
    gtd_new <- g_new$dot(d)
    ls_iter <- ls_iter + 1
  }

  # reached max number of iterations?
  if (ls_iter == max_ls) {
    bracket <- c(0, t)
    bracket_f <- c(f, f_new)
    bracket_g <- c(g, g_new)
  }

  # zoom phase: we now have a point satisfying the criteria, or
  # a bracket around it. We refine the bracket until we find the
  # exact point satisfying the criteria
  insuf_progress <- FALSE
  # find high and low points in bracket
  if (bracket_f[[1]] <= bracket_f[[length(bracket_f)]]) {
    low_pos <- 1
    high_pos <- 2
  } else {
    low_pos <- 2
    high_pos <- 1
  }

  while ((done != TRUE) && (ls_iter < max_ls)) {
    # line-search bracket is so small
    if (as.logical(abs(bracket[[2]] - bracket[[1]]) * d_norm < tolerance_change)) {
      # cat("Zoom phase: bracket too small", "\n")
      break
    }


    # compute new trial value
    t <-
      .cubic_interpolate(
        bracket[[1]],
        bracket_f[[1]],
        bracket_gtd[[1]],
        bracket[[2]],
        bracket_f[[2]],
        bracket_gtd[[2]]
      )
    # test that we are making sufficient progress:
    # in case t is too close to boundary, we mark that we are making insufficient progress,
    # and if we have made insufficient progress in the last step,
    # or t is at one of the boundaries,
    # we will move t to a position which is `0.1 * len(bracket)` away from the nearest boundary point.
    eps <- 0.1 * (max(bracket) - min(bracket))
    if (min(max(bracket) - t, t - min(bracket)) < eps) {
      # interpolation close to boundary
      if ((insuf_progress == TRUE) ||
        (t >= max(bracket)) || (t <= min(bracket))) {
        # evaluate at 0.1 away from boundary
        if (abs(t - max(bracket)) < abs(t - min(bracket))) {
          t <- max(bracket) - eps
        } else {
          t <- min(bracket) + eps
        }
        insuf_progress <- FALSE
      } else {
        insuf_progress <- TRUE
      }
    } else {
      insuf_progress <- FALSE
    }

    # Evaluate new point
    ret <- obj_func(x, t, d)
    f_new <- ret[[1]]
    g_new <- ret[[2]]
    ls_func_evals <- ls_func_evals + 1
    gtd_new <- g_new$dot(d)
    ls_iter <- ls_iter + 1

    # Armijo condition violated or not lower than lowest point
    if (as.logical(f_new > (f + c1 * t * gtd)) ||
      as.logical(f_new >= bracket_f[[low_pos]])) {
      # cat("Zoom phase: sufficient decrease condition violated", "\n")
      bracket[[high_pos]] <- t
      bracket_f[[high_pos]] <- f_new
      bracket_g[[high_pos]] <- g_new
      bracket_gtd[[high_pos]] <- gtd_new
      if (bracket_f[[1]] <= bracket_f[[2]]) {
        low_pos <- 1
        high_pos <- 2
      } else {
        low_pos <- 2
        high_pos <- 1
      }
    } else {
      # Wolfe conditions satisfied
      if (as.logical(abs(gtd_new) <= -c2 * gtd)) {
        done <- TRUE
        # cat("Zoom phase: strong Wolfe condition satisfied", "\n")
      } else if (as.logical(gtd_new * (bracket[[high_pos]] - bracket[[low_pos]]) >= 0)) {
        # old high becomes new low
        bracket[[high_pos]] <- bracket[[low_pos]]
        bracket_f[[high_pos]] <- bracket_f[[low_pos]]
        bracket_g[[high_pos]] <- bracket_g[[low_pos]]
        bracket_gtd[[high_pos]] <- bracket_gtd[[low_pos]]
      }

      # new point becomes new low
      bracket[[low_pos]] <- t
      bracket_f[[low_pos]] <- f_new
      bracket_g[[low_pos]] <- g_new
      bracket_gtd[[low_pos]] <- gtd_new
    }
  }

  t <- bracket[[low_pos]]
  f_new <- torch_tensor(bracket_f[[low_pos]])
  g_new <- bracket_g[[low_pos]]
  list(f_new, g_new, t, ls_func_evals)
}


#' LBFGS optimizer
#'
#'
#' Implements L-BFGS algorithm, heavily inspired by
#' [minFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)
#' 
#' This optimizer is different from the others in that in `optimizer$step()`,
#' it needs to be passed a closure that (1) calculates the loss, (2) calls 
#' `backward()` on it, and (3) returns it. See example below.
#'
#' @section Warning:
#'
#' This optimizer doesn't support per-parameter options and parameter
#' groups (there can be only one).
#'
#' Right now all parameters have to be on a single device. This will be
#' improved in the future.
#'
#' @note
#' This is a very memory intensive optimizer (it requires additional
#' `param_bytes * (history_size + 1)` bytes). If it doesn't fit in memory
#' try reducing the history size, or use a different algorithm.
#'
#' @param lr (float): learning rate (default: 1)
#' @param max_iter (int): maximal number of iterations per optimization step
#'   (default: 20)
#' @param max_eval (int): maximal number of function evaluations per optimization
#'   step (default: max_iter * 1.25).
#' @param tolerance_grad (float): termination tolerance on first order optimality
#'   (default: 1e-5).
#' @param tolerance_change (float): termination tolerance on function
#'   value/parameter changes (default: 1e-9).
#' @param history_size (int): update history size (default: 100).
#' @param line_search_fn (str): either 'strong_wolfe' or None (default: None).
#' @inheritParams optim_sgd
#'
#' @includeRmd man/rmd/optim-note.Rmd note
#' 
#' @examples
#' a <- 1
#' b <- 5
#' rosenbrock <- function(x) {
#'   x1 <- x[1]
#'   x2 <- x[2]
#'   (a - x1)^2 + b * (x2 - x1^2)^2
#' }
#'  
#' x <- torch_tensor(c(-1, 1), requires_grad = TRUE)
#' 
#' optimizer <- optim_lbfgs(x)
#' calc_loss <- function() {
#'   optimizer$zero_grad()
#'   value <- rosenbrock(x)
#'   value$backward()
#'   value
#' }
#'   
#' num_iterations <- 2
#' for (i in 1:num_iterations) {
#'   optimizer$step(calc_loss)
#' }
#'     
#' rosenbrock(x)
#' 
#' @export
optim_lbfgs <- optimizer(
  "optim_lbfgs",
  initialize = function(params,
                        lr = 1,
                        max_iter = 20,
                        max_eval = NULL,
                        tolerance_grad = 1e-7,
                        tolerance_change = 1e-9,
                        history_size = 100,
                        line_search_fn = NULL) {
    if (is.null(max_eval)) {
      max_eval <- as.integer(max_iter * 5 / 4)
    }

    defaults <- list(
      lr = lr,
      max_iter = max_iter,
      max_eval = max_eval,
      tolerance_grad = tolerance_grad,
      tolerance_change = tolerance_change,
      history_size = history_size,
      line_search_fn = line_search_fn
    )

    super$initialize(params, defaults)

    if (length(self$param_groups) != 1) {
      value_error(
        "LBFGS doesn't support per-parameter options ",
        "(parameter groups)"
      )
    }

    private$.params <- self$param_groups[[1]][["params"]]
    private$.numel_cache <- NULL
  },
  step = function(closure) {
    with_no_grad({
      closure_ <- function() {
        with_enable_grad({
          closure()
        })
      }

      group <- self$param_groups[[1]]
      lr <- group[["lr"]]
      max_iter <- group[["max_iter"]]
      max_eval <- group[["max_eval"]]
      tolerance_grad <- group[["tolerance_grad"]]
      tolerance_change <- group[["tolerance_change"]]
      line_search_fn <- group[["line_search_fn"]]
      history_size <- group[["history_size"]]

      # NOTE: LBFGS has only global state, but we register it as state for
      # the first param, because this helps with casting in load_state_dict
      if (is.null(state(private$.params[[1]]))) {
        state(private$.params[[1]]) <- new.env(parent = emptyenv())
        state(private$.params[[1]])[["func_evals"]] <- 0
        state(private$.params[[1]])[["n_iter"]] <- 0
      }
      state <- state(private$.params[[1]])

      # evaluate initial f(x) and df/dx
      orig_loss <- closure_()
      loss <- orig_loss$item()
      current_evals <- 1
      state[["func_evals"]] <- state[["func_evals"]] + 1

      flat_grad <- private$.gather_flat_grad()
      opt_cond <- (flat_grad$abs()$max() <= tolerance_grad)$item()

      if (opt_cond) {
        return(orig_loss)
      }

      # tensors cached in state (for tracing)
      d <- state[["d"]]
      t <- state[["t"]]
      old_dirs <- state[["old_dirs"]]
      old_stps <- state[["old_stps"]]
      ro <- state[["ro"]]
      H_diag <- state[["H_diag"]]
      prev_flat_grad <- state[["prev_flat_grad"]]
      prev_loss <- state[["prev_loss"]]

      n_iter <- 0
      # optimize for a max of max_iter iterations
      while (n_iter < max_iter) {
        # keep track of nb of iterations
        n_iter <- n_iter + 1
        state[["n_iter"]] <- state[["n_iter"]] + 1

        ############################################################
        # compute gradient descent direction
        ############################################################

        if (state[["n_iter"]] == 1) {
          d <- flat_grad$neg()
          old_dirs <- list()
          old_stps <- list()
          ro <- list()
          H_diag <- 1
        } else {
          # do lbfgs update (update memory)
          y <- flat_grad$sub(prev_flat_grad)
          s <- d$mul(t)
          ys <- y$dot(s) # y*s

          if (!is.na(ys$item()) && ys$item() > 1e-10) {
            # updating memory
            if (length(old_dirs) == history_size) {
              # shift history by one (limited-memory)
              old_dirs <- old_dirs[-1]
              old_stps <- old_stps[-1]
              ro <- ro[-1]
            }

            # store new direction/step
            old_dirs[[length(old_dirs) + 1]] <- y
            old_stps[[length(old_stps) + 1]] <- s
            ro[[length(ro) + 1]] <- 1. / ys

            # update scale of initial Hessian approximation
            H_diag <- ys / y$dot(y) # (y*y)
          }

          # compute the approximate (L-BFGS) inverse Hessian
          # multiplied by the gradient
          num_old <- length(old_dirs)

          if (is.null(state[["al"]])) {
            state[["al"]] <- vector(mode = "list", length = history_size)
          }

          al <- state[["al"]]

          # iteration in L-BFGS loop collapsed to use just one buffer
          q <- flat_grad$neg()
          if (num_old >= 1) {
            for (i in seq(num_old, 1, by = -1)) {
              al[[i]] <- old_stps[[i]]$dot(q) * ro[[i]]
              q$add_(old_dirs[[i]], alpha = -al[[i]])
            }
          }

          # multiply by initial Hessian
          # r/d is the final direction
          d <- r <- torch_mul(q, H_diag)
          for (i in seq_len(num_old)) {
            be_i <- old_dirs[[i]]$dot(r) * ro[[i]]
            r$add_(old_stps[[i]], alpha = al[[i]] - be_i)
          }
        }

        if (is.null(prev_flat_grad) || is_undefined_tensor(prev_flat_grad)) {
          prev_flat_grad <- flat_grad$clone(memory_format = torch_contiguous_format())
        } else {
          prev_flat_grad$copy_(flat_grad)
        }
        prev_loss <- loss

        ############################################################
        # compute step length
        ############################################################
        # reset initial guess for step size
        if (state[["n_iter"]] == 1) {
          t <- min(1., 1. / flat_grad$abs()$sum()$item()) * lr
        } else {
          t <- lr
        }

        # directional derivative
        gtd <- flat_grad$dot(d) # g * d

        # directional derivative is below tolerance
        if (!is.na(gtd$item()) && gtd$item() > (-tolerance_change)) {
          break
        }

        # optional line search: user function
        ls_func_evals <- 0

        if (!is.null(line_search_fn)) {
          if (line_search_fn != "strong_wolfe") {
            value_error("only strong_wolfe is supported")
          } else {
            x_init <- private$.clone_param()

            obj_func <- function(x, t, d) {
              private$.directional_evaluate(closure_, x, t, d)
            }

            ret <- .strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd)

            loss <- ret[[1]]$item()
            flat_grad <- ret[[2]]
            t <- ret[[3]]
            ls_func_evals <- ret[[4]]

            private$.add_grad(t, d)
            opt_cond <- flat_grad$abs()$max()$item() <= tolerance_grad
          }
        } else {
          # no line search, simply move with fixed-step
          private$.add_grad(t, d)
          if (n_iter != max_iter) {
            # re-evaluate function only if not in last iteration
            # the reason we do this: in a stochastic setting,
            # no use to re-evaluate that function here
            loss <- closure_()$item()
            flat_grad <- private$.gather_flat_grad()
            opt_cond <- flat_grad$abs()$max()$item() <= tolerance_grad
            ls_func_evals <- 1
          }
        }

        # update func eval
        current_evals <- current_evals + ls_func_evals
        state[["func_evals"]] <- state[["func_evals"]] + ls_func_evals

        ############################################################
        # check conditions
        ############################################################
        if (n_iter == max_iter) {
          break
        }

        if (current_evals >= max_eval) {
          break
        }

        # optimal condition
        if (!is.na(opt_cond) && opt_cond) {
          break
        }

        # lack of progress
        d_ml <- d$mul(t)$abs()$max()$item()
        if (!is.na(d_ml) && d_ml <= tolerance_change) {
          break
        }

        if (!is.na(loss) && abs(loss - prev_loss) < tolerance_change) {
          break
        }
      }

      state[["d"]] <- d
      state[["t"]] <- t
      state[["old_dirs"]] <- old_dirs
      state[["old_stps"]] <- old_stps
      state[["ro"]] <- ro
      state[["H_diag"]] <- H_diag
      state[["prev_flat_grad"]] <- prev_flat_grad
      state[["prev_loss"]] <- prev_loss
    })

    orig_loss
  },
  private = list(
    .numel = function() {
      if (is.null(private$.numel_cache)) {
        private$.numel_cache <- sum(sapply(private$.params, function(p) p$numel()))
      }
      private$._numel_cache
    },
    .gather_flat_grad = function() {
      views <- list()
      for (i in seq_along(private$.params)) {
        p <- private$.params[[i]]
        if (is_undefined_tensor(p$grad)) {
          view <- .new_zeros_tensor(p)
        } else {
          view <- p$grad$view(-1)
        }
        views[[i]] <- view
      }
      torch_cat(views, dim = 1)
    },
    .add_grad = function(step_size, update) {
      offset <- 1
      for (p in private$.params) {
        numel <- p$numel()
        p$add_(update[offset:(offset + numel - 1)]$view_as(p), alpha = step_size)
        offset <- offset + numel
      }
      stopifnot(offset == private$.numel())
    },
    .clone_param = function() {
      lapply(private$.params, function(p) p$clone(memory_format = torch_contiguous_format()))
    },
    .set_param = function(params_data) {
      for (i in seq_along(private$.params)) {
        private$.params[[i]]$copy_(params_data[[i]])
      }
    },
    .directional_evaluate = function(closure, x, t, d) {
      private$.add_grad(t, d)
      loss <- closure()$item()
      flat_grad <- private$.gather_flat_grad()
      private$.set_param(x)
      list(loss, flat_grad)
    }
  )
)

.new_zeros_tensor <- function(p) {
  torch_zeros(p$numel(), dtype = p$dtype, device = p$device)
}

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on June 7, 2023, 6:19 p.m.