#' Compute largest magnitude eigenvalue of Hessian and
#' corresponding eigenvector
#'
#' @description
#' \itemize{
#' \item If \code{udraws} is a vector, the output is the eigenvalue/
#' eigenvector pair computed at the point in parameter
#' space corresponding to \code{udraws}.
#' \item If \code{udraws} is a matrix, each row is interpreted as a place
#' to compute an eigenvalue/eigenvector pair and the output
#' is a list of eigenvalue/eigenvector pairs.
#' }
#' @export
#' @param fit An rstan stanfit.
#' @param L An optional preconditioner matrix. If this is supplied
#' compute the eigenvalue/eigenvector pairs of \code{L^T * H * L}.
#' @param max_iterations Maximum number of power method iterations
#' to use on any output.
#' @param tol Relative tolerance check done on the eigenvalue.
#' @param udraws Place(s) to compute Hessian(s) at.
#' @return eigenvalue/eigenvector pairs
rstan_power_method <- function(fit, udraws, L = NULL,
max_iterations = 200, tol = 1e-5) {
# TODO: this function has high cyclomatic complexity (many if-elses)
# We should consider making it more readable by separating parts to
# sub-functions
if (!setequal(class(fit), c("stanfit"))) {
msg <- "fit should be a stanfit object (from rstan)"
stop(msg)
}
if (is.null(nrow(udraws))) {
udraws <- matrix(udraws, nrow = 1)
} else {
udraws <- matrix(udraws, nrow = nrow(udraws))
}
cbrt_epsilon <- .Machine$double.eps**(1 / 3)
M <- nrow(udraws)
N <- ncol(udraws)
if (N != rstan::get_num_upars(fit)) {
msg <- paste(
"Supplied upars has", N,
"parameters but model has", rstan::get_num_upars(fit),
"parameters"
)
stop(msg)
}
if (is.null(L)) {
L <- diag(N)
} else {
if (N != nrow(L) || N != ncol(L)) {
msg <- paste(
"L (if supplied) must be a square matrix of size", N,
"by", N
)
stop(msg)
}
}
if (max_iterations < 1) {
msg <- "max_iterations must be greater than 0"
stop(msg)
}
if (tol <= 0.0) {
msg <- "tol must be greater than 0"
stop(msg)
}
out <- list()
for (m in 1:M) {
v <- runif(N)
v <- v / sqrt(sum(v^2))
eval <- 0.0
u <- udraws[m, ]
dx <- cbrt_epsilon * pmax(1, max(abs(u)))
f <- function(v) {
v <- v / sqrt(sum(v^2))
(rstan::grad_log_prob(fit, u + dx * L %*% v) -
rstan::grad_log_prob(fit, u - dx * L %*% v)) / (2.0 * dx)
}
Av <- f(v)
for (i in 1:max_iterations) {
v_norm <- sqrt(sum(v^2))
new_eval <- v %*% Av / (v_norm * v_norm)
if (i == max_iterations
|| abs(new_eval - eval) <= max(tol, tol * abs(eval))) {
eval <- new_eval
break
}
eval <- new_eval
v <- Av / sqrt(sum(Av^2))
Av <- f(v)
}
out[[m]] <- list(e = eval, v = as.vector(v))
}
if (M == 1) {
return(out[[1]])
} else {
return(out)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.