#' Gradient-based marginal optimization
#'
#' Maximize the marginal posterior with respect to specified
#' parameters, with nuisance parameters marginalized out.
#'
#' @param file
#' A character string or a connection that R supports specifying the
#' Stan model specification in Stan's modeling language.
#' @param local_file
#' A character string or a connection that R supports specifying the
#' Stan model specification in Stan's modeling language.
#' @param full_model
#' If provided, an object of class 'stanfit' that makes it unnecessary
#' to pass 'file' or 'local_file'
#' @param data
#' A named ‘list’ or ‘environment’ providing the data for the model
#' or a character vector for all the names of objects used as data.
#' See \code{\link[rstan]{stan}} for more details.
#' @param method
#' A character string naming the conditional inference:
#' "laplace"
#' @param init
#' A numeric vector of length the number of hyperparameters.
#' @param draws
#' A positive integer, number of draws to calculate stochastic gradient.
#' @param iter
#' A positive integer, the maximum number of iterations.
#' @param inner_iter
#' A positive integer, the number of iterations after each
#' conditional inference.
#' @param cond_iter
#' A positive integer, the maximum number of iterations for the
#' conditional inference. Default is to run until convergence.
#' @param eta
#' Double, constant scale factor for learning rate.
#' @param tol
#' Double, tolerance for signaling convergence.
#' @param seed
#' The seed, a positive integer, for random number generation of
#' Stan. The default is generated from 1 to the maximum integer
#' supported by R so fixing the seed of R's random number generator
#' can essentially fix the seed of Stan. When multiple chains are
#' used, only one seed is needed, with other chains' seeds being
#' generated from the first chain's seed to prevent dependency among
#' the random number streams for the chains. When a seed is
#' specified by a number, ‘as.integer’ will be applied to it. If
#' ‘as.integer’ produces ‘NA’, the seed is generated randomly. We
#' can also specify a seed using a character string of digits, such
#' as ‘"12345"’, which is converted to integer.
#' @param max_block_size
#' The maximum size of the blocks for a block-diagonal approximation
#' of the Hessian under Laplace approximation.
#' Default = Inf (no approximation).
#'
#' @return
#' An object of reference class \code{"gmo"}. It is a list containing
#' the following components:
#' \item{par}{a vector of optimized parameters}
#' \item{cov}{estimated covariance matrix at \code{par}}
#' \item{sims}{\code{draws * inner_iter} many samples from the last
#' approximation to the conditional posterior, p(alpha | y, phi)}
#'
#' @import methods
#' @importFrom rstan stan optimizing vb sampling constrain_pars log_prob grad_log_prob stan_model get_stanmodel
#' @importFrom loo psislw
#'
gmo <- function(file, local_file, full_model, data,
method=c("laplace"), init="random",
draws=5L, iter=100L, inner_iter=10L, cond_iter=NA, eta=1,
tol=1e-3, seed=1234L, max_block_size=Inf, N_sigma_points=0) {
if (missing(full_model)) {
full_model <- suppressMessages(
stan(file, data = c(data, list(GMO_FLAG = FALSE, fixed_phi = double())),
chains = 0, iter = 1))
}
#else stopifnot(is(full_model, "stanfit"))
local_model <- if (!missing(local_file)) stan_model(local_file) else
get_stanmodel(full_model)
if (N_sigma_points > 0) {
use_sigma_points <- TRUE
inner_iter <- 1L
draws <- N_sigma_points
}
else {
use_sigma_points <- FALSE
}
obj <- GMO$new(
calc_log_p="exact",
full_model=full_model,
local_model=local_model,
data=data,
method=match.arg(method),
init=init,
draws=draws,
iter=iter,
inner_iter=inner_iter,
cond_iter=structure(cond_iter, class="integer"),
eta=eta,
tol=tol,
seed=seed,
max_block_size=max_block_size,
use_sigma_points=use_sigma_points
)
obj$run()
return(obj)
}
#' Approximate gradient-based marginal optimization
#'
#' Maximize a lower bound to the marginal posterior with respect to
#' specified parameters, with nuisance parameters marginalized out.
#'
#' @param file
#' A character string or a connection that R supports specifying the
#' Stan model specification in Stan's modeling language.
#' @param local_file
#' A character string or a connection that R supports specifying the
#' Stan model specification in Stan's modeling language.
#' @param full_model
#' If provided, an object of class 'stanfit' that makes it unnecessary
#' to pass 'file' or 'local_file'
#' @param data
#' A named ‘list’ or ‘environment’ providing the data for the model
#' or a character vector for all the names of objects used as data.
#' See \code{\link[rstan]{stan}} for more details.
#' @param method
#' A character string naming the conditional inference:
#' "laplace", "vb", "sampling"
#' @param init
#' A numeric vector of length the number of hyperparameters.
#' @param draws
#' A positive integer, number of draws to calculate stochastic gradient.
#' @param iter
#' A positive integer, the maximum number of iterations.
#' @param inner_iter
#' A positive integer, the number of iterations after each
#' conditional inference.
#' @param cond_iter
#' A positive integer, the maximum number of iterations for the
#' conditional inference. Default is to run until convergence.
#' @param eta
#' Double, constant scale factor for learning rate.
#' @param tol
#' Double, tolerance for signaling convergence.
#' @param seed
#' The seed, a positive integer, for random number generation of
#' Stan. The default is generated from 1 to the maximum integer
#' supported by R so fixing the seed of R's random number generator
#' can essentially fix the seed of Stan. When multiple chains are
#' used, only one seed is needed, with other chains' seeds being
#' generated from the first chain's seed to prevent dependency among
#' the random number streams for the chains. When a seed is
#' specified by a number, ‘as.integer’ will be applied to it. If
#' ‘as.integer’ produces ‘NA’, the seed is generated randomly. We
#' can also specify a seed using a character string of digits, such
#' as ‘"12345"’, which is converted to integer.
#'
#' @return
#' An object of reference class \code{"gmo"}. It is a list containing
#' the following components:
#' \item{par}{a vector of optimized parameters}
#' \item{cov}{estimated covariance matrix at \code{par}}
#' \item{sims}{\code{draws * inner_iter} many samples from the last
#' approximation to the conditional posterior, p(alpha | y, phi)}
#'
#' @import methods
#' @importFrom rstan stan optimizing vb sampling constrain_pars log_prob grad_log_prob stan_model
gmo_approx <- function(file, local_file, full_model, data,
method=c("laplace", "vb", "sampling"), init="random",
draws=5L, iter=100L, inner_iter=10L, cond_iter=NA, eta=1,
tol=1e-3, seed=1234L) {
if (missing(full_model)) {
full_model <- suppressMessages(
stan(file, data = c(data, list(GMO_FLAG = FALSE, fixed_phi = double())),
chains = 0, iter = 1))
}
else stopifnot(is(full_model, "stanfit"))
local_model <- if (!missing(local_file)) stan_model(local_file) else
get_stanmodel(full_model)
obj <- GMO$new(
calc_log_p="approx",
full_model=full_model,
local_model=local_model,
data=data,
method=match.arg(method),
init=init,
draws=draws,
iter=iter,
inner_iter=inner_iter,
cond_iter=structure(cond_iter, class="integer"),
eta=eta,
tol=tol,
seed=seed
)
obj$run()
return(obj)
}
GMO <- setRefClass("gmo",
fields=list(
par="numeric",
cov="matrix",
sims="matrix",
g_alpha="stanfit",
alpha="list",
num_par="integer",
num_alpha="integer",
full_model="stanfit",
two_models="logical",
local_model="stanmodel",
data="ANY",
method="character",
draws="integer",
iter="integer",
inner_iter="integer",
cond_iter="integer",
eta="numeric",
tol="numeric",
seed="integer",
max_block_size="numeric",
use_sigma_points="logical",
.cond_infer="function",
.sample="function",
.calc_log_p="function",
.log_p="function"
), methods=list(
initialize = function(calc_log_p, full_model, local_model, data, method,
init, draws, iter, inner_iter, cond_iter, eta, tol, seed, max_block_size, use_sigma_points) {
if (identical(init, "random")) {
# Note that in stan data, par must be the parameter phi.
# We could generalize this to count all parameters which
# are in full_model but not the local model, although this
# requires running g_alpha.
num_par <<- as.integer(full_model@par_dims$phi)
par <<- rnorm(num_par, 0, 0.01)
} else {
par <<- init
num_par <<- length(init)
}
full_model <<- full_model
local_model <<- local_model
two_models <<- !identical(local_model, get_stanmodel(full_model))
data <<- c(data, list(GMO_FLAG = TRUE))
method <<- method
draws <<- draws
use_sigma_points <<- use_sigma_points
inner_iter <<- inner_iter
iter <<- iter
cond_iter <<- cond_iter
eta <<- eta
tol <<- tol
seed <<- seed
alpha <<- structure("random", class="list")
# Initialize functions that depend on static arguments.
# This avoids having to run an if-else chain inside the function
# itself, which would be called at each iteration that the
# function is called.
#
# When implementing GMO in Stan's C++, this would be done
# automatically using metaprogramming tricks.
if (method == "laplace") {
if (is.na(cond_iter)) {
.cond_infer <<- function(data) {
sink(file="/dev/null", type=c("output", "message"))
fit <- optimizing_new_method(local_model, data=data,
seed=seed, init=0,
as_vector=FALSE, max_block_size=max_block_size,
draws=inner_iter*draws, constrained=FALSE,
use_sigma_points=use_sigma_points)
closeAllConnections()
g_alpha <<- structure(fit, class="stanfit")
alpha <<- g_alpha$par
}
} else {
.cond_infer <<- function(data) {
sink(file="/dev/null", type=c("output", "message"))
fit <- optimizing_new_method(local_model, data=data,
seed=seed, init=alpha,
as_vector=FALSE, max_block_size=max_block_size,
draws=inner_iter*draws, constrained=FALSE,
iter=cond_iter, use_sigma_points=use_sigma_points)
closeAllConnections()
g_alpha <<- structure(fit, class="stanfit")
alpha <<- g_alpha$par
}
}
.sample <<- function(m) {
alpha_sims <- g_alpha$theta_tilde[(m-1)*draws + 1:draws, ]
if (draws == 1) {
alpha_sims <- matrix(alpha_sims, nrow=1)
}
return(alpha_sims)
}
.log_p <<- function(alpha_sims, m, g_flag) {
return(.log_p_laplace(alpha_sims, m, g_flag))
}
} else if (method =="vb") {
if (is.na(cond_iter)) {
.cond_infer <<- function(data) {
sink(file="/dev/null", type=c("output", "message"))
g_alpha <<- vb(local_model, data=data,
seed=seed, init=alpha,
output_samples=inner_iter*draws)
closeAllConnections()
alpha <<- structure("random", class="list")
#alpha <- alpha_mean(g_alpha)
}
} else {
.cond_infer <<- function(data) {
sink(file="/dev/null", type=c("output", "message"))
g_alpha <<- vb(local_model, data=data,
seed=seed, init=alpha,
output_samples=inner_iter*draws,
iter=cond_iter)
#adapt_engaged=FALSE, # TODO use only the first time
closeAllConnections()
alpha <<- structure("random", class="list")
#alpha <- alpha_mean(g_alpha)
}
}
.sample <<- function(m) {
# TODO avoid if-else chain within the function
if (length(num_alpha) == 0) {
pars <- get_stan_params(g_alpha)
num_alpha <<- count_params(g_alpha, pars)
}
alpha_sims <- matrix(
unlist(attributes(g_alpha)$sim$samples[[1]][1:num_alpha]),
ncol=num_alpha)[(m-1)*draws + 1:draws, ]
return(alpha_sims)
}
.log_p <<- function(alpha_sims, m, g_flag) {
return(.log_p_vb(alpha_sims, m, g_flag))
}
} else if (method == "sampling") {
.cond_infer <<- function(data) {
# For sampling, cond_iter is always equal to 2*inner_iter*draws.
sink(file="/dev/null", type=c("output", "message"))
g_alpha <<- sampling(local_model, data=data,
iter=2*inner_iter*draws, chains=1,
seed=seed, init=alpha)
closeAllConnections()
alpha <<- structure("random", class="list")
#alpha <<- alpha_mean(g_alpha)
}
.sample <<- function(m) {
if (length(num_alpha) == 0) {
pars <- get_stan_params(g_alpha)
num_alpha <<- count_params(g_alpha, pars)
}
alpha_sims <- extract(g_alpha,
permuted=FALSE)[(m-1)*draws + 1:draws,,1:num_alpha]
if (draws == 1) {
alpha_sims <- matrix(alpha_sims, nrow=1)
}
return(alpha_sims)
}
.log_p <<- function(alpha_sims, m, g_flag) {
return(.log_p_sampling(alpha_sims, m, g_flag))
}
} else {
stop("Conditional inference method not valid.")
}
if (calc_log_p == "exact") {
.calc_log_p <<- function(alpha_sims, m) {
density_sims <- .log_p(alpha_sims, m, g_flag=TRUE)
if (draws < 20000) {
log_r <- density_sims$log_p - density_sims$log_g
} else {
log_r <- psislw(density_sims$log_p - density_sims$log_g)$lw_smooth
}
max_log_r <- max(log_r)
r <- exp(log_r - max_log_r)
# Note that weighted.mean normalizes the importance ratios
return(list(fn=max_log_r + log(mean(r)),
grad=apply(density_sims$grad_log_p, 2, weighted.mean, r)))
}
} else {
.calc_log_p <<- function(alpha_sims, m) {
# This outputs only the energy; we drop the entropy term for
# computational reasons and because we assess the density
# value only for convergence diagnostics.
density_sims <- .log_p(alpha_sims, m, g_flag=FALSE)
return(list(fn=mean(density_sims$log_p),
grad=apply(density_sims$grad_log_p, 2, mean)))
}
}
},
run = function() {
diagnostic <- Diagnostic$new(tol)
opt <- Opt$new(eta)
for (tee in 1:iter) {
print(sprintf("Iteration: %s", tee))
print(par)
if (two_models) .cond_infer(c(data, list(phi=par)))
else .cond_infer(c(data, list(fixed_phi=par)))
for (m in 1:inner_iter) {
alpha_sims <- .sample(m)
log_p <- .calc_log_p(alpha_sims, m)
par <<- opt$update_params(par, log_p$grad, (tee-1)*inner_iter + m)
}
flags <- diagnostic$check_converge(par, log_p$fn, log_p$grad)
if (sum(flags) > 0) {
print("Optimization terminated normally:")
print(.get_code_string(flags))
cov <<- matrix(0, 1,1) # est_covariance(par, log_p=log_p, alpha_sims=alpha_sims, m=m)
sims <<- .collect_alpha_sims()
return()
}
}
print("Maximum number of iterations hit, may not be at an optima")
cov <<- matrix(0,1,1) #est_covariance(par)
sims <<- .collect_alpha_sims()
return()
},
.collect_alpha_sims = function() {
alpha_sims <- .sample(1)
if (inner_iter > 1) {
for (m in 2:inner_iter) {
alpha_sims <- rbind(alpha_sims, .sample(m))
}
}
return(alpha_sims)
},
.log_p_laplace = function(alpha_sims, m, g_flag=TRUE) {
grad_log_p_sims <- array(NA, c(draws, num_par))
sink(file="/dev/null")
for (s in 1:draws) {
grad_log_p_sims[s, ] <- grad_log_prob(full_model,
c(par, alpha_sims[s, ]),
adjust_transform=FALSE)[1:num_par]
}
closeAllConnections()
log_p_sims <- g_alpha$log_p[(m-1)*draws + 1:draws]
if (g_flag) {
log_g_sims <- g_alpha$log_g[(m-1)*draws + 1:draws]
} else {
log_g_sims <- NA
}
return(list(grad_log_p=grad_log_p_sims,
log_p=log_p_sims,
log_g=log_g_sims))
},
.log_p_vb = function(alpha_sims, m, g_flag=TRUE) {
# vb()/sampling() cannot output unconstrained samples via
# constrained=FALSE. Further, to unconstrain_pars() requires a
# list of each alpha's parameter name and data structure.
#
# To address this, we constrain phi, calculate log_prob on
# constrained space without adjusting ("log_prob") and with
# adjusting ("log_prob_Jt").
#
# Then
# 2*"log_prob" - "log_prob_Jt" = "log_prob" - Jt
# is log_prob on the constrained space adjusting to the
# minus log determinant of the Jacobian. This is the same as
# log_prob on the unconstrained space.
par_const <- constrain_pars(full_model,
upars=c(par,
rep(0, ncol(alpha_sims))))[[1]]
grad_log_p_sims <- array(NA, c(draws, num_par))
log_p_sims <- rep(NA, draws)
sink(file="/dev/null")
for (s in 1:draws) {
grad_log_p_sims[s, ] <- 2*grad_log_prob(full_model,
c(par_const, alpha_sims[s, ]),
adjust_transform=FALSE)[1:num_par] -
grad_log_prob(full_model,
c(par_const, alpha_sims[s, ]),
adjust_transform=TRUE)[1:num_par]
log_p_sims[s] <- 2*log_prob(full_model,
c(par_const, alpha_sims[s, ]),
adjust_transform=FALSE) -
log_prob(full_model,
c(par_const, alpha_sims[s, ]),
adjust_transform=TRUE)
}
closeAllConnections()
if (g_flag) {
# TODO we require (preferably unconstrained)
# log_g_sims = log_q outputted via variational.log_q
# See aed4c42c4beebed0d43d1e1b868ef795a760fe9e in stan-dev/stan.
# Ideally, we could also get (unconstrained)
# log_p_sims = lp__ outputted via model_.template log_prob<false, false>
# See ba4acfad356a85d2ebde79d82f6bcdced7034ca7 in stan-dev/stan.
log_g_sims <- attributes(g_alpha)$sim$samples[[1]]$lp__
} else {
log_g_sims <- NA
}
return(list(grad_log_p=grad_log_p_sims,
log_p=log_p_sims,
log_g=log_g_sims))
},
.log_p_sampling = function(alpha_sims, m, g_flag=TRUE) {
return(.log_p_vb(alpha_sims, m, g_flag))
## TODO
## using the awful harmonic mean estimator
#grad_log_p_sims <- array(NA, c(draws, num_par))
#log_p <- rep(NA, draws)
#for (s in 1:draws) {
# grad_log_p_sims[s, ] <- grad_log_prob(full_model,
# c(par, alpha_sims[s, ]),
# adjust_transform=FALSE)[1:num_par]
# log_p_sims[s] <- log_prob(full_model,
# c(par, alpha_sims[s, ]),
# adjust_transform=FALSE)
#}
},
.get_code_string = function(flags) {
if (flags[1]) {
return(" Convergence detected: absolute parameter change was below tolerance")
} else if (flags[2]) {
return(" Convergence detected: absolute change in objective function was below tolerance")
} else if (flags[3]) {
return(" Convergence detected: relative change in objective function was below tolerance")
} else if (flags[4]) {
return(" Convergence detected: gradient norm is below tolerance")
} else {
return(" Convergence detected: relative gradient magnitude is below tolerance")
}
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.