require(foreach) require(inspre) require(dplyr) require(ggplot2)
all_methods <- c("inspre", "inspre_DAG", "GIES", "LINGAM", "notears", "dotears", "golem", "igsp") #, "PC") null_metric <- list(precision=NA, recall=NA, F1=NA, rmse=NA, shd=NA, weight_acc=NA, time=NA) get_metrics_all <- function(pattern, n_iter){ res_all <- foreach::foreach(dir_name = Sys.glob(pattern), .combine = dplyr::bind_rows) %do% { setting <- strsplit(dir_name, "_")[[1]] graph <- setting[3] DAG <- setting[4] p <- setting[5] v <- setting[6] beta <- setting[7] C <- setting[8] result <- foreach::foreach(iter = 1:n_iter, .combine = dplyr::bind_rows) %do% { G_true <- read.table(paste0(dir_name, "/G_true_", iter, ".txt")) method_metrics <- foreach(method = all_methods, .combine = bind_rows) %do% { time_file <- paste0(dir_name, "/time_", method, "_", iter, ".txt") if(file.exists(time_file)){ G_hat <- read.table(paste0(dir_name, "/G_", method, "_", iter, ".txt")) metrics <- inspre::calc_metrics(G_hat, G_true, if_else(v=="03", 0.3/4, 0.15/4)) metrics$method <- method metrics$time <- read.table(time_file)[[1]] return(metrics) } else { metrics <- null_metric metrics$method <- method metrics$time <- NA } return(metrics) } method_metrics$iter <- iter return(method_metrics) } result <- result %>% dplyr::mutate(graph=graph, DAG=DAG, confounding = if_else(C>0, "True", "False"), density = if_else(p %in% c("008", "014", "01"), "High", "Low"), effects = if_else(v=="03", "Large", "Small"), instruments = if_else(startsWith(beta, "2"), "Strong", "Weak"), p=p, v=v, beta=beta, C=C) return(result) } return(res_all) }
res_all <- get_metrics_all("_rslurm_*", 10) res_all$method <- factor(res_all$method, levels = c('inspre', 'inspre_DAG', 'GIES', 'LINGAM', 'notears', 'dotears', 'golem', 'igsp'))
res_table <- res_all %>% group_by(graph, DAG, confounding, density, effects, instruments, method) %>% summarize(precision = mean(precision, na.rm=T), recall = mean(recall, na.rm=T), F1 = mean(F1, na.rm=T), mae = mean(mae, na.rrm=T), shd = mean(shd, na.rm=T), time = mean(time, na.rm=T)) res_table <- res_table %>% dplyr::mutate(mae = if_else(method %in% c("GIES", "igsp"), NA, mae))
options(scipen=999) options(digits=4) res_all %>% filter(confounding=="False", DAG=="FALSE") %>% group_by(method) %>% summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T), mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T), time = mean(time, na.rm=T)) res_all %>% filter(confounding=="True", DAG=="FALSE") %>% group_by(method) %>% summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T), mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T), time = mean(time, na.rm=T)) res_all %>% filter(confounding=="False", DAG=="TRUE") %>% group_by(method) %>% summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T), mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T), time = mean(time, na.rm=T)) res_all %>% filter(confounding=="True", DAG=="TRUE") %>% group_by(method) %>% summarize(precision = 100*mean(precision, na.rm=T), recall = 100*mean(recall, na.rm=T), F1 = 100*mean(F1, na.rm=T), mae = mean(mae, na.rm=T), shd = mean(shd, na.rm=T), time = mean(time, na.rm=T))
facet1_names <- c(Strong = "Strong instruments", Weak = "Weak instruments") facet2_names <- c(random = "Random graphs", scalefree = "Scalefree graphs") res_all %>% filter(DAG == "FALSE", confounding == "True", density=="High", effects=="Large") %>% ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + labs(title = "High density graphs with cycles, confounding and large edge weights", x = "Structural hamming distance", y = "Time (S)", color="Method") + theme_bw(base_size=9) + theme(plot.title = element_text(hjust = 0.5), # axis.title = element_text(size=9), # axis.text = element_text(size=8), legend.title = element_text(size=9), legend.spacing = unit(0, 'cm'), legend.key.size = unit(0.35, 'cm')) + scale_colour_brewer(palette = "Paired") + facet_grid(graph ~ instruments, labeller = labeller(instruments=facet1_names, graph=facet2_names))
res_all %>% filter(graph == "random", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Strong") %>% ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + labs(title = "Random, confounded, cyclic graphs.\nHigh density, large effects, strong instruments", x = "Structural hamming distance", y = "Time (S)") + xlim(0, 900) + theme_classic() + theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) + scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "random", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Weak") %>% ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + labs(title = "Random, confounded, cyclic graphs.\nHigh density, large effects, weak instruments", x = "Structural hamming distance", y = "Time (S)") + xlim(0, 900) + theme_classic() + theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) + scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "scalefree", DAG == "FALSE", confounding == "True", density=="High", effects=="Large", instruments=="Strong") %>% ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + labs(title = "Scalefree, confounded, cyclic graphs.\nHigh density, large effects, strong instruments", x = "Structural hamming distance", y = "Time (S)") + xlim(0, 900) + theme_classic() + theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) + scale_colour_brewer(palette = "Paired")
res_all %>% filter(graph == "scalefree", DAG == "TRUE", confounding == "False", density=="Low", effects=="Small", instruments=="Weak") %>% ggplot(aes(shd, time, color=method)) + geom_point() + scale_y_log10(labels = scales::comma) + labs(title = "Scalefree, confounded, cyclic graphs.\nHigh density, large effects, weak instruments", x = "Structural hamming distance", y = "Time (S)") + xlim(0, 250) + theme_classic() + theme(plot.title = element_text(hjust = 0.5, size=13), axis.title.y = element_text(angle = 0, vjust=0.5)) + scale_colour_brewer(palette = "Paired")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.