knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

Introduction {#intro}

The moretrees package fits Multi-Outcome Regression with Tree-structured Shrinkage (MOReTreeS) models to case-crossover and matched case-control data when being a 'case' can mean experiencing one of multiple related outcomes. MOReTreeS models require that these outcomes are related according to a known tree. For example, the outcomes might be multiple distinct cardiovascular diseases that can be labelled using a hierarchical, tree-like classification system, as in the worked example in Section 3 below.

The MOReTreeS model simultaneously achieves three core functions:

  1. discovers de novo groups of outcomes that have a similar association with the exposure conditional on matching factors and covariates, with both the number and composition of the groups estimated from the data;
  2. estimates a group-specific measure of association with the exposure for each outcome group;
  3. performs information sharing across the outcomes by shrinking the coefficient estimates for different outcomes towards each other, with the degree of shrinkage determined by how closely related those outcomes are on the tree.

The model is fit using variational inference (VI), a fast alternative to Markov chain Monte Carlo (MCMC) for approximating posterior distributions of Bayesian models.

This vignette is divided into three sections. Section 2 provides details of the MOReTreeS model implemented by the package. Section 3 lays out the functionality of the moretrees package in a readable form and guides potential users through an example of fitting the model using the moretrees() function. Finally, Section 4 explains the VI algorithm used to fit the model. Section 4 is not essential reading for most users, but rather aims to provide sufficient detail about the algorithm that someone with knowledge of VI would be able to write their own code to fit the model.

Model details {#model}

We begin by giving details of the model implemented by the moretrees package. The model is very similar to the one presented in @thomas2020estimating, with extensions to allow for: (1) a multivariate exposure; (2) control for covariates in regression analyses, and; (3) more flexible hyperparameter specification. MOReTreeS priors can be used with a range of different outcome distributions or likelihoods; the moretrees() function implements the conditional logistic likelihood used in conditional logistic regression (CLR) analyses, the standard method for both matched case-control studies and case-crossover studies. We first define the CLR likelihood used in our study, and then the MOReTreeS prior.

Likelihood

Let $\mathcal{V}$ be the set of outcomes under study. For outcome $v$ in $\mathcal{V}$, let $Y_{i,j}^{(v)}$ be a binary variable indicating whether a hospitalization for disease $v$ occurred on day $j$ for individual $i$, where without loss of generality we assume $j = 1$ represents the case day and $j = 2$ represents the control day. Let $\mathbf{x}{i,j}^{(v)}$ be the exposure vector of length $k$ and $\mathbf{w}{i,j}^{(v)}$ be a vector of covariates of length $m$. For notational convenience, we let $\mathbf{x}i^{(v)} = \mathbf{x}{i,1}^{(v)} - \mathbf{x}{i,2}^{(v)}$ and $\mathbf{w}_i^{(v)} = \mathbf{w}{i,1}^{(v)} - \mathbf{w}_{i,2}^{(v)}$.

The conditional logistic likelihood for each matched pair is shown below [@maclure1991case].

\begin{equation} \tag{1} P\left(Y_{i,1}^{(v)} = 1, Y_{i,2}^{(v)} = 0 \middle\vert Y_{i,1}^{(v)} + Y_{i,2}^{(v)} = 1, \boldsymbol{\beta}v, \boldsymbol{\theta}_v \right) = \textrm{logit} \left( \boldsymbol{\beta}_v^T \mathbf{x}{i}^{(v)} + \boldsymbol{\theta}_v^T \mathbf{w}_i^{(v)} \right). \end{equation}

where $\boldsymbol{\beta}$ is the coefficient vector for the effect of the exposure on outcome $v$, $\boldsymbol{\theta}_v$ is the vector of covariate effects for outcome $v$, and $\textrm{logit}(y) = 1 / (1+ e^{-y})$ is the logistic or sigmoid function. The full data likelihood is obtained by multiplying over all matched pairs.

Prior

The MOReTreeS spike and slab prior enables data-driven discovery of groups of outcomes that have a similar relationship with PM$_{2.5}$ [@thomas2020estimating]. Let $\mathcal{V}$ be the set of nodes on the tree of outcome relationships, and let $\mathcal{V}_L \subset \mathcal{V}$ be the set of leaves. Internal nodes $u \in \mathcal{V}$ represent outcome categories and leaf nodes $v \in \mathcal{V}_L$ represent individual outcomes. Let $a(v)$ be the set of ancestors nodes of $v \in \mathcal{V}_L$, including $v$ itself. Finally, let $l_u \in \lbrace 1, \dots, L \rbrace$ indicate which hyperparameters apply to variables at node $u \in \mathcal{V}$ on the tree. We explain below what is meant by this.

The prior is:

\begin{align} \boldsymbol{\beta}v & = \sum{u \in a(v)} s_u \boldsymbol{\gamma}u \ \boldsymbol{\theta}_v & = \sum{u \in a(v)} \boldsymbol{\zeta}u \ s_u & \sim Bern(\rho{l_u}) \ \boldsymbol{\gamma}u & \sim \mathcal{MVN}\left(\boldsymbol{0},\tau{l_u} I_k\right) \ \boldsymbol{\zeta}u & \sim \mathcal{MVN} \left( \boldsymbol{0}, \omega{l_u} I_m \right) \ \rho_{l} & \sim Beta(a_l, b_l). \end{align}

Here the 'node parameters' $s_u \boldsymbol{\gamma}_u$ and $\boldsymbol{\zeta}_u$ can be interpreted as the difference in the exposure effect for outcome category $u$ compared to its parent category. $\rho_l$, $\tau_l$, and $\omega_l$ for $l = 1, \dots, L$ are hyperparameters representing, respectively, the prior probability of inclusion for node parameters $\boldsymbol{\gamma}_u$, the prior variance of all elements of $\boldsymbol{\gamma}_u$, and the prior variance of all elements of $\boldsymbol{\zeta}_u$ for $u$ such that $l_u = u$. The $l_u$ are fixed values from 1 through $L$ that specify which nodes on the tree share common hyperparameters $\rho_l$, $\tau_l$, and $\omega_l$. The default used by moretrees() is to have two sets of hyperparameters: one for all internal nodes ($l_u = 1$ for $u \in \mathcal{V} \backslash \mathcal{V}_L$), and one for all leaf nodes ($l_u = 2$ for $u \in \mathcal{V}_L$). $\tau_l$ and $\omega_l$ are estimated via approximate empirical Bayes as discussed below in Section 4.4; we place a conditionally conjugate beta hyperprior on $\rho_l$.

The above hierarchical model leads to the formation of groups of outcomes with the same exposure effects by fusing (setting equal) coefficients for different outcomes. This fusion process is achieved as follows. The coefficients $\boldsymbol{\beta}{v_1}$ and $\boldsymbol{\beta}{v_2}$ for any two outcomes $v_1, v_2 \in \mathcal{V}L$ will be exactly equal if and only if $s_u = 0$ for all $u \in d(v_1, v_2)$, where $d(v_1, v_2) = (a(v_1) \cup a(v_2)) \backslash (a(v_1) \cap a(v_2))$ is the set of nodes that are ancestors of either $v_1$ or $v_2$ but not both. The result is that the prior probability that $\boldsymbol{\beta}{v_1}$ and $\boldsymbol{\beta}_{v_2}$ are equal, meaning $v_1$ and $v_2$ are grouped together, depends on both on the set of nodes $d(v_1, v_2)$ that 'separate' the two leaf nodes $v_1$ and $v_2$ on the tree, as well as on the hyperparameters $\rho_l$. Specifically, conditional on the $\rho_l$,

$$P\left(\mathbf{\beta}{v_1} = \mathbf{\beta}{v_2} \vert \rho_1, \dots, \rho_L \right) = \prod_{l=1}^L (1 - \rho_l)^{\left\vert d(v_1, v_2) \cap \lbrace u: l_u = l \rbrace \right\vert}.$$ Let $n(v_1, v_2, l) = \left\vert d(v_1, v_2) \cap \lbrace u: l_u = l \rbrace \right\vert$. Integrating over the parameters $\rho_l$, we find:

$$P\left(\mathbf{\beta}{v_1} = \mathbf{\beta}{v_2}\right) = \prod_{l = 1}^L \left(\prod_{r = 0}^{n(v_1,v_2,l)} \dfrac{b_l + r}{a_l + b_l + r}\right)^{\mathbf{1}_{n > 0}\left(n(v_1,v_2,l)\right)}. $$ This expression can be derived using the higher moments of the Beta distribution.

Posterior estimates

moretrees obtains posterior estimates for the exposure effects $\boldsymbol{\beta}_v$ in the same manner as described by @thomas2020estimating. Briefly, the moretrees() function returns estimates from the \emph{median probability model} that includes only those node variables with marginal posterior inclusion probability greater than 0.5 [@barbieri2004optimal]. As discussed in the previous section, a node is included in the model if $s_u = 1$. Thus, a node will be included in our final model if the posterior probability that $s_u = 1$ is greater than 0.5.

moretrees(): functionality and usage {#code}

Functionality

The moretrees() function fits MOReTrees models and supports the following functionality:

  1. Fitting MOReTreeS models to pair matched case-control or case-crossover data where each case is associated with exactly one outcome among a set of related outcomes;
  2. Multiple exposure types: moretrees() supports any exposure type by supplying appropriate design matrices for the cases and controls;
  3. Control for confounding in regression model: confounders which could not be matched on can be controlled for in the regression model; information about covariate effect sizes will be shared across the outcomes;
  4. Hyperparameter selection: the user need not specify the values of hyperparameters which are instead estimated via an approximate empirial Bayes scheme, as detailed in Section 4.4.

Worked example

We illustrate use of the moretrees() function with a small simulated dataset, moretreesCaseCrossover, which has $n = 1000$ case-crossover pairs, $p = 11$ outcomes, an exposure with dimension $k = 2$, and $m = 2$ covariates. The data are stored as a list with the following elements:

  1. Xcase: an n by k matrix of exposure values for cases;
  2. Xcontrol: the corresponding exposure matrix for matched controls--- recall that only one control per case is allowed;
  3. Wcase: an n by m matrix of covariate values for cases;
  4. Wcontrol: the corresponding covariate matrix for matched controls;
  5. outcomes: a character vector of length n indicating which outcome was experienced by each case.

The elements of outcomes are character strings representing cardiovascular diseases in category 7.4 (diseases of arteries, arterioles, and capillaries) in the multilevel Clinical Classifications Software (CCS) hierarchical disease classification system [@healthcare2018clinical].

First, we load the edge list for our outcome tree and create an igraph object, tr, to represent the tree. The edge list is a matrix with two columns: the left column represents parent nodes or categories, and the right column represents children or subcategories of the parents. There is one row for every parent-child pair. The set of names of leaf nodes in tr should be equal to the set of unique values in the outcomes vector. We then plot the tree using the ggtree package from bioconductor.

library(moretrees)
data("moretreesExampleEdgelist")
library(igraph)
tr <- graph_from_edgelist(moretreesExampleEdgelist, directed = TRUE)
# Plot tree
# If needed, first install ggtree
# if (!requireNamespace("BiocManager", quietly = TRUE))
# install.packages("BiocManager")
# BiocManager::install("ggtree")
library(ggtree)
library(ggplot2)
ggtree(tr, ladderize = F, layout = "slanted") + 
  geom_tiplab(geom = "label") + 
  geom_nodelab(geom = "label") +
  theme(plot.margin = unit(c(0, 1.5, 0, 0.2), "cm")) +
  coord_cartesian(clip = "off") + 
  scale_y_reverse()

Next, we load the data, check that the unique values in outcomes are the same as the leaves of the tree (moretrees() will throw an error if the latter isn't true, but it's good practice to double-check before hand), and print the number of matched pairs per outcome.

data("moretreesExampleData")
Xcase <- moretreesExampleData$Xcase
Xcontrol <- moretreesExampleData$Xcontrol
Wcase <- moretreesExampleData$Wcase
Wcontrol <- moretreesExampleData$Wcontrol
outcomes <- moretreesExampleData$outcomes
leaves <- names(V(tr)[degree(tr, mode = "out") == 0])
setequal(leaves, unique(outcomes))
knitr::kable(table(outcomes))

We see that there are between 68 and 118 matched pairs per outcome. Note that each case is associated with exactly one outcome; the model assumes cases do not experience multiple outcomes.

Model without covariate adjustment

We now fit the MOReTreeS model without adjustment for covariates.

mod1 <- moretrees::moretrees(Xcase = Xcase, Xcontrol = Xcontrol,
                 outcomes = outcomes, tr = tr,
                 nrestarts = 1, print_freq = 500)

We see that the model ran for at least 2000 iterations before converging at a tolerance of tol = 1E-8 (the default). The option print_freq determines how frequently the algorithm's progress will be printed.

We now print the estimated odds ratios (or rate ratios, for a case-crossover model) for the groups discovered by MOReTreeS. The table also shows which outcomes belong to each group and the number of observations associated with outcomes in each group.

print(mod1, digits = 2)

We can also plot the discovered groups on the outcome tree:

plot(mod1, layout = "slanted", horizontal = FALSE)

Finally, for comparison, we inspect the group-specific estimates obtained by running a classical CLR model on the data for each discovered group separately. Here, we use a more compact printing method to save space.

print(mod1, compact = T, digits = 2, 
      coeff_type = "clr", print_outcomes = FALSE)

Model with covariate adjustment

Next, we fit the MOReTreeS model with covariate adjustment. For parameters that were estimated in the model without covariate adjustment, we use the final model values as initial values in the adjusted model. This approach may be useful for speeding up models with covariates, which can be slower to converge than unadjusted models.

vi_params_init <- mod1$mod$vi_params[c("prob", "mu", "Sigma",
                                       "Sigma_inv", "Sigma_det",
                                       "tau_t", "a_t", "b_t")]
hyperparams_init <- mod1$mod$hyperparams[c("eta", "tau")]
mod2 <- moretrees::moretrees(Xcase = Xcase, Xcontrol = Xcontrol,
                  Wcase = Wcase, Wcontrol = Wcontrol,
                  vi_params_init = vi_params_init,
                  hyperparams_init = hyperparams_init,
                  outcomes = outcomes, tr = tr,
                  nrestarts = 1, print_freq = 1000)

The warning indicates that the default maximum number of iterations, max_iter = 5000, was reached before the model converged. We therefore increase max_iter and run the model again, using as new starting values the final values from the run which failed to converge.

vi_params_init <- mod2$mod$vi_params
hyperparams_init <- mod2$mod$hyperparams
mod2 <- moretrees::moretrees(Xcase = Xcase, Xcontrol = Xcontrol,
                  Wcase = Wcase, Wcontrol = Wcontrol,
                  vi_params_init = vi_params_init,
                  hyperparams_init = hyperparams_init,
                  outcomes = outcomes, tr = tr,
                  nrestarts = 1, print_freq = 1000,
                  max_iter = 10000)

The model has now converged. Next, we print the estimated odds ratios for the discovered groups adjusted for the covariates in W. We suppress printing of the outcomes associated with each group, since in this example these happen to be the same as for the model with no covariates.

print(mod2, compact = T, digits = 2, print_outcomes = FALSE)

Using random restarts

Note that in the example above, we have used only nrestarts = 1 random restarts for illustration purposes. This means we have chosen only one non-random set of initial values for the VI algorithm. However, in practice we recommend using nrestarts > 1. This will cause moretrees() to run multiple versions of the VI algorithm from different random initial values. These restarts may converge to different local optima; if so, the "best" restart will be returned (highest value of the objective function). Here, we use the initial values from mod2 with some randomness added for each restart.

The unevaluated code chunk below illustrates how to register a parallel backend for running three random restarts on three cores. When nrestarts > 1, moretrees() attempts to run the restarts on multiple cores by default. To prevent this behaviour and run the restarts on one core, use the option parallel = FALSE. The progress of restart i will be logged to a text file restart_logs/restart_i_log.txt which is deleted at the end of the run. moretrees returns the results from the 'best' restart, and by default stores the results from the other restarts in the list item mod_restarts. For large models, preserving the restarts can lead to memory problems; these can be discarded using the option keep_restarts = FALSE.

nrestarts <- 3
doParallel::registerDoParallel(cores = nrestarts)
set.seed(345083)
log_dir <- "restart_logs"
dir.create(log_dir)
mod3 <- moretrees::moretrees(Xcase = Xcase, Xcontrol = Xcontrol,
                  Wcase = Wcase, Wcontrol = Wcontrol,
                  outcomes = outcomes, tr = tr,
                  vi_params_init = mod2$mod$vi_params,
                  hyperparams_init = mod2$mod$hyperparams,
                  nrestarts = nrestarts, print_freq = 500,
                  log_restarts = TRUE, log_dir = log_dir)
unlink(log_dir, recursive = T)
# Print the final ELBO for all restarts
print(c(mod3$mod$hyperparams$ELBO,
  sapply(mod3$mod_restarts, function(mod) mod$hyperparams$ELBO)),
  digits = 8)

Details of variational inference algorithm {#algo}

VI is a method for approximating the posterior distribution of a Bayesian model. This is achieved by minimizing the "distance" between a pre-specified family of probability distributions $q$ and the true posterior $\pi$, where distance is typically defined as the Kullback-Leibler divergence $K(q || \pi)$. Using this technique, VI finds a deterministic approximation for $\pi$, in contrast to MCMC which attempts to sample from $\pi$. A detailed introduction to VI is beyond the scope of this vignette; for an accessible introduction for statisticians, see @blei2017variational.

This section of the vignette is not required reading. Rather, it is intended for readers with some knowledge of variational Bayes who may wish to verify the details of the VI algorithm used by moretrees.

Preliminaries

We take a mean field VI apprach and assume that the variational approximation $q$ factorizes as $q\left(\boldsymbol{\gamma},\mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) = \prod_{u \in \mathcal{V}} q \left( \boldsymbol{\theta}u \right) q\left(\boldsymbol{\gamma}_u, s_u \right) \prod{l=1}^L q(\rho_l)$, where $\boldsymbol{\gamma}$, $\mathbf{s}$, $\boldsymbol{\theta}$, and $\boldsymbol{\rho}$ are vectors containing all $\boldsymbol{\gamma}_u$, $\mathbf{s}_u$, $\boldsymbol{\theta}_u$ and $\boldsymbol{\rho}_l$ respectively. Rather than minimizing $K(q || \pi)$, which depends on the intractable posterior $\pi$, we maximize a quantity known as the evidence lower bound (ELBO), $\mathcal{E}(q)$, which is equal to $- K(q || \pi)$ up to an additive constant that does not depend on $q$:

\begin{equation} \tag{2} \mathcal{E}(q) = E_{q} \left[\log f \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] - E_{q} \left[\log q \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] \end{equation} where $f \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right)$ is the joint density of data and parameters, and $E_q$ denotes expectation taken with respect to $q$.

The moretrees packages uses a co-ordinate ascent variational inference (CAVI) algorithm to maximize $\mathcal{E}(q)$ in $q$. However, attempting to maximize $\mathcal{E}(q)$ directly would lead to intractable updates due to the logistic form of the likelihood in Equation 1. We therefore replace this likelihood in the expression for $\mathcal{E}(q)$ with a quadratic approximation that leads to tractable updates in the CAVI algorithm. This approach to VI with a logistic likelihood was first proposed by @jaakkola2000bayesian. Here, we state the result; see @bishop2016pattern [Section 10.6] for a detailed derivation.

First, the joint density of data and parameters is:

\begin{align} f \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) & = \prod_{v \in \mathcal{V}l} \prod{i = 1}^{n_v} \left(1 + \exp \left[-\left(\sum_{u \in a(v)} s_u \boldsymbol{\gamma}u^T \mathbf{x}{i}^{(u)} + \boldsymbol{\theta}u^T \mathbf{w}_i^{(u)} \right) \right] \right)^{-1} \ & \times \prod{u \in \mathcal{V}} \rho_{l_u}^{s_u} (1-\rho_{l_u})^{1-s_u} \dfrac{1}{\left(2\pi\tau_{l_u}\right)^{k/2}} \exp\left[-\dfrac{\boldsymbol{\gamma}u^T\boldsymbol{\gamma}_u}{2\tau{l_u}}\right] \ & \times \prod_{u \in \mathcal{V}} \dfrac{1}{\left(2\pi\omega_{l_u}\right)^{m/2}} \exp\left[-\dfrac{\boldsymbol{\theta}u^T\boldsymbol{\theta}_u}{2\omega{l_u}}\right] \ & \times \prod_{l=1}^L \dfrac{\rho_l^{a_l -1} (1 - \rho_l)^{b_l - 1}}{\mathcal{B}(a_l, b_l)}. \end{align}

Next, let

$$r_i^{(v)} = \sum_{u \in a(v)} s_u \boldsymbol{\gamma}u^T \mathbf{x}{i}^{(u)} + \boldsymbol{\theta}_u^T \mathbf{w}_i^{(u)}.$$

Then a lower bound for the likelihood contribution from a single matched pair is [@jaakkola2000bayesian]: \begin{align} \left(1 + \exp(-r_i^{(v)}) \right)^{-1} & \geq \sigma(\eta_i^{(v)}) \exp \left[ r_i^{(v)} - \left(r_i^{(v)} + \eta_i^{(v)} \right) /2 -g(\eta_i^{(v)}) \left[r_i^{(v)2} - \eta_i^{(v)2} \right] \right] \ & := h\left(\eta_i^{(v)}, r_i^{(v)} \right). \end{align}

where $\sigma(x) = (1 + e^{-x})^{-1}$ and $g(x) = \dfrac{1}{2x}(\sigma(x) - 1/2)$. This inequality hold for all values of $\eta_i^{(v)}$, an introduced variational parameter.

Let $h\left(\boldsymbol{\eta}, \boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) = \prod_{v \in \mathcal{V}L} \prod{i=1}^{n_v} h\left(\eta_i^{(v)}, r_i^{(v)} \right)$. We get the following approximation to the joint density of data and parameters:

\begin{equation} f \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \geq h\left(\boldsymbol{\eta}, \boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \pi \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \end{equation}

where $\pi \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right)$ is the prior. We ensure this lower bound is as tight as possible by maximizing it in $\boldsymbol{\eta}$ at every step.

Replacing the joint density of data and parameters in Equation 2 with the above lower bound gives the following approximation to the ELBO:

\begin{align} \mathcal{E}^(q) & := E_{q} \left[\log h\left(\boldsymbol{\eta}, \boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \pi \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] - E_{q} \left[\log q \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] \leq \mathcal{E}(q). \end{align*}

Co-ordinate ascent variational inference algorithm

The CAVI algorithm involves maximizing $\mathcal{E}^*(q)$ in each of the factors $q\left(\boldsymbol{\gamma}_u, s_u \right)$, $q \left( \boldsymbol{\theta}_u \right)$, and $q(\rho_l)$ in turn. In what follows we use the subscript $t$ or superscript $(t)$ to denote the next iteration or update in the CAVI algorithm; the absence of a subscript indicates that the last or most recent value of that parameter should be used.

The updates for each factor in the CAVI algorithm can be obtained using the following well-known result (see Blei 2017 for a derivation). As an example, consider the factor $q\left(\boldsymbol{\gamma}u, s_u \right)$. Let $q{-\left(\boldsymbol{\gamma}_u, s_u \right)}$ be the variational distribution marginalized over $(\boldsymbol{\gamma}_u, s_u)$. Then the $t^{th}$ coordinate ascent update for $q(\boldsymbol{\gamma}_u, s_u)$ is given by:

$$ q_t \left( \boldsymbol{\gamma}u, s_u \right) \propto \exp \left( E{q_{-\left(\boldsymbol{\gamma}_u, s_u \right)}} \left[ \log f \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] \right).$$ However, here replace the joint density $f$ with its quadratic approximation to get

$$q_t \left( \boldsymbol{\gamma}u, s_u \right) \propto \exp \left( E{q_{-\left(\boldsymbol{\gamma}_u, s_u \right)}} \left[ \log h\left(\boldsymbol{\eta}, \boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \pi \left(\boldsymbol{\gamma}, \mathbf{s}, \boldsymbol{\theta}, \boldsymbol{\rho} \right) \right] \right).$$

The equivalent results hold for all other factors $q(\rho_l)$, $q(\boldsymbol{\theta}_u)$, and $q(\boldsymbol{\gamma}_v, s_v)$ for $v \neq u$.

Below, we give the CAVI updates for each factor. These updates can be straightforwardly derived from the above formula. In the moretrees package, these updates are coded in the internal function update_vi_params_logistic_moretrees().

Update for $q(\boldsymbol{\gamma}_u, s_u)$

$$ q_{t}\left( \boldsymbol{\gamma}u, s_u \right) = \mathcal{MVN}\left(\boldsymbol{\gamma}_u \middle\vert s_u \boldsymbol{\mu}_u^{(t)}, s_u \Sigma_u^{(t)} + (1-s_u) \tilde{\tau}_u^{(t)} I{k} \right) p_u^{(t)s_u} \left(1-p_u^{(t)}\right)^{1-s_u}$$

where

\begin{align} \left(\Sigma_u^{(t)}\right)^{-1} & = \dfrac{1}{\tau_{l_u}} I_{k} + 2 \sum_{v: u \in anc(v)} \eta_i^{(v)} \mathbf{x}i^{(v)} \mathbf{x}_i^{(v)T} \ \tilde{\tau}_u^{(t)} & = \tau{l_u} \ \boldsymbol{\mu}u^{(t)} & = \Sigma_u^{(t)} \sum{v: u \in anc(v)} \sum_{i = 1}^{n_v} \mathbf{x}i^{(v)} \left(\dfrac{1}{2} - 2 g \left(\eta_i^{(v)}\right)\left( \sum{u^\prime \in anc(v)} E_{q} \left[ \boldsymbol{\theta}{u^\prime} \right]^T \mathbf{w}_i^{(v)} + \sum{u^\prime \in anc(v) \backslash \lbrace u \rbrace} E_{q} \left[ \boldsymbol{\gamma}{u^\prime} s{u^\prime} \right]^T \mathbf{x}i^{(v)} \right) \right) \ \log \left(\dfrac{p_u^{(t)}}{ 1- p_u^{(t)}} \right) & = E{q} \left[ \log \rho_{l_u} \right] - E_{q} \left[ \log (1-\rho_{l_u})\right] + \dfrac{\boldsymbol{\mu}_u^{(t)T} \left(\Sigma_u^{(t)}\right)^{-1} \boldsymbol{\mu}_u^{(t)}}{2} + \dfrac{1}{2} \left( \log \left\vert \Sigma_u^{(t)} \right\vert - k \log \tilde{\tau}_u^{(t)} \right). \end{align}

Update for $q(\boldsymbol{\theta}_u)$

$$q_t\left( \boldsymbol{\theta}_u \right) = \mathcal{MVN}\left(\boldsymbol{\delta}_u, \Omega_u \right) $$ where

\begin{align} \left(\Omega_u^{(t)}\right)^{-1} & = \dfrac{1}{\omega_{l_u}} I_m + 2 \sum_{v: u \in anc(v)} \sum_{i=1}^{n_v} g(\eta_i^{(v)}) \mathbf{w}i^{(v)} \mathbf{w}_i^{(v)T} \ \boldsymbol{\delta}_u^{(t)} & = \Omega_u^{(t)} \sum{v: u \in anc(v)} \sum_{i=1}^{n_v} \mathbf{w}i^{(v)} \left( \dfrac{1}{2} - 2 g\left(\eta_i^{(v)}\right) \left( \sum{u^\prime \in anc(v) \backslash \lbrace u \rbrace} E_{q} \left[ \boldsymbol{\theta}{u^\prime} \right]^T \mathbf{w}_i^{(v)} + \sum{u^\prime \in anc(v) \rbrace} E_{q} \left[ \boldsymbol{\gamma}{u^\prime} s{u^\prime} \right]^T \mathbf{x}_i^{(v)} \right) \right). \end{align}

Update for $q(\rho_l)$

$$q_{t}(\rho_l) = Beta \left(\tilde{a}_l^{(t)}, \tilde{b}_l^{(t)} \right)$$

where

\begin{align} \tilde{a}l^{(t)} & = a_l + \sum{u: l_u = l} E_{q} \left[ s_u \right] \ \tilde{b}l^{(t)} & = b_l + \sum{u: l_u = l} E_{q} \left[ 1 - s_u \right]. \end{align}

Objective function (evidence lower bound)

At the end of every iteration $t$ we must compute the objective function $\mathcal{E}^(q_t)$ for the current value of our variational approximation $q_t$. We iterate through parameter updates until convergence, i.e., $\left\vert \mathcal{E}^(q_t) - \mathcal{E}^*(q_{t-1}) \right\vert <$ tol.

\begin{align} \mathcal{E}^*(q) & = \sum_{v \in \mathcal{V}L} \sum{i = 1}^{n_v} \left( \dfrac{1}{2}E_q \left[ r_i^{(v)} \right] + \log \sigma\left(\eta_i^{(v)}\right) - \eta_i^{(v)} / 2 + g\left(\eta_i^{(v)}\right) \left(\eta_i^{(v)2} - E_q \left[r_i^{(v)2} \right] \right) \right) \tag{1} \ & - \sum_{u \in \mathcal{V}}\left( \dfrac{E_q \left[\boldsymbol{\gamma}u^T\boldsymbol{\gamma}_u \right]}{2\tau{l_u}} + \dfrac{k}{2} \log(2 \pi \tau_{l_u})\right) \tag{2} \ & + \sum_{u \in \mathcal{V}} E_q \left[ \log \rho_{l_u} \right] E_q \left[ s_u \right] + E \left[ \log \left(1 - \rho_{l_u} \right) \right] \left(1-E \left[s_u\right] \right) \tag{3} \ & - \sum_{u \in \mathcal{V}} \dfrac{E \left[\boldsymbol{\theta}u^T \boldsymbol{\theta}_u \right]}{2 \omega{l_u}} - \dfrac{m}{2} \log (2 \pi \omega_{l_u}) \tag{4} \ & + \sum_{l = 1}^L \left((a_l - 1) E_q \left[\log \rho_l \right] + (b_l - 1) E_q \left[\log (1 - \rho_l) \right] - \log \mathcal{B} (a_l, b_l) \right) \tag{5} \ & + \dfrac{1}{2} \sum_{u \in \mathcal{V}} E_q\left[s_u \left(\boldsymbol{\gamma}u - \boldsymbol{\mu}_u \right)^T \Sigma_u^{-1} \left(\boldsymbol{\gamma}_u - \boldsymbol{\mu}_u \right) \right] + \dfrac{1}{2} E_q\left[ s_u \right] \left( \log \left\vert \Sigma_u \right\vert + k \log (2\pi)\right) \tag{6} \ & + \dfrac{1}{2} \sum{u \in \mathcal{V}} \dfrac{1}{\tilde{\tau}u} E_q\left[(1-s_u) \boldsymbol{\gamma}_u^T \boldsymbol{\gamma}_u \right] + \dfrac{k}{2} \log (2\pi \tilde{\tau}_u) E_q\left[ 1-s_u \right] \tag{7} \ & - \sum{u \in \mathcal{V}} \left( E_q[s_u] \log(p_u) + E_q[1-s_u] \log \left( 1- p_u \right) \right) \tag{8} \ & + \dfrac{1}{2} \sum_{u \in \mathcal{V}} E_q \left[ \left(\boldsymbol{\theta}u - \boldsymbol{\delta}_u \right)^T \Omega_u^{-1} \left(\boldsymbol{\theta}_u - \boldsymbol{\delta}_u \right) \right] + \dfrac{1}{2} \log \left\vert \Omega_u \right\vert + \dfrac{m}{2} \log(2 \pi) \tag{9} \ & - \sum{l = 1}^L \left((\tilde{a}_l - 1) E_q \left[\log \rho_l \right] + (\tilde{b}_l - 1) E \left[\log (1 - \rho_l) \right] - \log \mathcal{B} (\tilde{a}_l, \tilde{b}_l) \right). \tag{10} \end{align}

where $\mathcal{B}(\cdot,\cdot)$ is the Beta function. The expectations with respect to $q$ can be computed using the known expressions for each factor of $q$ given in the previous sections (recall that $r_i^{(v)}$ depends on the parameters $\boldsymbol{\gamma}$, $\mathbf{s}$, and $\boldsymbol{\theta}$). The line numbers above correspond exactly to those used in the internal function update_hyperparams_logistic_moretrees(), which both computes $\mathcal{E}^*(q)$ and performs the hyperparameter updates described in the next section.

Hyperparameter selection

The hyperparameters $\tau_l$ and $\omega_l$ for $l = 1, \dots, L$ are selected via approximate empirical Bayes. Specifically, every update_hyper_freq iterations, we maximize $\mathcal{E}^*(q)$, which can be viewed as an approximation to the marginal likelihood, in each hyperparameter. This gives the following hyperparameter updates:

\begin{align} \tau_l^{(t)} & = \dfrac{\sum_{u: l_u = u} E_{q}\left[ \boldsymbol{\gamma}u^T \boldsymbol{\gamma}_u \right]}{\sum{u: l_u = u} k} \ \omega_l^{(t)} & = \dfrac{\sum_{u: l_u = u} E_{q} \left[ \boldsymbol{\theta}u^T \boldsymbol{\theta}_u \right]}{\sum{u: l_u = u} m}. \end{align}

This convenient technique has been used elsewhere [@blei2003latent; @khan2010variational]. It is particularly useful in the case of estimating the variance parameters $\tau_l$ and $\omega_l$, since in the absence of prior information, specifying good non-informative priors for variance parameters, such as those recommended by @gelman2006prior, would require breaking the conditional conjugacy that leads to closed-form updates in the VI algorithm.

Ideally, one would find the optimal VI parameters after every hyperparameter update by allowing the algorithm to converge while keeping the hyperparameters fixed. The hyperparameters would then be updated again, and this process would be repeated until convergence of both the VI parameters and hyperparameters. However, in practice this can be too time consuming; as an alternative, we set a maximum of update_hyper_freq VI updates between hyperparameter updates. The default value is update_hyper_freq = 50; if the algorithm is converging slowly, either increasing or decreasing this value may help.

We also allow a different convergence tolerance, tol_hyper, when comparing $\mathcal{E}^(q)$ between consecutive hyperparameter updates. The algorithm is considered to have converged when the difference in $\mathcal{E}^(q)$ between consecutive hyperparameter updates is less than tol_hyper \emph{and} the difference between the last two variational parameter updates is less than tol. If our aim in selecting hyperparameters is merely to identify 'good enough' values, it may be acceptable to set tol_hyper $<$ tol. Because convergence of the hyperparameters can be slow, this also allows us to speed up convergence of the algorithm overall.

References



emgthomas/moretrees_pkg documentation built on June 20, 2020, 6:13 p.m.