R/cavi_plvm.R

Defines functions cavi_plvm

Documented in cavi_plvm

#' CAVI for the PLVM
#'
#' Coordinate Ascent Variational Inference for the Phylogenetic Latent Variable Model
#'
#' @param plvm_list A list defining the PLVM generated by `initialise_plvm`.
#' @param tol A positive real-valued scalar. The required reducting in the ELBO to define convergence,
#' @param max_iter A positive integer. The maximum number of iterations that the optimisation procedure swill run.
#' @param n_samples A positive integer. The number of samples drawn to esimated expected auxiliary trait values.
#' @param random_seed The seed for random number generation.
#' @param progress_bar Logical. Should a progress bar be displayed?
#' @inheritParams initialise_plvm
#'
#' @seealso initialise_plvm
#'
#' @return The optimised PLVM and ELBO at each iteration.
#' @export
cavi_plvm <- function(
  plvm_list,
  tol = 1e-6, max_iter = 1000,
  n_samples = 1000, random_seed = NULL,
  progress_bar = TRUE,
  perform_checks = TRUE
){
  N <- nrow(plvm_list$manifest_trait_df)
  S <- length(plvm_list$phy$tip.label)
  L <- ncol(plvm_list$loading_expectation)
  D <- sum(sapply(plvm_list$metadata$manifest_trait_index, length))
  D_prime <- sum(sapply(plvm_list$metadata$auxiliary_trait_index, length))
  ord_index <- which(plvm_list$metadata$trait_type == "ord")
  mod <- plvm_list
  elbo <- c(
    -Inf,
    compute_plvm_elbo(
      mod,
      n_samples = n_samples, random_seed = random_seed,
      perform_checks = perform_checks
    )
  )
  tmp_elbo <- NULL
  i <- 1
  if (progress_bar) pb <- utils::txtProgressBar(min = 0, max = max_iter, style = 3)
  while (i <= max_iter) {
    #  Phylogenetic GP (Satisfied that this section works)
    for (l in 1:L) {
      # Within-taxon
      tmp_wta  <-  stats::optim(
        par = mod$within_taxon_amplitude[l], fn = within_taxon_amplitude_objective,
        i = l,
        individual_specific_latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
        taxon_id = mod$manifest_trait_df[, mod$id_label], phy = mod$phy,
        terminal_taxon_specific_latent_trait_expectation = mod$taxon_specific_latent_trait_expectation[1:S, , drop=F],
        individual_specific_latent_trait_covariance = mod$individual_specific_latent_trait_covariance,
        individual_specific_latent_trait_outer_product_expectation = mod$individual_specific_latent_trait_outer_product_expectation,
        terminal_taxon_latent_trait_outer_product_expectation = mod$taxon_specific_latent_trait_outer_product_expectation[, , 1:S, drop=F],
        within_taxon_amplitude = mod$within_taxon_amplitude,
        perform_checks = FALSE,
        method = "Brent", lower = 0, upper = 1
      )
      mod$within_taxon_amplitude[l] <- tmp_wta$par
      # Heritable
      tmp_ha <- stats::optim(
        par = mod$heritable_amplitude[l], fn = heritable_amplitude_objective,
        i = l,
        heritable_amplitude = mod$heritable_amplitude, length_scale = mod$length_scale,
        taxon_specific_latent_trait_expectation = mod$taxon_specific_latent_trait_expectation,
        taxon_specific_latent_trait_outer_product_expectation = mod$taxon_specific_latent_trait_outer_product_expectation,
        taxon_specific_latent_trait_covariance = mod$taxon_specific_latent_trait_covariance,
        phy = mod$phy,
        phylogenetic_gp = mod$phylogenetic_GP,
        perform_checks = FALSE,
        method = "Brent", lower = 0, upper = 1
      )
      mod$heritable_amplitude[l] <- tmp_ha$par
      mod$phylogenetic_GP[, , l] <- reparameterise_phylogenetic_ou(
        phy = mod$phy,
        heritable_amplitude = mod$heritable_amplitude[l],
        length_scale = mod$length_scale,
        environmental_amplitude = sqrt(1 - mod$heritable_amplitude[l]^2),
        perform_checks = FALSE
      )
    }
    ## ARD
    mod$ard_precision <- update_loading_ard_precision(
      loading_col_outer_product_expectation = mod$loading_col_outer_product_expectation,
      inv_loading_prior_correlation = mod$inv_loading_prior_correlation,
      ard_shape = 0,  ard_rate = 0
    )
    # Precision
    mod$precision <- update_precision(
      precision = mod$precision,
      metadata = mod$metadata,
      auxiliary_traits = mod$auxiliary_traits,
      loading_expectation = mod$loading_expectation,
      latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
      loading_outer_expectation = mod$loading_row_outer_product_expectation,
      latent_trait_outer_expectation = mod$individual_specific_latent_trait_outer_product_expectation
    )
    mod$precision_vector <- map_precision_to_auxiliary_traits(
      precision = mod$precision,
      auxiliary_trait_index = mod$metadata$auxiliary_trait_index,
      perform_checks = FALSE
    )
    # ordinal trait cut off points
    for (j in ord_index) {
      for (k in 3:mod$metadata$trait_levels[j]) {
        tmp_co <- stats::optim(
          par = mod$metadata$cut_off_points[[j]][k], fn = ordinal_trait_cut_off_objective,
          i = k,
          y = mod$manifest_trait_df[, mod$metadata$manifest_trait_index[[j]]],
          cut_off_points = mod$metadata$cut_off_points[[j]],
          loading_expectation = mod$loading_expectation[mod$metadata$auxiliary_trait_index[[j]], ],
          latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
          loading_outer_expectation = mod$loading_row_outer_product_expectation[, , mod$metadata$auxiliary_trait_index[[j]]],
          latent_trait_outer_expectation = mod$individual_specific_latent_trait_outer_product_expectation,
          perform_checks = FALSE,
          method = "Brent",
          lower = mod$metadata$cut_off_points[[j]][k-1],
          upper = min(mod$metadata$cut_off_points[[j]][k+1], mod$metadata$cut_off_points[[j]][k] + 10)
        )
        mod$metadata$cut_off_points[[j]][k] <- tmp_co$par
      }
    }
    # auxiliary traits
    mod$auxiliary_traits <-  update_discrete_auxiliary_traits(
      manifest_trait_df = mod$manifest_trait_df,
      metadata = mod$metadata,
      auxiliary_traits = mod$auxiliary_traits,
      loading_expectation = mod$loading_expectation,
      latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
      n_samples = n_samples, random_seed = random_seed,
      perform_checks = FALSE
    )
    # Loading
    mod$loading_row_precision <- simplify2array(compute_loading_row_precision_list(
      total_individual_specific_latent_trait_outer_product_expectation =
        apply(mod$individual_specific_latent_trait_outer_product_expectation, c(1, 2), sum),
      precision_vector = mod$precision_vector,
      ard_precision = mod$ard_precision,
      scaled_conditional_row_variance_vector = mod$scaled_conditional_loading_row_variance_vector,
      perform_checks = FALSE
    ))
    mod$loading_row_precision <- array(mod$loading_row_precision, dim = c(L, L, D_prime))
    mod$loading_row_covariance <- simplify2array(lapply(
      1:D_prime,
      function(d){
        chol2inv(chol(as.matrix(mod$loading_row_precision[, , d])))
      }
    ))
    mod$loading_row_covariance <- array(mod$loading_row_covariance, dim = c(L, L, D_prime))
    mod$loading_expectation <-  compute_loading_expectation(
      current_loading_expectation = mod$loading_expectation,
      loading_row_precision = mod$loading_row_precision,
      auxiliary_traits = mod$auxiliary_traits,
      latent_trait_expectation = mod$individual_specific_latent_trait_expectation,
      precision_vector = mod$precision_vector,
      ard_precision = mod$ard_precision,
      scaled_conditional_row_variance_vector = mod$scaled_conditional_loading_row_variance_vector,
      loading_row_conditional_mean_weight = mod$loading_row_conditional_mean_weight,
      perform_checks = FALSE
    )
    mod$loading_row_outer_product_expectation <- simplify2array(lapply(
      1:D_prime,
      function(d){
        gaussian_outer_product_expectation(
          expected_value = mod$loading_expectation[d, ],
          covariance_matrix = mod$loading_row_covariance[, , d],
          perform_checks = FALSE
        )
      }
    ))
    mod$loading_row_outer_product_expectation <- array(mod$loading_row_outer_product_expectation, dim = c(L, L, D_prime))
    mod$loading_col_outer_product_expectation <- simplify2array(lapply(
      1:L,
      function(l){
        gaussian_outer_product_expectation(
          expected_value = mod$loading_expectation[, l],
          covariance_matrix = diag(mod$loading_row_covariance[l, l, ]),
          perform_checks = FALSE
        )
      }
    ))
    mod$loading_col_outer_product_expectation <- array(mod$loading_col_outer_product_expectation, dim = c(D_prime, D_prime, L))
    # Individual-Specific Latent Traits
    mod$individual_specific_latent_trait_precision <- compute_individual_specific_latent_trait_precision(
      precision_vector = mod$precision_vector,
      loading_outer_product_expectation = mod$loading_row_outer_product_expectation,
      within_taxon_amplitude = mod$within_taxon_amplitude,
      perform_checks = TRUE
    )
    mod$individual_specific_latent_trait_covariance <- chol2inv(chol(mod$individual_specific_latent_trait_precision))
    mod$individual_specific_latent_trait_expectation <- t(sapply(
      1:N, function(i) {
        taxon_ind <- which(mod$phy$tip.label == mod$manifest_trait_df[i, mod$id_label])
        compute_individual_specific_latent_trait_expectation(
          auxiliary_trait = mod$auxiliary_traits[i, ],
          loading = mod$loading_expectation,
          taxon_specific_latent_trait = mod$taxon_specific_latent_trait_expectation[taxon_ind, ],
          precision_vector = mod$precision_vector,
          individual_specific_latent_trait_precision = mod$individual_specific_latent_trait_precision,
          within_taxon_amplitude = mod$within_taxon_amplitude,
          perform_checks = FALSE
        )
      }
    ))
    if (L == 1) mod$individual_specific_latent_trait_expectation <- t(mod$individual_specific_latent_trait_expectation)
    mod$individual_specific_latent_trait_outer_product_expectation <- simplify2array(lapply(
      1:N,
      function(i){
        gaussian_outer_product_expectation(
          expected_value = t(mod$individual_specific_latent_trait_expectation[i, ,drop=F]),
          covariance_matrix = mod$individual_specific_latent_trait_covariance,
          perform_checks = FALSE
        )
      }
    ))
    mod$individual_specific_latent_trait_outer_product_expectation <- array(
      mod$individual_specific_latent_trait_outer_product_expectation,
      dim = c(L, L, N)
      )
    # Taxon-specific latent traits
    for (s in 1:S) {
      desc_ind <- mod$manifest_trait_df[, mod$id_label] == mod$phy$tip.label[s]
      anc_ind <- phangorn::Ancestors(mod$phy, s, type = "parent")
      mod$taxon_specific_latent_trait_precision[s, ] <- compute_terminal_taxon_specific_latent_trait_precision(
        N_s = sum(desc_ind),
        within_taxon_amplitude = mod$within_taxon_amplitude,
        conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
        perform_checks = FALSE
      )
      mod$taxon_specific_latent_trait_expectation[s, ] <- compute_terminal_taxon_specific_latent_trait_expectation(
        individual_specific_latent_traits = mod$individual_specific_latent_trait_expectation[desc_ind, , drop=F],
        within_taxon_amplitude = mod$within_taxon_amplitude,
        parent_taxon_latent_trait = mod$taxon_specific_latent_trait_expectation[anc_ind, ],
        conditional_expectation_weight = mod$phylogenetic_GP[s, "weight", ],
        conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
        latent_trait_precision = mod$taxon_specific_latent_trait_precision[s, ],
        perform_checks = FALSE
      )
      mod$taxon_specific_latent_trait_outer_product_expectation[, , s] <- gaussian_outer_product_expectation(
        expected_value = mod$taxon_specific_latent_trait_expectation[s, ],
        covariance_matrix = diag(L) * (1 / mod$taxon_specific_latent_trait_precision[s, ]),
        perform_checks = FALSE
      )
    }
    for (s in unique(mod$phy$edge[ape::postorder(mod$phy), 1])) {
      ch <- mod$phy$edge[mod$phy$edge[, 1] == s, 2]
      anc_ind <- phangorn::Ancestors(mod$phy, s, type = "parent")
      if (anc_ind == 0) {
        anc_lt <- rep(0, L)
      } else {
        anc_lt <- mod$taxon_specific_latent_trait_expectation[anc_ind, ]
      }
      mod$taxon_specific_latent_trait_precision[s, ] <- compute_internal_taxon_specific_latent_trait_precision(
        child_taxa_conditional_expectation_weights = as.matrix(mod$phylogenetic_GP[ch, "weight", ]),
        child_taxa_conditional_standard_deviations = as.matrix(mod$phylogenetic_GP[ch, "sd", ]),
        conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
        perform_checks = FALSE
      )
      mod$taxon_specific_latent_trait_expectation[s, ] <- compute_internal_taxon_specific_latent_trait_expectation(
        child_taxa_latent_traits = mod$taxon_specific_latent_trait_expectation[ch, ],
        child_taxa_conditional_expectation_weights = as.matrix(mod$phylogenetic_GP[ch, "weight", ]),
        child_taxa_conditional_standard_deviations = as.matrix(mod$phylogenetic_GP[ch, "sd", ]),
        parent_taxon_latent_trait = anc_lt,
        conditional_expectation_weight = mod$phylogenetic_GP[s, "weight", ],
        conditional_standard_deviation = mod$phylogenetic_GP[s, "sd", ],
        latent_trait_precision = mod$taxon_specific_latent_trait_precision[s, ],
        perform_checks = FALSE
      )
      mod$taxon_specific_latent_trait_outer_product_expectation[, , s] <- gaussian_outer_product_expectation(
        expected_value = mod$taxon_specific_latent_trait_expectation[s, ],
        covariance_matrix = diag(L) *(1 / mod$taxon_specific_latent_trait_precision[s, ]),
        perform_checks = FALSE
      )
    }
    # ELBO
    elbo <- c(
      elbo,
      compute_plvm_elbo(
        mod,
        n_samples = n_samples, random_seed = random_seed,
        perform_checks = FALSE
      )
    )
    if (progress_bar) utils::setTxtProgressBar(pb, i)
    i <- i + 1
  }
  if (progress_bar) close(pb)

  list(
    model = mod,
    elbo = elbo
  )
}
jpmeagher/vbar documentation built on Nov. 22, 2022, 5:48 a.m.