set.seed(1234) library(glmnet) library(LFBCR) library(tidyverse) library(reticulate) np <- import("numpy") theme_set(min_theme())
attach(params) layer_prefix <- "*_best.npy" z_paths <- list.files(input_dir, layer_prefix, recursive = T, full = T) Zb <- map(z_paths, ~ drop(np$load(.))) Xy <- read_csv(list.files(input_dir, "Xy.csv", recursive = T, full = T)[1])
ix <- list.files(input_dir, "*subset*", recursive = TRUE, full = T) %>% .[[1]] %>% read_csv() %>% rename(ix = X1) %>% left_join(Xy)
performances <- list() for (i in seq_along(z_paths)) { tmp <- as.data.frame(Zb[[i]]) %>% mutate(ix = ix$ix) %>% left_join(ix) x_train <- tmp %>% filter(split == "train") %>% select(starts_with("V")) %>% as.matrix() y_train <- tmp %>% filter(split == "train") %>% pull(y) fit <- cv.glmnet(x_train, y_train) x_test <- tmp %>% filter(split == "test") %>% select(starts_with("V")) %>% as.matrix() y_test <- tmp %>% filter(split == "test") %>% pull(y) performances[[i]] <- data.frame( path = z_paths[i], split = c("train", "test"), mse = c(mean((predict(fit, x_train) - y_train) ^ 2), "test" = mean((predict(fit, x_test) - y_test) ^ 2)) ) }
performances <- bind_rows(performances) rownames(performances) <- NULL performances <- performances %>% mutate( k = str_extract(path, "k[0-9]+"), k = as.integer(str_remove(k, "k")), b = str_extract(path, "/[0-9]+/") %>% str_remove_all("/"), model = str_extract(path, "tnbc_[A-z]+") %>% str_remove("tnbc_"), model = factor(model, levels = c("cnn", "vae", "rcf")), model = forcats::fct_recode(model, "CNN" = "cnn", "VAE" = "vae", "RCF" = "rcf"), split = factor(split, levels = c("train", "test")) ) baseline <- data.frame( split = c("train", "test"), mse = c(1.0976247318043064, 1.9567960844047563), # from notebooks/tnbc_baseline.ipynb k = NA, b = NA ) baseline <- baseline[rep(1:2, 3), ] baseline$model <- factor(rep(c("CNN", "VAE", "RCF"), each = 2), levels = c("CNN", "VAE", "RCF"))
ggplot(performances, aes(split, mse, col = as.factor(k))) + geom_point(alpha = 0.8) + geom_line(aes(group = interaction(b, k)), alpha = 0.2) + geom_line(data = baseline, col = "black", group = 1, size = 1.5) + geom_point(data = baseline, col = "black", size = 3) + facet_wrap(~ model) + labs(x = "Data Split", y = "Mean Squared Error", color = "K") + scale_color_brewer(palette = "Set1") + theme( strip.text = element_text(size = 14), legend.position = "right" ) ggsave("~/Desktop/tnbc_baseline.png", dpi = 200, width = 8, height = 3.7)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.