knitr::opts_chunk$set(collapse = TRUE, comment = "#>", cache=TRUE) set.seed(params$id)
In this vignette, we simulate from two true LDA models, one with five topics and one null model with a single topic. We repeat the simulations 20 times each, generate alignments for each of the simulated data sets, and visualize the alignment.
The packages used in this vignette are given below.
library(MCMCpack) library(alto) library(dplyr) library(ggplot2) library(purrr) library(readr) library(stringr) library(colorspace) library(gridExtra) library(cowplot) source("https://raw.githubusercontent.com/krisrs1128/topic_align/main/simulations/simulation_functions.R") my_theme()
The block below simulates data x
according to an LDA models with parameters
specified above. The topics are relatively sparse, with $\lambda_{\beta} = 0.1$
and $\lambda_{\gamma} = 0.5$. Each sample has 10,000 counts.
attach(params) n_sims <- 10 lambdas <- list(beta = 0.1, gamma = .5, count = 1e4) betas_non_null <- rdirichlet(K, rep(lambdas$beta, V)) betas_null <- rdirichlet(N, rep(1, V)) gammas_non_null <- rdirichlet(N, rep(lambdas$gamma, K)) gammas_null <- diag(rep(1, N)) counts_non_null <- replicate(n = n_sims, simulate_lda(betas_non_null, gammas_non_null, lambda = lambdas$count), simplify = FALSE) counts_null <- replicate(n = n_sims, simulate_lda(betas_null, gammas_null, lambda = lambdas$count), simplify = FALSE)
Next, we run the LDA models and compute the alignment.
lda_params <- map(1:n_models, ~ list(k = .)) names(lda_params) <- str_c("K", 1:n_models) alignments_non_null <- lapply(counts_non_null, function(x) x %>% run_lda_models(lda_params, reset = TRUE) %>% align_topics(method = params$method) ) alignments_null <- lapply(counts_null, function(x) x %>% run_lda_models(lda_params, reset = TRUE) %>% align_topics(method = params$method) )
plots_non_null = lapply(alignments_non_null, plot) plots_null = lapply(alignments_null, plot) plots = c(plots_non_null, plots_null) grid.arrange(grobs = plots, cols = 4, layout_matrix = matrix(1:20, nrow = 5, byrow = FALSE))
layout_matrix = matrix(1:20, ncol = 4) layout_matrix = rbind(layout_matrix, 21) max_coherence = 1 plots_non_null = lapply(alignments_non_null, function(a) { plot(a, color_by = "coherence") + scale_fill_continuous(name = "Coherence", limits = c(0, max_coherence), low = "brown1", high = "cornflowerblue") + theme(legend.position = "none") }) plots_null = lapply(alignments_null, function(a) { plot(a, color_by = "coherence") + scale_fill_continuous(name = "Coherence", limits = c(0, max_coherence), low = "brown1", high = "cornflowerblue") + theme(legend.position = "none") }) plots = c(plots_non_null, plots_null) l = get_legend(qplot(x = 1, y = 1, fill = c(0, max_coherence)) + scale_fill_continuous(name = "Coherence", limits = c(0, max_coherence), low = "brown1", high = "cornflowerblue") + theme(legend.position = "bottom")) plots[[21]] = l grid.arrange(grobs = plots, layout_matrix = layout_matrix)
max_refinement = 9 plots_non_null = lapply(alignments_non_null, function(a) { plot(a, color_by = "refinement") + scale_fill_continuous_divergingx(name = "Refinement", mid = 1, palette = "Roma", p1 = .005, p3 = .005, limits = c(0, max_refinement)) + theme(legend.position = "none") }) plots_null = lapply(alignments_null, function(a) { plot(a, color_by = "refinement") + scale_fill_continuous_divergingx(name = "Refinement", mid = 1, palette = "Roma", p1 = .005, p3 = .005, limits = c(0, max_refinement)) + theme(legend.position = "none") }) plots = c(plots_non_null, plots_null) l = get_legend(plots[[1]] + theme(legend.position = "bottom")) plots[[21]] = l grid.arrange(grobs = plots, layout_matrix = layout_matrix)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.