set.seed(1234)
library(glmnet)
library(LFBCR)
library(tidyverse)
library(reticulate)
np <- import("numpy")
theme_set(min_theme())
attach(params)
layer_prefix <- "*_best.npy"
z_paths <- list.files(input_dir, layer_prefix, recursive = T, full = T)
Zb <- map(z_paths, ~ drop(np$load(.)))
Xy <- read_csv(list.files(input_dir, "Xy.csv", recursive = T, full = T)[1])
ix <- list.files(input_dir, "*subset*", recursive = TRUE, full = T) %>%
  .[[1]] %>%
  read_csv() %>%
  rename(ix = X1) %>%
  left_join(Xy)
performances <- list()

for (i in seq_along(z_paths)) {
  tmp <- as.data.frame(Zb[[i]]) %>%
    mutate(ix = ix$ix) %>%
    left_join(ix)

  x_train <- tmp %>% filter(split == "train") %>% select(starts_with("V")) %>% as.matrix()
  y_train <- tmp %>% filter(split == "train") %>% pull(y)
  fit <- cv.glmnet(x_train, y_train)

  x_test <- tmp %>% filter(split == "test") %>% select(starts_with("V")) %>% as.matrix()
  y_test <- tmp %>% filter(split == "test") %>% pull(y)
  performances[[i]] <- data.frame(
    path = z_paths[i],
    split = c("train", "test"),
    mse = c(mean((predict(fit, x_train) - y_train) ^ 2), "test" = mean((predict(fit, x_test) - y_test) ^ 2))
  )
}
performances <- bind_rows(performances)
rownames(performances) <- NULL
performances <- performances %>%
  mutate(
    k = str_extract(path, "k[0-9]+"),
    k = as.integer(str_remove(k, "k")),
    b = str_extract(path, "/[0-9]+/") %>% str_remove_all("/"),
    model = str_extract(path, "tnbc_[A-z]+") %>% str_remove("tnbc_"),
    model = factor(model, levels = c("cnn", "vae", "rcf")),
    model = forcats::fct_recode(model, "CNN" = "cnn", "VAE" = "vae", "RCF" = "rcf"),
    split = factor(split, levels = c("train", "test"))
  )

baseline <- data.frame(
  split = c("train", "test"),
  mse = c(1.0976247318043064, 1.9567960844047563), # from notebooks/tnbc_baseline.ipynb
  k = NA, b = NA
)

baseline <- baseline[rep(1:2, 3), ]
baseline$model <- factor(rep(c("CNN", "VAE", "RCF"), each = 2), levels = c("CNN", "VAE", "RCF"))
ggplot(performances, aes(split, mse, col = as.factor(k))) +
  geom_point(alpha = 0.8) +
  geom_line(aes(group = interaction(b, k)), alpha = 0.2) +
  geom_line(data = baseline, col = "black", group = 1, size = 1.5) +
  geom_point(data = baseline, col = "black", size = 3) +
  facet_wrap(~ model) +
  labs(x = "Data Split", y = "Mean Squared Error", color = "K") +
  scale_color_brewer(palette = "Set1") +
  theme(
    strip.text = element_text(size = 14),
    legend.position = "right"
  )
ggsave("~/Desktop/tnbc_baseline.png", dpi = 200, width = 8, height = 3.7)


krisrs1128/MSLF documentation built on March 21, 2023, 10:21 a.m.