# This file is part of the `locus` R package:
# https://github.com/hruffieux/locus
#
# Internal functions gathering the variational updates for the core algorithms.
# Besides improving code readability via modular programming, the main purpose
# is to avoid copy-and-paste programming, as most of these updates (or slightly
# modified versions) are used more than once in the different core algorithms.
# For this reason, we choose to create functions for most variational updates,
# even for those consisting in very basic operations.
# Note that we don't modularize the body of the core for loops for performance
# reasons.
#####################
## alpha's updates ##
#####################
update_m2_alpha_ <- function(alpha_vb, sig2_alpha_vb, sweep = FALSE) {
if(sweep) {
sweep(alpha_vb ^ 2, 1, sig2_alpha_vb, `+`)
} else {
sig2_alpha_vb + alpha_vb ^ 2
}
}
update_sig2_alpha_vb_ <- function(n, zeta2_inv_vb, tau_vb = NULL, intercept = FALSE, c = 1) {
den <- n - 1 + zeta2_inv_vb
if (intercept)
den[1] <- den[1] + 1 # the first column of Z was not scaled, it is the intercept.
if (is.null(tau_vb)) {
1 / (c * den)
} else {
1 / (c * tcrossprod(den, as.matrix(tau_vb)))
}
}
update_sig2_alpha_logit_vb_ <- function(Z, psi_vb, zeta2_inv_vb) {
1 / sweep(2 * crossprod(Z ^ 2, psi_vb), 1, zeta2_inv_vb, `+`)
}
update_mat_z_mu_ <- function(Z, alpha_vb) Z %*% alpha_vb
####################
## beta's updates ##
####################
update_beta_vb_ <- function(gam_vb, mu_beta_vb) gam_vb * mu_beta_vb
update_g_beta_vb_ <- function(list_mu_beta_vb, gam_vb) {
G <- length(list_mu_beta_vb)
lapply(1:G, function(g) sweep(list_mu_beta_vb[[g]], 2, gam_vb[g, ], `*`))
}
update_m2_beta_ <- function(gam_vb, mu_beta_vb, sig2_beta_vb, sweep = FALSE) {
if(sweep) {
sweep(mu_beta_vb ^ 2, 2, sig2_beta_vb, `+`) * gam_vb
} else {
(mu_beta_vb ^ 2 + sig2_beta_vb) * gam_vb
}
}
update_sig2_beta_vb_ <- function(n, sig2_inv_vb, tau_vb = NULL, c = 1) {
if(is.null(tau_vb)) {
1 / (c * (n - 1 + sig2_inv_vb))
} else {
1 / (c * (n - 1 + sig2_inv_vb) * tau_vb)
}
}
update_sig2_beta_logit_vb_ <- function(X, psi_vb, sig2_inv_vb) {
1 / (2 * crossprod(X ^ 2, psi_vb) + sig2_inv_vb)
}
update_mat_x_m1_ <- function(X, beta_vb) X %*% beta_vb
update_g_mat_x_m1_ <- function(list_X, list_beta_vb) {
G <- length(list_X)
Reduce(`+`, lapply(1:G, function(g) list_X[[g]] %*% list_beta_vb[[g]]))
}
update_g_m1_btb_ <- function(gam_vb, list_mu_beta_vb, list_sig2_beta_star, tau_vb) { ## not list_sig2_beta_star_inv!
d <- length(tau_vb)
G <- length(list_mu_beta_vb)
lapply(1:G, function(g) {
gam_vb[g, ]^2 * colSums(list_mu_beta_vb[[g]]^2) + # colSums(A^2) = diag(crossprod(A))
gam_vb[g, ] * (sum(diag(as.matrix(list_sig2_beta_star[[g]]))) / tau_vb +
sapply(1:d, function(k) (1-gam_vb[g, k]) * sum(list_mu_beta_vb[[g]][, k]^2))) # tr(mu_gt mu_gt^T) = sum(mu_gt^2)
})
}
update_g_m1_btXtXb_ <- function(list_X, gam_vb, list_mu_beta_vb, list_sig2_beta_star, tau_vb) {
d <- length(tau_vb)
G <- length(list_mu_beta_vb)
lapply(1:G, function(g) {
gam_vb[g, ]^2 * colSums((list_X[[g]] %*% list_mu_beta_vb[[g]])^2) +
gam_vb[g, ] * (sum(crossprod(list_X[[g]]) * list_sig2_beta_star[[g]]) / tau_vb +
sapply(1:d, function(k) (1-gam_vb[g, k]) * sum(crossprod(list_X[[g]]) * tcrossprod(list_mu_beta_vb[[g]][, k])))) # tr(AB^T) = sum_ij A_ij B_ij
})
}
########################
## c0 and c's updates ##
########################
update_sig2_c0_vb_ <- function(d, s02, c = 1) 1 / (c * (d + (1/s02)))
###################
## chi's updates ##
###################
update_chi_vb_ <- function(X, Z, beta_vb, m2_beta, mat_x_m1, mat_z_mu, sig2_alpha_vb) {
sqrt(X^2 %*% m2_beta + mat_x_m1^2 - X^2 %*% beta_vb^2 + Z^2 %*% sig2_alpha_vb +
mat_z_mu^2 + 2 * mat_x_m1 * mat_z_mu)
}
update_psi_logit_vb_ <- function(chi_vb) {
exp(log(exp(log_sigmoid_(chi_vb)) - 1 / 2) - log(2 * chi_vb))
}
#####################
## omega's updates ##
#####################
a_vb <- update_a_vb <- function(a, rs_gam, c = 1) c * (a + rs_gam) - c + 1
b_vb <- update_b_vb <- function(b, d, rs_gam, c = 1) c * (b - rs_gam + d) - c + 1
update_log_om_vb <- function(a, digam_sum, rs_gam, c = 1) digamma(c * (a + rs_gam) - c + 1) - digam_sum
update_log_1_min_om_vb <- function(b, d, digam_sum, rs_gam, c = 1) digamma(c * (b - rs_gam + d) - c + 1) - digam_sum
#####################
## sigma's updates ##
#####################
update_lambda_vb_ <- function(lambda, sum_gam, c = 1) c * (lambda + sum_gam / 2) - c + 1
update_g_lambda_vb_ <- function(lambda, g_sizes, rs_gam) lambda + sum(g_sizes * rs_gam) / 2
update_nu_vb_ <- function(nu, m2_beta, tau_vb, c = 1) c * as.numeric(nu + crossprod(tau_vb, colSums(m2_beta)) / 2)
update_g_nu_vb_ <- function(nu, list_m1_btb, tau_vb) nu + sum(tau_vb * Reduce(`+`, list_m1_btb))/2
update_nu_bin_vb_ <- function(nu, m2_beta) nu + sum(m2_beta) / 2
update_log_sig2_inv_vb_ <- function(lambda_vb, nu_vb) digamma(lambda_vb) - log(nu_vb)
###################
## tau's updates ##
###################
update_eta_vb_ <- function(n, eta, gam_vb, c = 1) c * (eta + n / 2 + colSums(gam_vb) / 2) - c + 1
update_g_eta_vb_ <- function(n, eta, g_sizes, gam_vb) eta + n / 2 + as.numeric(crossprod(gam_vb, g_sizes)) / 2
update_eta_z_vb_ <- function(n, q, eta, gam_vb, c = 1) c * (eta + n / 2 + colSums(gam_vb) / 2 + q / 2) - c + 1
update_kappa_vb_ <- function(Y, kappa, mat_x_m1, beta_vb, m2_beta, sig2_inv_vb, c = 1) {
n <- nrow(Y)
c * (kappa + (colSums(Y^2) - 2 * colSums(Y * mat_x_m1) +
(n - 1 + sig2_inv_vb) * colSums(m2_beta) +
colSums(mat_x_m1^2) - (n - 1) * colSums(beta_vb^2))/ 2)
}
update_g_kappa_vb_ <- function(Y, list_X, kappa, list_beta_vb, list_m1_btb,
list_m1_btXtXb, mat_x_m1, sig2_inv_vb) {
n <- nrow(Y)
G <- length(list_beta_vb)
# avoid using do.call() as can trigger node stack overflow
kappa + (colSums(Y^2) - 2 * colSums(Y * mat_x_m1) +
Reduce(`+`, list_m1_btXtXb) +
sig2_inv_vb * Reduce(`+`, list_m1_btb) +
colSums(mat_x_m1^2) -
Reduce(`+`, lapply(1:G, function(g) colSums((list_X[[g]] %*% list_beta_vb[[g]])^2) ))) / 2
}
update_kappa_z_vb_ <- function(Y, Z, kappa, alpha_vb, beta_vb, m2_alpha,
m2_beta, mat_x_m1, mat_z_mu, sig2_inv_vb,
zeta2_inv_vb, intercept = FALSE, c = 1) {
n <- nrow(Y)
kappa_vb <- c * (kappa + (colSums(Y^2) - 2 * colSums(Y * (mat_x_m1 + mat_z_mu)) +
(n - 1 + sig2_inv_vb) * colSums(m2_beta) +
colSums(mat_x_m1^2) - (n - 1) * colSums(beta_vb^2) +
(n - 1) * colSums(m2_alpha) +
crossprod(m2_alpha, zeta2_inv_vb) +
colSums(mat_z_mu^2) - (n - 1) * colSums(alpha_vb^2) +
2 * colSums(mat_x_m1 * mat_z_mu))/ 2)
if (intercept)
kappa_vb <- kappa_vb + c * (m2_alpha[1, ] - (alpha_vb[1, ])^2) / 2
kappa_vb
}
update_log_tau_vb_ <- function(eta_vb, kappa_vb) digamma(eta_vb) - log(kappa_vb)
#####################
## theta's updates ##
#####################
update_theta_vb_ <- function(W, m0, S0_inv, sig2_theta_vb, vec_fac_st,
mat_add = 0, is_mat = FALSE, c = 1) {
if (is.null(vec_fac_st)) {
if (is_mat) {
theta_vb <- c * sig2_theta_vb * (rowSums(W) + S0_inv * m0 - rowSums(mat_add))
} else {
theta_vb <- c * sig2_theta_vb * (rowSums(W) + S0_inv * m0 - sum(mat_add))
}
} else {
if (c != 1)
stop("Annealing not implemented when Sigma_0 is not the identity matrix.")
bl_ids <- unique(vec_fac_st)
n_bl <- length(bl_ids)
if (is_mat) {
theta_vb <- unlist(lapply(1:n_bl, function(bl) {
sig2_theta_vb[[bl]] %*% (rowSums(W[vec_fac_st == bl_ids[bl], , drop = FALSE]) +
S0_inv[[bl]] %*% m0[vec_fac_st == bl_ids[bl]] -
rowSums(mat_add[vec_fac_st == bl_ids[bl], , drop = FALSE])) # mat_add = sweep(mat_v_mu, 1, theta_vb, `-`)
}))
} else {
theta_vb <- unlist(lapply(1:n_bl, function(bl) {
sig2_theta_vb[[bl]] %*% (rowSums(W[vec_fac_st == bl_ids[bl], , drop = FALSE]) +
S0_inv[[bl]] %*% m0[vec_fac_st == bl_ids[bl]] -
sum(mat_add))
}))
}
}
}
update_sig2_theta_vb_ <- function(d, p, list_struct, s02, X = NULL, c = 1) {
if (is.null(list_struct)) {
vec_fac_st <- NULL
S0_inv <- 1 / s02 # stands for a diagonal matrix of size p with this value on the (constant) diagonal
sig2_theta_vb <- as.numeric(update_sig2_c0_vb_(d, s02, c = c)) # idem
vec_sum_log_det_theta <- - p * (log(s02) + log(d + S0_inv))
} else {
if (c != 1)
stop("Annealing not implemented when Sigma_0 is not the identity matrix.")
if (is.null(X))
stop("X must be passed to the update_sig2_theta_function.")
vec_fac_st <- list_struct$vec_fac_st
n_cpus <- list_struct$n_cpus
S0_inv <- parallel::mclapply(unique(vec_fac_st), function(bl) {
corX <- cor(X[, vec_fac_st == bl, drop = FALSE])
corX <- as.matrix(Matrix::nearPD(corX, corr = TRUE, do2eigen = TRUE)$mat) # regularization in case of non-positive definiteness.
as.matrix(solve(corX) / s02)
}, mc.cores = n_cpus)
if (is.list(S0_inv)) {
sig2_theta_vb <- parallel::mclapply(S0_inv, function(mat) {
as.matrix(solve(mat + diag(d, nrow(mat))))
}, mc.cores = n_cpus)
} else {
sig2_theta_vb <- 1 / (S0_inv + d)
}
vec_sum_log_det_theta <- log_det(S0_inv) + log_det(sig2_theta_vb) # vec_sum_log_det_theta[bl] = log(det(S0_inv_bl)) + log(det(sig2_theta_vb_bl))
}
create_named_list_(S0_inv, sig2_theta_vb, vec_sum_log_det_theta, vec_fac_st)
}
#################
## W's updates ##
#################
update_W_probit_ <- function(Y, mat_z_mu, mat_x_m1) {
mat_z_mu + mat_x_m1 + inv_mills_ratio_matrix_(Y, mat_z_mu + mat_x_m1)
}
update_W_struct_ <- function(gam_vb, theta_vb) {
log_pnorm <- pnorm(theta_vb, log.p = TRUE)
log_1_pnorm <- pnorm(theta_vb, log.p = TRUE, lower.tail = FALSE)
imr0 <- inv_mills_ratio_(0, theta_vb, log_1_pnorm, log_pnorm)
sweep(sweep(gam_vb, 1, (inv_mills_ratio_(1, theta_vb, log_1_pnorm, log_pnorm) - imr0), `*`),
1, theta_vb + imr0, `+`)
}
####################
## zeta's updates ##
####################
update_phi_z_vb_ <- function(phi, d, c = 1) c * (phi + d / 2) - c + 1
update_xi_z_vb_ <- function(xi, tau_vb, m2_alpha, c = 1) c * (xi + m2_alpha %*% tau_vb / 2)
update_xi_bin_vb_ <- function(xi, m2_alpha) xi + rowSums(m2_alpha) / 2
update_log_zeta2_inv_vb_ <- function(phi_vb, xi_vb) digamma(phi_vb) - log(xi_vb)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.