require(foreach)
require(inspre)
require(dplyr)
require(ggplot2)
all_methods <- c("inspre", "inspre_DAG", "GIES", "LINGAM", "notears", "dotears", "golem", "igsp") #, "PC")
null_metric <- list(precision=NA, recall=NA, F1=NA, rmse=NA, shd=NA, weight_acc=NA, time=NA)
get_metrics_all <- function(pattern, n_iter){
  res_all <- foreach::foreach(dir_name = Sys.glob(pattern), .combine = dplyr::bind_rows) %do% {
    setting <- strsplit(dir_name, "_")[[1]]
    graph <- setting[3]
    DAG <- setting[4]
    p <- setting[5]
    v <- setting[6]
    beta <- setting[7]
    C <- setting[8]
    result <- foreach::foreach(iter = 1:n_iter, .combine = dplyr::bind_rows) %do% {
      G_true <- read.table(paste0(dir_name, "/G_true_", iter, ".txt"))
      method_metrics <- foreach(method = all_methods, .combine = bind_rows) %do% {
        time_file <- paste0(dir_name, "/time_", method, "_", iter, ".txt")
        if(file.exists(time_file)){
          G_hat <- read.table(paste0(dir_name, "/G_", method, "_",  iter, ".txt"))
          metrics <- inspre::calc_metrics(G_hat, G_true, if_else(v=="03", 0.3/4, 0.15/4))
          metrics$method <- method
          metrics$time <- read.table(time_file)[[1]]
          return(metrics)
        } else {
          metrics <- null_metric
          metrics$method <- method
          metrics$time <- NA
        }
        return(metrics)
      }
      method_metrics$iter <- iter
      return(method_metrics)
    }
    result <- result %>% dplyr::mutate(graph=graph, DAG=DAG,
      confounding = if_else(C>0, "True", "False"),
      density = if_else(p %in% c("008", "014", "01"), "High", "Low"),
      effects = if_else(v=="03", "Large", "Small"),
      instruments = if_else(startsWith(beta, "2"), "Strong", "Weak"),
      p=p, v=v, beta=beta, C=C)
    return(result)
  }
  return(res_all)
}
res_all <- get_metrics_all("_rslurm_*", 10)
res_all$method <- factor(res_all$method, levels = c('inspre', 'inspre_DAG', 'GIES', 'LINGAM', 'notears', 'dotears', 'golem', 'igsp'))
res_table <- res_all %>% group_by(graph, DAG, confounding, density, effects, instruments, method) %>%
  summarize(precision = mean(precision, na.rm=T), recall = mean(recall, na.rm=T), F1 = mean(F1, na.rm=T),
            mae = mean(mae, na.rrm=T), shd = mean(shd, na.rm=T),
            time = mean(time, na.rm=T))
res_table <- res_table %>% dplyr::mutate(mae = if_else(method %in% c("GIES", "igsp"), NA, mae))
options(scipen=999)
options(digits=4)
res_all %>% filter(confounding=="False", DAG=="FALSE") %>% group_by(method) %>%
    summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T),
            mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T),
            time = mean(time, na.rm=T))

res_all %>% filter(confounding=="True", DAG=="FALSE") %>% group_by(method) %>%
    summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T),
            mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T),
            time = mean(time, na.rm=T))

res_all %>% filter(confounding=="False", DAG=="TRUE") %>% group_by(method) %>%
    summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T),
            mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T),
            time = mean(time, na.rm=T))

res_all %>% filter(confounding=="True", DAG=="TRUE") %>% group_by(method) %>%
    summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T),
            mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T),
            time = mean(time, na.rm=T))
facet1_names <- c(Strong = "Strong instruments", Weak = "Weak instruments")
facet2_names <- c(random = "Random graphs", scalefree = "Scalefree graphs")

res_all %>% filter(DAG == "FALSE", confounding == "True", density=="High", effects=="Large") %>% 
  ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + 
  labs(title = "High density graphs with cycles, confounding and large edge weights", x = "Structural hamming distance", y = "Time (S)", color="Method") +
  theme_bw(base_size=9) + 
  theme(plot.title = element_text(hjust = 0.5),
        # axis.title = element_text(size=9),
        # axis.text = element_text(size=8),
        legend.title = element_text(size=9),
        legend.spacing = unit(0, 'cm'),
        legend.key.size = unit(0.35, 'cm')) +
  scale_colour_brewer(palette = "Paired") +
  facet_grid(graph ~ instruments, labeller = labeller(instruments=facet1_names, graph=facet2_names))
res_all %>% filter(graph == "random", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Strong") %>% 
  ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + 
  labs(title = "Random, confounded, cyclic graphs.\nHigh density, large effects, strong instruments", x = "Structural hamming distance", y = "Time (S)") +
  xlim(0, 900) +
  theme_classic() + 
  theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) +
  scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "random", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Weak") %>% 
  ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + 
  labs(title = "Random, confounded, cyclic graphs.\nHigh density, large effects, weak instruments", x = "Structural hamming distance", y = "Time (S)") +
  xlim(0, 900) +
  theme_classic() + 
  theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) +
  scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "scalefree", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Strong") %>% 
  ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + 
  labs(title = "Scalefree, confounded, cyclic graphs.\nHigh density, large effects, strong instruments", x = "Structural hamming distance", y = "Time (S)") +
  xlim(0, 900) +
  theme_classic() + 
  theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) +
  scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "scalefree", DAG == "TRUE", confounding == "False", density=="Low", effects=="Small", instruments=="Weak") %>% 
  ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + 
  labs(title = "Scalefree, confounded, cyclic graphs.\nHigh density, large effects, weak instruments", x = "Structural hamming distance", y = "Time (S)") +
  xlim(0, 250) +
  theme_classic() + 
  theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) +
  scale_colour_brewer(palette = "Paired")


brielin/inspre documentation built on Dec. 3, 2023, 4:55 a.m.