Nothing
#' @include optim.R
NULL
# Compute minimum of interpolating polynomial based on function and derivative values
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
.cubic_interpolate <-
function(x1, f1, g1, x2, f2, g2, bounds = NULL) {
# Compute bounds of interpolation area
if (!is.null(bounds)) {
xmin_bound <- bounds[1]
xmax_bound <- bounds[2]
} else if (x1 <= x2) {
xmin_bound <- x1
xmax_bound <- x2
} else {
xmin_bound <- x2
xmax_bound <- x1
}
# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 <- g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square <- d1^2 - g1 * g2
if (as.logical(d2_square>= 0)) {
d2 <- sqrt(d2_square)
min_pos <- if (x1 <= x2) {
x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
} else {
x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
}
as.numeric(min(max(min_pos, xmin_bound), xmax_bound))
} else {
as.numeric((xmin_bound + xmax_bound) / 2)
}
}
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
#
# Parameters:
# - objfunc a function (the objective) that takes as inputs the point of evaluation,
# the step size, and the descent direction, and returns f(X) and df/dX
# - x initial point / starting location
# - t initial step size
# - d descent direction
# - f initial function value
# - g gradient at initial location
# - gtd directional derivative at starting location
# - c1 sufficient decrease parameter
# - c2 curvature parameter
# - tolerance_change minimum allowable step length
# - max_ls maximum nb of iterations
#
# Return values:
# - f function value at x+t*d
# - g gradient value at x+t*d
# - x the next x (=x+t*d)
# - t the step length
# - ls_func_evals the number of function evaluations
#
#
### Rationale ###
# 1 (sufficient decrease / Armijo condition):
# function value should decrease at least as a fraction of how it would decrease with seepest descent
# 2 (curvature condition):
# gradient should not be too steep, or else we could hope for stronger decrease
# 3 (strong Wolfe modification):
# gradient should not be too positive either
# ยด
.strong_wolfe <- function(obj_func,
x,
t,
d,
f,
g,
gtd,
c1 = 1e-4,
c2 = 0.9,
tolerance_change = 1e-9,
max_ls = 25) {
d_norm <- d$abs()$max()
g <- g$clone(memory_format = torch_contiguous_format())
# evaluate objective and gradient using initial step
ret <- obj_func(x, t, d)
f_new <- ret[[1]]
g_new <- ret[[2]]
ls_func_evals <- 1
gtd_new <- g_new$dot(d)
# initial phase: find a solution point, or
# bracket an initial interval containing a point satisfying the strong Wolfe criteria
t_prev <- 0
f_prev <- f
g_prev <- g
gtd_prev <- gtd
ls_iter <- 0
done <- FALSE
while (ls_iter < max_ls) {
# sufficient decrease (Armijo) condition violated
# construct interval between previous (smaller) step size and current
if (as.logical(f_new > f + c1 * t * gtd) ||
as.logical(ls_iter > 1 && f_new >= f_prev)) {
bracket <- c(t_prev, t)
bracket_f <- c(f_prev, f_new)
bracket_g <- c(g_prev, g_new$clone(memory_format = torch_contiguous_format()))
bracket_gtd <- c(gtd_prev, gtd_new)
# cat("Initial search: sufficient decrease condition violated", "\n")
break
}
# curvature condition satisfied (slope not too steep)
# return current parameters, no further zoom stage
if (as.logical(abs(gtd_new) <= -c2 * gtd)) {
bracket <- c(t)
bracket_f <- c(f_new)
bracket_g <- c(g_new)
done <- TRUE
# cat("Initial search: strong Wolfe condition satisfied", "\n")
break
}
# curvature condition (strong Wolfe 2) violated (gradient positive)
# construct interval between previous (smaller) step size and current
if (as.logical(gtd_new >= 0)) {
bracket <- c(t_prev, t)
bracket_f <- c(f_prev, f_new)
bracket_g <- c(g_prev, g_new$clone(memory_format = torch_contiguous_format()))
bracket_gtd <- c(gtd_prev, gtd_new)
# cat("Initial search: curvature condition violated (gradient positive)", "\n")
break
}
# interpolate
min_step <- t + 0.01 * (t - t_prev)
max_step <- t * 10
tmp <- t
t <- .cubic_interpolate(t_prev,
f_prev,
gtd_prev,
t,
f_new,
gtd_new,
bounds = c(min_step, max_step)
)
# next step
t_prev <- tmp
f_prev <- f_new
g_prev <- g_new$clone(memory_format = torch_contiguous_format())
gtd_prev <- gtd_new
ret <- obj_func(x, t, d)
f_new <- ret[[1]]
g_new <- ret[[2]]
ls_func_evals <- ls_func_evals + 1
gtd_new <- g_new$dot(d)
ls_iter <- ls_iter + 1
}
# reached max number of iterations?
if (ls_iter == max_ls) {
bracket <- c(0, t)
bracket_f <- c(f, f_new)
bracket_g <- c(g, g_new)
}
# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress <- FALSE
# find high and low points in bracket
if (bracket_f[[1]] <= bracket_f[[length(bracket_f)]]) {
low_pos <- 1
high_pos <- 2
} else {
low_pos <- 2
high_pos <- 1
}
while ((done != TRUE) && (ls_iter < max_ls)) {
# line-search bracket is so small
if (as.logical(abs(bracket[[2]] - bracket[[1]]) * d_norm < tolerance_change)) {
# cat("Zoom phase: bracket too small", "\n")
break
}
# compute new trial value
t <-
.cubic_interpolate(
bracket[[1]],
bracket_f[[1]],
bracket_gtd[[1]],
bracket[[2]],
bracket_f[[2]],
bracket_gtd[[2]]
)
# test that we are making sufficient progress:
# in case t is too close to boundary, we mark that we are making insufficient progress,
# and if we have made insufficient progress in the last step,
# or t is at one of the boundaries,
# we will move t to a position which is `0.1 * len(bracket)` away from the nearest boundary point.
eps <- 0.1 * (max(bracket) - min(bracket))
if (min(max(bracket) - t, t - min(bracket)) < eps) {
# interpolation close to boundary
if ((insuf_progress == TRUE) ||
(t >= max(bracket)) || (t <= min(bracket))) {
# evaluate at 0.1 away from boundary
if (abs(t - max(bracket)) < abs(t - min(bracket))) {
t <- max(bracket) - eps
} else {
t <- min(bracket) + eps
}
insuf_progress <- FALSE
} else {
insuf_progress <- TRUE
}
} else {
insuf_progress <- FALSE
}
# Evaluate new point
ret <- obj_func(x, t, d)
f_new <- ret[[1]]
g_new <- ret[[2]]
ls_func_evals <- ls_func_evals + 1
gtd_new <- g_new$dot(d)
ls_iter <- ls_iter + 1
# Armijo condition violated or not lower than lowest point
if (as.logical(f_new > (f + c1 * t * gtd)) ||
as.logical(f_new >= bracket_f[[low_pos]])) {
# cat("Zoom phase: sufficient decrease condition violated", "\n")
bracket[[high_pos]] <- t
bracket_f[[high_pos]] <- f_new
bracket_g[[high_pos]] <- g_new
bracket_gtd[[high_pos]] <- gtd_new
if (bracket_f[[1]] <= bracket_f[[2]]) {
low_pos <- 1
high_pos <- 2
} else {
low_pos <- 2
high_pos <- 1
}
} else {
# Wolfe conditions satisfied
if (as.logical(abs(gtd_new) <= -c2 * gtd)) {
done <- TRUE
# cat("Zoom phase: strong Wolfe condition satisfied", "\n")
} else if (as.logical(gtd_new * (bracket[[high_pos]] - bracket[[low_pos]]) >= 0)) {
# old high becomes new low
bracket[[high_pos]] <- bracket[[low_pos]]
bracket_f[[high_pos]] <- bracket_f[[low_pos]]
bracket_g[[high_pos]] <- bracket_g[[low_pos]]
bracket_gtd[[high_pos]] <- bracket_gtd[[low_pos]]
}
# new point becomes new low
bracket[[low_pos]] <- t
bracket_f[[low_pos]] <- f_new
bracket_g[[low_pos]] <- g_new
bracket_gtd[[low_pos]] <- gtd_new
}
}
t <- bracket[[low_pos]]
f_new <- torch_tensor(bracket_f[[low_pos]])
g_new <- bracket_g[[low_pos]]
list(f_new, g_new, t, ls_func_evals)
}
#' LBFGS optimizer
#'
#'
#' Implements L-BFGS algorithm, heavily inspired by
#' [minFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)
#'
#' This optimizer is different from the others in that in `optimizer$step()`,
#' it needs to be passed a closure that (1) calculates the loss, (2) calls
#' `backward()` on it, and (3) returns it. See example below.
#'
#' @section Warning:
#'
#' This optimizer doesn't support per-parameter options and parameter
#' groups (there can be only one).
#'
#' Right now all parameters have to be on a single device. This will be
#' improved in the future.
#'
#' @note
#' This is a very memory intensive optimizer (it requires additional
#' `param_bytes * (history_size + 1)` bytes). If it doesn't fit in memory
#' try reducing the history size, or use a different algorithm.
#'
#' @param lr (float): learning rate (default: 1)
#' @param max_iter (int): maximal number of iterations per optimization step
#' (default: 20)
#' @param max_eval (int): maximal number of function evaluations per optimization
#' step (default: max_iter * 1.25).
#' @param tolerance_grad (float): termination tolerance on first order optimality
#' (default: 1e-5).
#' @param tolerance_change (float): termination tolerance on function
#' value/parameter changes (default: 1e-9).
#' @param history_size (int): update history size (default: 100).
#' @param line_search_fn (str): either 'strong_wolfe' or None (default: None).
#' @inheritParams optim_sgd
#'
#' @includeRmd man/rmd/optim-note.Rmd note
#'
#' @examples
#' a <- 1
#' b <- 5
#' rosenbrock <- function(x) {
#' x1 <- x[1]
#' x2 <- x[2]
#' (a - x1)^2 + b * (x2 - x1^2)^2
#' }
#'
#' x <- torch_tensor(c(-1, 1), requires_grad = TRUE)
#'
#' optimizer <- optim_lbfgs(x)
#' calc_loss <- function() {
#' optimizer$zero_grad()
#' value <- rosenbrock(x)
#' value$backward()
#' value
#' }
#'
#' num_iterations <- 2
#' for (i in 1:num_iterations) {
#' optimizer$step(calc_loss)
#' }
#'
#' rosenbrock(x)
#'
#' @export
optim_lbfgs <- optimizer(
"optim_lbfgs",
initialize = function(params,
lr = 1,
max_iter = 20,
max_eval = NULL,
tolerance_grad = 1e-7,
tolerance_change = 1e-9,
history_size = 100,
line_search_fn = NULL) {
if (is.null(max_eval)) {
max_eval <- as.integer(max_iter * 5 / 4)
}
defaults <- list(
lr = lr,
max_iter = max_iter,
max_eval = max_eval,
tolerance_grad = tolerance_grad,
tolerance_change = tolerance_change,
history_size = history_size,
line_search_fn = line_search_fn
)
super$initialize(params, defaults)
if (length(self$param_groups) != 1) {
value_error(
"LBFGS doesn't support per-parameter options ",
"(parameter groups)"
)
}
private$.params <- self$param_groups[[1]][["params"]]
private$.numel_cache <- NULL
},
step = function(closure) {
with_no_grad({
closure_ <- function() {
with_enable_grad({
closure()
})
}
group <- self$param_groups[[1]]
lr <- group[["lr"]]
max_iter <- group[["max_iter"]]
max_eval <- group[["max_eval"]]
tolerance_grad <- group[["tolerance_grad"]]
tolerance_change <- group[["tolerance_change"]]
line_search_fn <- group[["line_search_fn"]]
history_size <- group[["history_size"]]
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
if (is.null(state(private$.params[[1]]))) {
state(private$.params[[1]]) <- new.env(parent = emptyenv())
state(private$.params[[1]])[["func_evals"]] <- 0
state(private$.params[[1]])[["n_iter"]] <- 0
}
state <- state(private$.params[[1]])
# evaluate initial f(x) and df/dx
orig_loss <- closure_()
loss <- orig_loss$item()
current_evals <- 1
state[["func_evals"]] <- state[["func_evals"]] + 1
flat_grad <- private$.gather_flat_grad()
opt_cond <- (flat_grad$abs()$max() <= tolerance_grad)$item()
if (opt_cond) {
return(orig_loss)
}
# tensors cached in state (for tracing)
d <- state[["d"]]
t <- state[["t"]]
old_dirs <- state[["old_dirs"]]
old_stps <- state[["old_stps"]]
ro <- state[["ro"]]
H_diag <- state[["H_diag"]]
prev_flat_grad <- state[["prev_flat_grad"]]
prev_loss <- state[["prev_loss"]]
n_iter <- 0
# optimize for a max of max_iter iterations
while (n_iter < max_iter) {
# keep track of nb of iterations
n_iter <- n_iter + 1
state[["n_iter"]] <- state[["n_iter"]] + 1
############################################################
# compute gradient descent direction
############################################################
if (state[["n_iter"]] == 1) {
d <- flat_grad$neg()
old_dirs <- list()
old_stps <- list()
ro <- list()
H_diag <- 1
} else {
# do lbfgs update (update memory)
y <- flat_grad$sub(prev_flat_grad)
s <- d$mul(t)
ys <- y$dot(s) # y*s
if (!is.na(ys$item()) && ys$item() > 1e-10) {
# updating memory
if (length(old_dirs) == history_size) {
# shift history by one (limited-memory)
old_dirs <- old_dirs[-1]
old_stps <- old_stps[-1]
ro <- ro[-1]
}
# store new direction/step
old_dirs[[length(old_dirs) + 1]] <- y
old_stps[[length(old_stps) + 1]] <- s
ro[[length(ro) + 1]] <- 1. / ys
# update scale of initial Hessian approximation
H_diag <- ys / y$dot(y) # (y*y)
}
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old <- length(old_dirs)
if (is.null(state[["al"]])) {
state[["al"]] <- vector(mode = "list", length = history_size)
}
al <- state[["al"]]
# iteration in L-BFGS loop collapsed to use just one buffer
q <- flat_grad$neg()
if (num_old >= 1) {
for (i in seq(num_old, 1, by = -1)) {
al[[i]] <- old_stps[[i]]$dot(q) * ro[[i]]
q$add_(old_dirs[[i]], alpha = -al[[i]])
}
}
# multiply by initial Hessian
# r/d is the final direction
d <- r <- torch_mul(q, H_diag)
for (i in seq_len(num_old)) {
be_i <- old_dirs[[i]]$dot(r) * ro[[i]]
r$add_(old_stps[[i]], alpha = al[[i]] - be_i)
}
}
if (is.null(prev_flat_grad) || is_undefined_tensor(prev_flat_grad)) {
prev_flat_grad <- flat_grad$clone(memory_format = torch_contiguous_format())
} else {
prev_flat_grad$copy_(flat_grad)
}
prev_loss <- loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if (state[["n_iter"]] == 1) {
t <- min(1., 1. / flat_grad$abs()$sum()$item()) * lr
} else {
t <- lr
}
# directional derivative
gtd <- flat_grad$dot(d) # g * d
# directional derivative is below tolerance
if (!is.na(gtd$item()) && gtd$item() > (-tolerance_change)) {
break
}
# optional line search: user function
ls_func_evals <- 0
if (!is.null(line_search_fn)) {
if (line_search_fn != "strong_wolfe") {
value_error("only strong_wolfe is supported")
} else {
x_init <- private$.clone_param()
obj_func <- function(x, t, d) {
private$.directional_evaluate(closure_, x, t, d)
}
ret <- .strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd)
loss <- ret[[1]]$item()
flat_grad <- ret[[2]]
t <- ret[[3]]
ls_func_evals <- ret[[4]]
private$.add_grad(t, d)
opt_cond <- flat_grad$abs()$max()$item() <= tolerance_grad
}
} else {
# no line search, simply move with fixed-step
private$.add_grad(t, d)
if (n_iter != max_iter) {
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
loss <- closure_()$item()
flat_grad <- private$.gather_flat_grad()
opt_cond <- flat_grad$abs()$max()$item() <= tolerance_grad
ls_func_evals <- 1
}
}
# update func eval
current_evals <- current_evals + ls_func_evals
state[["func_evals"]] <- state[["func_evals"]] + ls_func_evals
############################################################
# check conditions
############################################################
if (n_iter == max_iter) {
break
}
if (current_evals >= max_eval) {
break
}
# optimal condition
if (!is.na(opt_cond) && opt_cond) {
break
}
# lack of progress
d_ml <- d$mul(t)$abs()$max()$item()
if (!is.na(d_ml) && d_ml <= tolerance_change) {
break
}
if (!is.na(loss) && abs(loss - prev_loss) < tolerance_change) {
break
}
}
state[["d"]] <- d
state[["t"]] <- t
state[["old_dirs"]] <- old_dirs
state[["old_stps"]] <- old_stps
state[["ro"]] <- ro
state[["H_diag"]] <- H_diag
state[["prev_flat_grad"]] <- prev_flat_grad
state[["prev_loss"]] <- prev_loss
})
orig_loss
},
private = list(
.numel = function() {
if (is.null(private$.numel_cache)) {
private$.numel_cache <- sum(sapply(private$.params, function(p) p$numel()))
}
private$._numel_cache
},
.gather_flat_grad = function() {
views <- list()
for (i in seq_along(private$.params)) {
p <- private$.params[[i]]
if (is_undefined_tensor(p$grad)) {
view <- .new_zeros_tensor(p)
} else {
view <- p$grad$view(-1)
}
views[[i]] <- view
}
torch_cat(views, dim = 1)
},
.add_grad = function(step_size, update) {
offset <- 1
for (p in private$.params) {
numel <- p$numel()
p$add_(update[offset:(offset + numel - 1)]$view_as(p), alpha = step_size)
offset <- offset + numel
}
stopifnot(offset == private$.numel())
},
.clone_param = function() {
lapply(private$.params, function(p) p$clone(memory_format = torch_contiguous_format()))
},
.set_param = function(params_data) {
for (i in seq_along(private$.params)) {
private$.params[[i]]$copy_(params_data[[i]])
}
},
.directional_evaluate = function(closure, x, t, d) {
private$.add_grad(t, d)
loss <- closure()$item()
flat_grad <- private$.gather_flat_grad()
private$.set_param(x)
list(loss, flat_grad)
}
)
)
.new_zeros_tensor <- function(p) {
torch_zeros(p$numel(), dtype = p$dtype, device = p$device)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.