R/align_topics.R

Defines functions .check_align_input setup_edges compute_ordering_cost consecutive_weights order_topics topics_list align_topics

Documented in align_topics setup_edges

#' Align topics from distinct LDA models
#'
#' This function takes a list of LDA models and returns an object of class
#' \code{alignment}. Each element in the models list must be itself a named
#' list, corresponding to the mixed memberships (\code{$gamma}) and topics
#' (\code{$beta}). The resulting alignment object can be plotted using `plot`
#' and its weights can be extracted using the `weights` accessor function. See
#' the documentation for class \code{alignment} for further details.
#'
#' @param models (required) A list of LDA models object. Each list component
#' must be a list with two named entries, $gamma (containing mixed memberships)
#' and $beta (containing topic parameters in log sapce). See
#' \code{run_lda_models} for details.
#' @param method (required) Either \code{product} or \code{transport}, giving
#' two types of alignment strategies, using inner products between gamma vectors
#' or optimal transport between gamma-beta pairs, respectively. Defaults to
#' \code{product}.
#' @param ... (optional) Further keyword arguments passed to the weight
#' function. For example, passing \code{reg = 10} when using the
#' \code{transport} method will use a regularization level to 10 in the Sinkhorn
#' optimal transport algorithm.
#' @return An object of class \code{alignment} providing the weights between
#' every pair of topics of each model pairs in the input edgelist
#' (\code{comparisons}).
#'
#' @details
#'
#' After topics are aligned, they are re-ordered such that topics connected
#' by high weights are ranked similarly within their respective models.
#'
#' Topic paths (sets of topics connected by high weights across
#' models) are then identified and alignment diagnostics (topic refinement and
#' coherence scores) are computed. These variables are included to the
#' \code{topics} container of the returned \code{alignment}.
#'
#'
#' @seealso alignment
#' @examples
#' library(purrr)
#' data <- rmultinom(10, 20, rep(0.1, 20))
#' lda_params <- setNames(map(1:5, ~ list(k = .)), 1:5)
#' lda_models <- run_lda_models(data, lda_params)
#'
#' alignment <- align_topics(lda_models)
#' alignment
#' plot(alignment)
#'
#' plot(alignment, color_by = "refinement")
#' alignment <- align_topics(lda_models, method = "transport")
#' plot(alignment)
#' plot_beta(alignment)
#'
#' topics(alignment)
#' weights(alignment)
#' models(alignment)
#' @importFrom purrr map
#' @export
align_topics <- function(
  models,
  method = "product",
  ...
) {

  # 1. Check input and initialize key objects
  .check_align_input(models, method)
  weight_fun <- ifelse(method == "product", product_weights, transport_weights)
  if (is.null(names(models))) { names(models) <- seq_along(models) }

  # 2. topics
  topics <- topics_list(models)

  # 3. perform alignment
  weights <-
    align_graph(
      edges = setup_edges("all", names(models)),
      gamma_hats = map(models, ~ .$gamma),
      beta_hats = map(models, ~ exp(.$beta)),
      weight_fun = weight_fun, ...
    )

  aligned_topics <-  new("alignment", topics = topics, weights = weights, models = models)

  # 4. re-order the topics, identify the paths and compute summary diagnostics
  aligned_topics %>%
    order_topics() %>%
    add_paths() %>%
    add_summaries()
}

#' @importFrom magrittr %>%
#' @importFrom purrr imap_dfr
#' @importFrom dplyr tibble mutate
topics_list <-  function(models) {
  imap_dfr(models, ~ {
    k_labels <- rownames(.x$beta)
    if (is.null(k_labels)) {
      k_labels <- as.character(1:nrow(.x$beta))
    }

    tibble(m = .y, k = 1:nrow(.x$beta), k_label = k_labels, mass = colSums(.x$gamma)) %>%
      mutate(prop = mass / sum(mass))
  }) %>%
    mutate(
      m = factor(m, levels = names(models)),
      k_label = factor(k_label, levels = unique(k_label))
    )
}

