#' CAVI for the PLVM
#'
#' Coordinate Ascent Variational Inference for the Phylogenetic Latent Variable Model
#'
#' @param plvm_list A list defining the PLVM generated by `initialise_plvm`.
#' @param tol A positive real-valued scalar. The required reducting in the ELBO to define convergence,
#' @param max_iter A positive integer. The maximum number of iterations that the optimisation procedure swill run.
#' @param n_samples A positive integer. The number of samples drawn to esimated expected auxiliary trait values.
#' @param random_seed The seed for random number generation.
#' @param progress_bar Logical. Should a progress bar be displayed?
#' @inheritParams initialise_plvm
#'
#' @seealso initialise_plvm
#'
#' @return The optimised PLVM and ELBO at each iteration.
#' @export
cavi_plvm <- function(
plvm_list,
tol = 1e-6, max_iter = 1000,
n_samples = 1000, random_seed = NULL,
progress_bar = TRUE,
perform_checks = TRUE
){
N <- nrow(plvm_list$manifest_trait_df)
S <- length(plvm_list$phy$tip.label)
L <- ncol(plvm_list$loading_expectation)
D <- sum(sapply(plvm_list$metadata$manifest_trait_index, length))
D_prime <- sum(sapply(plvm_list$metadata$auxiliary_trait_index, length))
ord_index <- which(plvm_list$metadata$trait_type == "ord")
mod <- plvm_list
elbo <- c(
-Inf,
compute_plvm_elbo(
mod,
n_samples = n_samples, random_seed = random_seed,
perform_checks = perform_checks
)
)
tmp_elbo <- NULL
i <- 1
if (progress_bar) pb <- utils::txtProgressBar(min = 0, max = max_iter, style = 3)
while (i <= max_iter) {
# Phylogenetic GP (Satisfied that this section works)
for (l in 1:L) {
# Within-taxon
tmp_wta <- stats::optim(
par = mod$within_taxon_amplitude[l], fn = within_taxon_amplitude_objective,
i = l,
individual_specific_latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
taxon_id = mod$manifest_trait_df[, mod$id_label], phy = mod$phy,
terminal_taxon_specific_latent_trait_expectation = mod$taxon_specific_latent_trait_expectation[1:S, , drop=F],
individual_specific_latent_trait_covariance = mod$individual_specific_latent_trait_covariance,
individual_specific_latent_trait_outer_product_expectation = mod$individual_specific_latent_trait_outer_product_expectation,
terminal_taxon_latent_trait_outer_product_expectation = mod$taxon_specific_latent_trait_outer_product_expectation[, , 1:S, drop=F],
within_taxon_amplitude = mod$within_taxon_amplitude,
perform_checks = FALSE,
method = "Brent", lower = 0, upper = 1
)
mod$within_taxon_amplitude[l] <- tmp_wta$par
# Heritable
tmp_ha <- stats::optim(
par = mod$heritable_amplitude[l], fn = heritable_amplitude_objective,
i = l,
heritable_amplitude = mod$heritable_amplitude, length_scale = mod$length_scale,
taxon_specific_latent_trait_expectation = mod$taxon_specific_latent_trait_expectation,
taxon_specific_latent_trait_outer_product_expectation = mod$taxon_specific_latent_trait_outer_product_expectation,
taxon_specific_latent_trait_covariance = mod$taxon_specific_latent_trait_covariance,
phy = mod$phy,
phylogenetic_gp = mod$phylogenetic_GP,
perform_checks = FALSE,
method = "Brent", lower = 0, upper = 1
)
mod$heritable_amplitude[l] <- tmp_ha$par
mod$phylogenetic_GP[, , l] <- reparameterise_phylogenetic_ou(
phy = mod$phy,
heritable_amplitude = mod$heritable_amplitude[l],
length_scale = mod$length_scale,
environmental_amplitude = sqrt(1 - mod$heritable_amplitude[l]^2),
perform_checks = FALSE
)
}
## ARD
mod$ard_precision <- update_loading_ard_precision(
loading_col_outer_product_expectation = mod$loading_col_outer_product_expectation,
inv_loading_prior_correlation = mod$inv_loading_prior_correlation,
ard_shape = 0, ard_rate = 0
)
# Precision
mod$precision <- update_precision(
precision = mod$precision,
metadata = mod$metadata,
auxiliary_traits = mod$auxiliary_traits,
loading_expectation = mod$loading_expectation,
latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
loading_outer_expectation = mod$loading_row_outer_product_expectation,
latent_trait_outer_expectation = mod$individual_specific_latent_trait_outer_product_expectation
)
mod$precision_vector <- map_precision_to_auxiliary_traits(
precision = mod$precision,
auxiliary_trait_index = mod$metadata$auxiliary_trait_index,
perform_checks = FALSE
)
# ordinal trait cut off points
for (j in ord_index) {
for (k in 3:mod$metadata$trait_levels[j]) {
tmp_co <- stats::optim(
par = mod$metadata$cut_off_points[[j]][k], fn = ordinal_trait_cut_off_objective,
i = k,
y = mod$manifest_trait_df[, mod$metadata$manifest_trait_index[[j]]],
cut_off_points = mod$metadata$cut_off_points[[j]],
loading_expectation = mod$loading_expectation[mod$metadata$auxiliary_trait_index[[j]], ],
latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
loading_outer_expectation = mod$loading_row_outer_product_expectation[, , mod$metadata$auxiliary_trait_index[[j]]],
latent_trait_outer_expectation = mod$individual_specific_latent_trait_outer_product_expectation,
perform_checks = FALSE,
method = "Brent",
lower = mod$metadata$cut_off_points[[j]][k-1],
upper = min(mod$metadata$cut_off_points[[j]][k+1], mod$metadata$cut_off_points[[j]][k] + 10)
)
mod$metadata$cut_off_points[[j]][k] <- tmp_co$par
}
}
# auxiliary traits
mod$auxiliary_traits <- update_discrete_auxiliary_traits(
manifest_trait_df = mod$manifest_trait_df,
metadata = mod$metadata,
auxiliary_traits = mod$auxiliary_traits,
loading_expectation = mod$loading_expectation,
latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
n_samples = n_samples, random_seed = random_seed,
perform_checks = FALSE
)
# Loading
mod$loading_row_precision <- simplify2array(compute_loading_row_precision_list(
total_individual_specific_latent_trait_outer_product_expectation =
apply(mod$individual_specific_latent_trait_outer_product_expectation, c(1, 2), sum),
precision_vector = mod$precision_vector,
ard_precision = mod$ard_precision,
scaled_conditional_row_variance_vector = mod$scaled_conditional_loading_row_variance_vector,
perform_checks = FALSE
))
mod$loading_row_precision <- array(mod$loading_row_precision, dim = c(L, L, D_prime))
mod$loading_row_covariance <- simplify2array(lapply(
1:D_prime,
function(d){
chol2inv(chol(as.matrix(mod$loading_row_precision[, , d])))
}
))
mod$loading_row_covariance <- array(mod$loading_row_covariance, dim = c(L, L, D_prime))
mod$loading_expectation <- compute_loading_expectation(
current_loading_expectation = mod$loading_expectation,
loading_row_precision = mod$loading_row_precision,
auxiliary_traits = mod$auxiliary_traits,
latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
precision_vector = mod$precision_vector,
ard_precision = mod$ard_precision,
scaled_conditional_row_variance_vector = mod$scaled_conditional_loading_row_variance_vector,
loading_row_conditional_mean_weight = mod$loading_row_conditional_mean_weight,
perform_checks = FALSE
)
mod$loading_row_outer_product_expectation <- simplify2array(lapply(
1:D_prime,
function(d){
gaussian_outer_product_expectation(
expected_value = mod$loading_expectation[d, ],
covariance_matrix = mod$loading_row_covariance[, , d],
perform_checks = FALSE
)
}
))
mod$loading_row_outer_product_expectation <- array(mod$loading_row_outer_product_expectation, dim = c(L, L, D_prime))
mod$loading_col_outer_product_expectation <- simplify2array(lapply(
1:L,
function(l){
gaussian_outer_product_expectation(
expected_value = mod$loading_expectation[, l],
covariance_matrix = diag(mod$loading_row_covariance[l, l, ]),
perform_checks = FALSE
)
}
))
mod$loading_col_outer_product_expectation <- array(mod$loading_col_outer_product_expectation, dim = c(D_prime, D_prime, L))
# Individual-Specific Latent Traits
mod$individual_specific_latent_trait_precision <- compute_individual_specific_latent_trait_precision(
precision_vector = mod$precision_vector,
loading_outer_product_expectation = mod$loading_row_outer_product_expectation,
within_taxon_amplitude = mod$within_taxon_amplitude,
perform_checks = TRUE
)
mod$individual_specific_latent_trait_covariance <- chol2inv(chol(mod$individual_specific_latent_trait_precision))
mod$individual_specific_latent_trait_expectation <- t(sapply(
1:N, function(i) {
taxon_ind <- which(mod$phy$tip.label == mod$manifest_trait_df[i, mod$id_label])
compute_individual_specific_latent_trait_expectation(
auxiliary_trait = mod$auxiliary_traits[i, ],
loading = mod$loading_expectation,
taxon_specific_latent_trait = mod$taxon_specific_latent_trait_expectation[taxon_ind, ],
precision_vector = mod$precision_vector,
individual_specific_latent_trait_precision = mod$individual_specific_latent_trait_precision,
within_taxon_amplitude = mod$within_taxon_amplitude,
perform_checks = FALSE
)
}
))
if (L == 1) mod$individual_specific_latent_trait_expectation <- t(mod$individual_specific_latent_trait_expectation)
mod$individual_specific_latent_trait_outer_product_expectation <- simplify2array(lapply(
1:N,
function(i){
gaussian_outer_product_expectation(
expected_value = t(mod$individual_specific_latent_trait_expectation[i, ,drop=F]),
covariance_matrix = mod$individual_specific_latent_trait_covariance,
perform_checks = FALSE
)
}
))
mod$individual_specific_latent_trait_outer_product_expectation <- array(
mod$individual_specific_latent_trait_outer_product_expectation,
dim = c(L, L, N)
)
# Taxon-specific latent traits
for (s in 1:S) {
desc_ind <- mod$manifest_trait_df[, mod$id_label] == mod$phy$tip.label[s]
anc_ind <- phangorn::Ancestors(mod$phy, s, type = "parent")
mod$taxon_specific_latent_trait_precision[s, ] <- compute_terminal_taxon_specific_latent_trait_precision(
N_s = sum(desc_ind),
within_taxon_amplitude = mod$within_taxon_amplitude,
conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
perform_checks = FALSE
)
mod$taxon_specific_latent_trait_expectation[s, ] <- compute_terminal_taxon_specific_latent_trait_expectation(
individual_specific_latent_traits = mod$individual_specific_latent_trait_expectation[desc_ind, , drop=F],
within_taxon_amplitude = mod$within_taxon_amplitude,
parent_taxon_latent_trait = mod$taxon_specific_latent_trait_expectation[anc_ind, ],
conditional_expectation_weight = mod$phylogenetic_GP[s, "weight", ],
conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
latent_trait_precision = mod$taxon_specific_latent_trait_precision[s, ],
perform_checks = FALSE
)
mod$taxon_specific_latent_trait_outer_product_expectation[, , s] <- gaussian_outer_product_expectation(
expected_value = mod$taxon_specific_latent_trait_expectation[s, ],
covariance_matrix = diag(L) * (1 / mod$taxon_specific_latent_trait_precision[s, ]),
perform_checks = FALSE
)
}
for (s in unique(mod$phy$edge[ape::postorder(mod$phy), 1])) {
ch <- mod$phy$edge[mod$phy$edge[, 1] == s, 2]
anc_ind <- phangorn::Ancestors(mod$phy, s, type = "parent")
if (anc_ind == 0) {
anc_lt <- rep(0, L)
} else {
anc_lt <- mod$taxon_specific_latent_trait_expectation[anc_ind, ]
}
mod$taxon_specific_latent_trait_precision[s, ] <- compute_internal_taxon_specific_latent_trait_precision(
child_taxa_conditional_expectation_weights = as.matrix(mod$phylogenetic_GP[ch, "weight", ]),
child_taxa_conditional_standard_deviations = as.matrix(mod$phylogenetic_GP[ch, "sd", ]),
conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
perform_checks = FALSE
)
mod$taxon_specific_latent_trait_expectation[s, ] <- compute_internal_taxon_specific_latent_trait_expectation(
child_taxa_latent_traits = mod$taxon_specific_latent_trait_expectation[ch, ],
child_taxa_conditional_expectation_weights = as.matrix(mod$phylogenetic_GP[ch, "weight", ]),
child_taxa_conditional_standard_deviations = as.matrix(mod$phylogenetic_GP[ch, "sd", ]),
parent_taxon_latent_trait = anc_lt,
conditional_expectation_weight = mod$phylogenetic_GP[s, "weight", ],
conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
latent_trait_precision = mod$taxon_specific_latent_trait_precision[s, ],
perform_checks = FALSE
)
mod$taxon_specific_latent_trait_outer_product_expectation[, , s] <- gaussian_outer_product_expectation(
expected_value = mod$taxon_specific_latent_trait_expectation[s, ],
covariance_matrix = diag(L) *(1 / mod$taxon_specific_latent_trait_precision[s, ]),
perform_checks = FALSE
)
}
# ELBO
elbo <- c(
elbo,
compute_plvm_elbo(
mod,
n_samples = n_samples, random_seed = random_seed,
perform_checks = FALSE
)
)
if (progress_bar) utils::setTxtProgressBar(pb, i)
i <- i + 1
}
if (progress_bar) close(pb)
list(
model = mod,
elbo = elbo
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.