This notebook reproduces Figures 5, assuming that projections have been generated using the bootstrap.Rmd (for all panels in the figure) and saved as csv's to the derived_dir directory. For example, for a panel of VAE projections for a model with $K = 64$, you should see files like vae-k64-50.csv and vae-k64-50.csv.

The block below loads relevant libraries and defines a function that is used to match models of comparable complexity across the CNN, VAE, and RCF. E.g., a model with $K = 128$ for the CNN or VAE corresponds to a model of size $K = 1024$ in the RCF.

library(LFBCR)
library(tidyverse)
library(reticulate)
theme_set(min_theme())

# links K's if they correspond to the same overall model complexity
expand_k <- function(k) {
  if (k == 32) {
    k_ <- c(32, 256)
  } else if (k == 64) {
    k_ <- c(64, 512)
  } else if (k == 128) {
    k_ <- c(128, 1024)
  }

  k_
}

The block below reads in projections from each of the subsets and then parses the filenames to figure out which type of subset it was (e.g., is it from a VAE or a CNN?). This is used to arrange the projections onto separate facets.

paths <- list.files(params$derived_dir, "*.csv", full = TRUE)
coordinates <- map_dfr(paths, ~ read_csv(.), .id = "path") %>%
  mutate(
    path = paths[as.integer(path)],
    model = str_extract(basename(path), "[A-z]+"),
    k = str_extract(path, "k[0-9]+"),
    k = as.integer(str_remove(k, "k")),
    s = str_extract(path, "-[0-9]+_"),
    s = str_remove(s, "-"),
    s = as.integer(str_remove(s, "_")),
    bootstrap = str_extract(path, "[A-z]+.csv"),
    bootstrap = str_remove(bootstrap, ".csv"),
    bootstrap = str_remove(bootstrap, "_"),
    bootstrap = factor(bootstrap, levels = c("parametric", "nonparametric", "compromise")),
    X1 = ifelse(bootstrap == "nonparametric" & str_detect(model, "rcf"), X1 + runif(n(), -0.01, 0.01), X1),
    X2 = ifelse(bootstrap == "nonparametric" & str_detect(model, "rcf"), X2 + runif(n(), -0.01, 0.01), X2),
  )

The block below generates a set of projection images and places them in the figure_dir directory. We first filter on s and k and facet by model ~ bootstrap, like in Figure 5a. Then, we filter by model and bootstrap and compare across s ~ k, like in Figure 5b.

for (s_ in c(15, 50, 90)) {
  for (k_ in c(32, 64, 128)) {
    cur_data <- coordinates %>%
      filter(s == s_, k %in% expand_k(k_)) %>%
      mutate(model = factor(model, levels = c("cnn", "vae", "rcf")))
    if (nrow(cur_data) > 0) { # check in case current model / data subset is not yet run
      plot_overlay_combined(cur_data) +
        labs(x = "Dimension 1", y = "Dimension 2") +
        facet_wrap(model ~ bootstrap, scale = "free", ncol = 3)
      ggsave(str_c(params$figure_dir, "/", s_, "_", k_, "_", "coordinates.png"), dpi = 600, height = 6, width = 4)
    }
  }
}

for (m in c("vae", "cnn", "rcf")) {
  for (bt in c("nonparametric", "parametric", "compromise")) {
    cur_data <- coordinates %>%
      filter(bootstrap == bt, model == m)
    if (nrow(cur_data) > 0) { # check in case current model / data subset is not yet run
      plot_overlay_combined(cur_data) +
        labs(x = "Dimension 1", y = "Dimension 2") +
        facet_wrap(k ~ s, ncol = 3, scale = "free")
      ggsave(str_c(params$figure_dir, "/", m, "_", bt, "_", "coordinates.png"), dpi = 600, height = 6, width = 4)
    }
  }
}


krisrs1128/MSLF documentation built on March 21, 2023, 10:21 a.m.