order_topics <- function(aligned_topics) {

  perms <-
    consecutive_weights(aligned_topics) %>%
    mutate(k_init = k_next) %>%
    forward_ordering() %>%
    backward_ordering() %>%
    ungroup() %>%
    topic_ordering()

  aligned_topics@models <- reorder_models(aligned_topics@models, perms)
  aligned_topics@topics <- reorder_topics(aligned_topics@topics, perms)
  aligned_topics@weights <- reorder_weights(aligned_topics@weights, perms)

  new(
    "alignment",
    topics = aligned_topics@topics,
    weights = aligned_topics@weights,
    models = aligned_topics@models
  )
}

consecutive_weights <- function(aligned_topics) {
  model_names <-
    names(aligned_topics@models) %>%
    factor(., levels = names(aligned_topics@models))

  tibble(
    m = model_names %>% head(-1),
    m_next = model_names %>% tail(-1)
  ) %>%
    left_join(aligned_topics@weights, by = c("m", "m_next"))
}

compute_ordering_cost <- function(weights) {
  weights %>%
    group_by(m) %>%
    mutate(y = k / max(k), y_next = k_next / max(k_next)) %>%
    ungroup() %>%
    mutate(c = abs(y_next - y) * weight) %>%
    pull(c) %>%
    sum()
}

#' Edgelists for Default Alignments
#'
#' This is a helper function for setting up edges that can be used by
#' align_graph. It implements two types of comparisons, 'consecutive' and 'all'.
#' It returns a data frame specifying which topics to compare from across all
#' models.
#'
#' @param comparisons A string describing the type of model comparisons to
#'   compute.
#' @param model_names The names of the models to compare. The resulting edge
#'   list will refer to models by these names.
#' @importFrom magrittr set_colnames %>%
#' @importFrom dplyr filter
#' @importFrom tibble tibble as_tibble
#' @importFrom utils combn head tail
#' @export
setup_edges <- function(comparisons, model_names) {
  edges <- comparisons
  if (comparisons == "consecutive") {
    edges <- tibble(
      from = head(model_names, -1), to = tail(model_names, -1)
    )
  } else if (comparisons == "all") {
    edges <- t(combn(model_names, 2)) %>%
      as_tibble(.name_repair = "unique") %>%
      suppressMessages() %>%
      set_colnames(c("from", "to")) %>%
      filter(from != to)
  }

  edges
}

#' @importFrom purrr map map_int
#' @importFrom stringr str_starts
.check_align_input <- function(
  models,
  method
) {
  # check model list input
  stopifnot(typeof(models) == "list")
  stopifnot(
    all(map_int(models, ~ class(.) == "list"))
  )
  stopifnot(
    all(map_int(models, ~ all(names(.) %in% c("gamma", "beta"))))
  )

  # check method used
  match.arg(tolower(method), c("product", "transport"))
}

#' Alignment between Pairs of Topics
#'
#' This provides a more generic interface to alignment between arbitrary pairs
#' of topics, compared to `align_topics`. Rather than requiring sequential or
#' all-vs-all comparisons, this function supports comparisons between any pairs
#' of models, as specified by the `edges` parameter. Any graph linking pairs of
#' models can be the starting point for an alignment.
#'
#' @param edges A data frame with two columns, $from and $to, giving the names
#' of the models to be aligned. These names must be the names of the lists in
#' `gamma_hats` and `beta_hats`.
#' @param gamma_hats A named list of matrices, giving estimated mixed-membership
#' parameters across a collection of topic models. The names of this list must
#' correspond to the names of models to compare in `edges`.
#' @param beta_hats A named list of matrices, giving estimated topic parameters
#' across a collection of topic models. The names of this list must correspond
#' to the names of models to compare in `edges`.
#' @param weight_fun A function that returns a data.frame giving weights between
#' all pairs of topics between two models. The first argument must accept a list
#' of two gamma_hat matrices, the second argument must accept a list of two
#' beta_hat matrices. See `product_weights` or `transport_weights` for examples.
#' @param ... (optional) Further keyword arguments passed to the weight
#' function. For example, passing \code{reg = 10} when using the
#' \code{transport} method will use a regularization level fo 10 in the Sinkhorn
#' optimal transport algorithm.
#' @importFrom dplyr mutate ungroup
#' @importFrom magrittr %>%
#' @export
align_graph <- function(edges, gamma_hats, beta_hats, weight_fun, ...) {
  weights <- list()
  for (i in seq_len(nrow(edges))) {
    pair <- c(edges$from[i], edges$to[i])
    weights[[i]] <-
      weight_fun(gamma_hats[pair], beta_hats[pair], ...) %>%
      mutate(m = pair[1], m_next = pair[2])
  }
  postprocess_weights(weights, nrow(gamma_hats[[1]]), names(gamma_hats)) %>%
    ungroup()
}

#' Product Weights between a Model Pair
#'
#' An alignment based on product weights sets the weight between topics k and k'
#' according to \eqn{\gamma_{k}^T\gamma_{k}^\prime}, where \eqn{\gamma_{k} \in
#' \mathbb{R}^n_{+}} provides the mixed membership assigned to topic \eqn{k}
#' across the \eqn{n} samples (and similarly for topic \eqn{k^\prime}). This
#' function computes these weights given a list of two \eqn{n \times K} gamma
#' matrices.
#'
#' @param gammas (required) A list of length two, containing the mixed
#' membership matrices (a \code{matrix} of dimension n-samples by k-topics) to
#' compare. The number of columns may be different, but the number of samples
#' must be equal.
#' @param ... (optional) Other keyword arguments. These are unused by the
#' \code{product_weights} alignment strategy, but is included for consistency
#' across weight functions.
#' @return products A \code{data.frame} giving the product similarity of each
#' pair of topics across the two input matrices.
#'
#' @examples
#' g1 <- matrix(runif(20 * 2), 20, 2)
#' g2 <- matrix(runif(20 * 4), 20, 4)
#' product_weights(list(g1, g2))
#'
#' @seealso align_graph
#' @importFrom purrr map
#' @importFrom magrittr %>%
#' @export
product_weights <- function(gammas, ...) {
  products <- t(gammas[[1]]) %*% gammas[[2]]
  # dimnames(products) <- map(gammas, ~ colnames(.))
  dimnames(products) <- map(gammas, ~ 1:ncol(.))
  data.frame(products) %>%
    .lengthen_weights()
}

#' Transport Weights between Model Pair
#'
#' An alignment based on transport weights sets the weight between topics k and
#' k' according to an optimal transport problem with (1) costs set by the
#' distance (specifically, Jensen-Shannon Divergence) between \eqn{\beta_{k}}
#' and \eqn{\beta_{k^\prime}} and (2) masses defined by the total topic mixed
#' memberships \eqn{\sum_{i}\gamma_{ik}} and \eqn{\sum_{i}\gamma_{ik^\prime}}.
#' If topics have similar mixed membership weight and similar topic \eqn{\beta},
#' then they will be given high transport alignment weight.
#'
#' @param gammas (required) A list of length two, containing the mixed
#' membership matrices (a \code{matrix} of dimension n-samples by k-topics) to
#' compare. The number of columns may be different, but the number of samples
#' must be equal.
#' @param betas (required). A list of length two, containing the topic matrices
#' (a \code{matrix} of dimension k-topics by d-dimensions).) The number of rows
#' may be different, but the number of columns must remain fixed.
#' @param reg (optional) How much regularization to use in the Sinkhorn optimal
#' transport algorithm? Defaults to 0.1.
#' @param ... (optional) Other keyword arguments. Not used here, but included
#' for consistency with other weight functions.
#' @return products A \code{data.frame} giving the product similarity of each
#' pair of topics across the two input matrices.
#'
#'
#' @examples
#' library(purrr)
#' data <- rmultinom(10, 20, rep(0.1, 20))
#' lda_params <- setNames(map(1:5, ~ list(k = .)), 1:5)
#' lda_models <- run_lda_models(data, lda_params)
#' gammas <- list(lda_models[[3]]$gamma, lda_models[[5]]$gamma)
#' betas <- list(lda_models[[3]]$beta, lda_models[[5]]$beta)
#' transport_weights(gammas, betas)
#'
#' @importFrom philentropy JSD
#' @importFrom T4transport sinkhornD
#' @importFrom purrr map
#' @export
transport_weights <- function(gammas, betas, reg = 0.1, ...) {
  betas_mat <- do.call(rbind, betas)
  costs <- suppressMessages(JSD(betas_mat))
  ix <- seq_len(nrow(betas[[1]]))

  a <- colSums(gammas[[1]])
  b <- colSums(gammas[[2]])

  # compute transport plan and convert probability measure into general measure
  plan <- sinkhornD(costs[ix, -ix, drop = F], wx = a, wy = b, lambda = reg, ...)$plan
  plan <- (a * plan / rowSums(plan)) * (b / colSums(a * plan / rowSums(plan)))

  if (any(is.na(plan))) {
    plan <- matrix(0, nrow(betas[[1]]), nrow(betas[[2]]))
    warning("OT diverged, considering increasing regularization.\n")
  }

  dimnames(plan) <- map(gammas, ~ 1:ncol(.))
  data.frame(plan) %>%
    .lengthen_weights()
}

################################################################################
# Helper functions for reshaping and renaming
################################################################################
#' @importFrom dplyr mutate
#' @importFrom tibble rownames_to_column
#' @importFrom tidyr pivot_longer
#' @importFrom magrittr %>%
#' @importFrom stringr str_replace
.lengthen_weights <- function(weights) {
  weights %>%
    rownames_to_column("k") %>%
    pivot_longer(-k, names_to = "k_next", values_to = "weight") %>%
    mutate(k_next = str_replace(k_next, "X", ""))
}

#' @importFrom dplyr group_by bind_rows summarize mutate ungroup arrange across
#' everything
#' @importFrom magrittr %>%
postprocess_weights <- function(weights, n_docs, m_levels) {
  bind_rows(weights) %>%
    mutate(
      document_mass = weight,
      weight = document_mass / n_docs,
      ) %>%
    group_by(m, m_next, k_next) %>%
    mutate(bw_weight = weight / sum(weight)) %>%
    ungroup() %>%
    mutate(
      across(c("m", "m_next"), factor, levels = m_levels),
      across(c("k", "k_next"), as.integer)
    ) %>%
    group_by(m, m_next, k) %>%
    mutate(
      fw_weight = weight / sum(weight)
    ) %>%
    arrange(m) %>%
    select(m, m_next, k, k_next, everything())
}

################################################################################
# Class construction and methods
################################################################################
#' @importFrom utils head
print_alignment <- function(object) {
  cat(sprintf(
    "# An %s: %d models, %d topics:\n",
    class(object), n_models(object), n_topics(object)
  ))

  print(head(object@weights))
  cat(sprintf("# ... with %s more rows", nrow(object@weights) - 6))
}

#' Alignment Class Definition
#'
#' The alignment class contains all the information available associated with an
#' alignment across an ensemble of topic models. The available accessor methods
#' are,
#'
#'  * \code{weights}: Extract weights between all pairs of topics within an
#'  alignment. Topic pairs with high alignment scores are more similar to one
#'  another, though the precise implementation will depend on the \code{method}
#'  used during \code{align_topics}. Note that only the weights are needed in
#'  order to compute stability, refinement, and key topics summaries.
#' * \code{models}: Extract the model parameters that were used in the original
#'   alignment. Note that the latent topics may have been reordered, to maximize
#'   the consistency across all models according to their alignment.
#' * \code{n_topics}: How many topics are there total, within the alignment
#'   object?
#' \code{n_models}: How many models total are there, within the alignment
#'   object?
#'
#' @seealso align_topics
#' @import methods
#' @exportClass alignment
setClass("alignment",
         representation(
           topics = "data.frame",
           weights = "data.frame",
           models = "list",
           n_models = "numeric",
           n_topics = "numeric"
         )
)

#' Show Method for Alignment Class
#' @param object An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("show", "alignment", print_alignment)

setGeneric("weights", function(x) standardGeneric("weights"))
#' Weights Accessor for Alignment Class
#' @param x An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("weights", "alignment", function(x) x@weights)

setGeneric("n_models", function(x) standardGeneric("n_models"))

#' Number of Models Method for Alignment Class
#' @param x An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("n_models", "alignment", function(x) nlevels(x@weights$m))


setGeneric("n_topics", function(x) standardGeneric("n_topics"))
#' Number of Topics Method for Alignment Class
#' @param x An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("n_topics", "alignment", function(x) nrow(x@topics))

setGeneric("models", function(x) standardGeneric("models"))
#' Extract Models underlying Alignment
#' @param x An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("models", "alignment", function(x) x@models)


setGeneric("topics", function(x) standardGeneric("topics"))
#' Extract List of Topics and their Summaries
#' @param x An alignment object output from \code{align_topics}.
#' @import methods
#' @export
setMethod("topics", "alignment", function(x) x@topics)
lasy/alto documentation built on June 23, 2024, 6:45 a.m.