knitr::opts_chunk$set(echo = TRUE)
library(here)
library(tidyverse)
library(glue)
library(ggsci)
library(patchwork)
source(here("R", "plot_utils.R"))
library(gridExtra)
library(ggpubr)
library(ggbeeswarm)

theme_set(theme_nice())

Plotting for predictive performance

Support link here

condition <- "crc"
sequencing <- "16s"

files <- list.files(here("output", "pred"), pattern = "results", full.names = TRUE, 
                    include.dirs = FALSE)

df_list <- map(files, ~read_csv(.x) %>% dplyr::select(-1) %>% 
                   mutate(sequencing = as.character(sequencing)))

plot_df <- do.call(bind_rows, df_list) %>% 
    mutate(sequencing = if_else(sequencing == "16", "16s", sequencing)) %>% 
    pivot_longer(c(brier, roc_auc), names_to = "statistic", values_to = "values") %>% 
    mutate(input_type = recode(input_type, "picrust2" = "PICRUSt2", 
                               "trait" = "CBEA trait scores", 
                               "pathways" = "MetaCyc pathway abundance"),
           statistic = recode(statistic, "brier" = "Brier Score", "roc_auc" = "AUROC"), 
           condition = recode(condition, "crc" = "Colorectal Cancer", "ibd" = "Inflammatory Bowel Disease"))


(
    metabar_plt <- plot_df %>% filter(sequencing == "16s") %>% 
        ggplot(aes(x = input_type, y = values, col = input_type, fill = input_type)) + 
        facet_grid(statistic~condition, scales = "free") +
        scale_fill_npg() +
        geom_boxplot(alpha = 0.5) + geom_jitter() + 
        labs(x = "Input types", y = "Performance", title = "16S rRNA gene metabarcoding") + 
        theme(legend.position = "None")
)

(
    wgs_plt <- plot_df %>% filter(sequencing == "wgs") %>% 
        ggplot(aes(x = input_type, y = values, col = input_type, fill = input_type)) + 
        facet_grid(statistic~condition, scales = "free") +
        scale_fill_npg() +
        geom_boxplot(alpha = 0.5) + geom_jitter() + 
        labs(x = "Input types", y = "Performance", title = "Whole genome metagenomics") + 
        theme(legend.position = "None")
)

combo_plot <- metabar_plt / wgs_plt


ggsave(combo_plot, filename = here("output", "figures", "pred_performance.png"), dpi = 300, 
       width = 8, height = 8)
ggsave(combo_plot, filename = here("output", "figures", "pred_performance.eps"), dpi = 300, 
       width = 8, height = 8, device = cairo_ps)

Plot feature importances

annotation <- read_delim(file = "https://github.com/picrust/picrust2/raw/master/picrust2/default_files/description_mapfiles/metacyc_pathways_info.txt.gz", delim = "\t", col_names = FALSE)

colnames(annotation) <- c("pathway", "annotation")

WGS feature importance

imp_files <- list.files(here("output", "pred"), pattern = "fimportance", full.names = TRUE)

imp_metadata <- str_split(imp_files, "\\/") %>% 
    map_dfr(~{
        d_vec <- .x[length(.x)] %>% 
            str_remove("fimportance_") %>% 
            str_remove(".csv") %>% str_split("_") %>% .[[1]] %>% 
            recode(
                "ibd" = "Inflammatory bowel disease", 
                "crc" = "Colorectal cancer", 
                "picrust2" = "PICRUSt2 predicted abundances",
                "16s" = "16s rRNA metabarcoding", 
                "wgs" = "Whole genome metagenomics", 
                "trait" = "CBEA trait scores", 
                "pathways" = "Pathway abundances"
            )
        tibble(
            condition = d_vec[1],
            seq = d_vec[2],
            input = d_vec[3]
        )
    })

plot_imp_df <- imp_metadata %>% mutate(path = imp_files) %>% 
    mutate(data = map(path, ~{
        imp_df <- read_csv(.x)
        imp_df <- imp_df %>% select(-1) %>% 
            select(-performance) %>% 
            pivot_longer(everything(), names_to = "pathway") %>%
            left_join(annotation) %>% 
            mutate(full_name = paste(pathway, annotation, sep = "-")) %>% 
            mutate(full_name = str_remove_all(full_name, "-NA")) %>%
            group_by(full_name) #%>% 
            #summarise(m = mean(value), 
            #          se = sd(value)/sqrt(n())) %>% 
            #arrange(-m) %>% 
            #mutate(u = m + se, l = m - se)
            #mutate(m = mean(value)) 
    })) %>%
    mutate(performance = map_chr(path, ~{
        imp_df <- read_csv(.x)
        imp_df %>% pull(performance) %>% unique()
    })) %>% 
    unnest(data) %>% 
    dplyr::rename("features" = "full_name") %>% 
    mutate(features = str_replace(features, ";", "; ") %>% str_replace_all("_", " "))

wgs_imp <- plot_imp_df %>% filter(seq == "Whole genome metagenomics")
plot_grid <- cross_df(list(
    ipt = c("CBEA trait scores", "Pathway abundances"), 
    cond = c("Inflammatory bowel disease", "Colorectal cancer")
))
plot_list <- pmap(plot_grid, function(ipt, cond, ...){
    pltdf <- wgs_imp %>% filter(condition == {{ cond }}, 
                       input == {{ ipt }})
    perf <- pltdf %>% pull(performance) %>% unique() %>% as.numeric()
    feat_levels <- pltdf %>% group_by(features) %>% summarise(m = mean(value)) %>% 
        arrange(m) %>% slice(1:10) %>% pull(features)

    pltdf <- pltdf %>% filter(features %in% feat_levels) %>%
        group_by(features) %>% summarise(m = mean(value)) %>%
        mutate(features = factor(str_wrap(features, width = 30), 
                                 levels = str_wrap(feat_levels, width = 30)))
    print(pltdf)
    out <- ggplot(pltdf, 
                  aes(x = features, y = m)) + 
        #geom_pointrange(aes(ymin = l, ymax = u)) +
        geom_bar(stat = "identity", fill = pal_npg()(1)) +
        coord_flip() +
        labs(y = "Feature Importance (10-fold cross validation)", 
            x = "Features", 
            title = glue("{cond}", cond = cond, ipt = ipt), 
            caption = glue("Test set AUROC: {perf}", 
                            perf = round(perf, digits = 2)), 
            subtitle = glue("{ipt}", ipt = ipt)) + 
        theme(legend.position = "none", axis.title = element_blank()) 
    return(out)
})

gt <- patchwork::patchworkGrob(Reduce("+", plot_list))

x_axis <- text_grob("Features", family = "Source Sans Pro", face = "bold",rot = 90)
y_axis <- text_grob("Feature importance (10-fold cross-validation)",
                    family = "Source Sans Pro", face = "bold") 

feat_importance_wgs <- gridExtra::grid.arrange(gt, left = x_axis, 
                        bottom = y_axis)

ggsave(feat_importance_wgs, 
       filename = here("output", "figures", "feat_importance_wgs.png"), 
       width = 10, height = 11)  
ggsave(feat_importance_wgs, 
       filename = here("output", "figures", "feat_importance_wgs.eps"), 
       width = 10, height = 11, device = cairo_ps)

16S feature importance

metabar_imp <- plot_imp_df %>% filter(seq == "16s rRNA metabarcoding")
plot_grid <- cross_df(list(
    ipt = c("CBEA trait scores", "PICRUSt2 predicted abundances"), 
    cond = c("Inflammatory bowel disease", "Colorectal cancer")
))

plots_list <- pmap(plot_grid, function(ipt, cond, ...){
    pltdf <- metabar_imp %>% filter(condition == {{ cond }}, 
                       input == {{ ipt }})
    perf <- pltdf %>% pull(performance) %>% unique() %>% as.numeric()
    feat_levels <- pltdf %>% group_by(features) %>% summarise(m = mean(value)) %>% 
        arrange(m) %>% slice(1:10) %>% pull(features)

    pltdf <- pltdf %>% filter(features %in% feat_levels) %>%
        group_by(features) %>% summarise(m = mean(value)) %>%
        mutate(features = factor(str_wrap(features, width = 30), 
                                 levels = str_wrap(feat_levels, width = 30)))

    out <- ggplot(pltdf, 
                  aes(x = features, y = m)) + 
        #geom_pointrange(aes(ymin = l, ymax = u)) +
        geom_bar(stat = "identity", fill = pal_npg()(1)) +
        coord_flip() +
        labs(y = "Feature importance (10-fold cross-validation)", 
            x = "Features", 
            title = glue("{cond}", cond = cond, ipt = ipt), 
            caption = glue("Test set AUROC: {perf}", 
                            perf = round(perf, digits = 2)), 
            subtitle = glue("{ipt}", ipt = ipt)) + 
        theme(legend.position = "none", axis.title = element_blank()) 
    return(out)
})

gt <- patchwork::patchworkGrob(Reduce("+", plots_list))

x_axis <- text_grob("Features", family = "Source Sans Pro", face = "bold",rot = 90)
y_axis <- text_grob("Feature importance (10-fold cross-validation)",
                    family = "Source Sans Pro", face = "bold") 

feat_importance_16s <- gridExtra::grid.arrange(gt, left = x_axis, 
                        bottom = y_axis)

ggsave(feat_importance_16s, 
       filename = here("output", "figures", "feat_importance_16s.png"), 
       width = 10, height = 11)  
ggsave(feat_importance_16s, 
       filename = here("output", "figures", "feat_importance_16s.eps"), 
       width = 10, height = 11, device = cairo_ps)

Copy files

if (Sys.info()["sysname"] == "Darwin"){
    path_to <- "../../microbe_trait_manuscript/figures/"
}

file.copy(Sys.glob(here("output", "figures", "*.eps")), 
          to = path_to,
          recursive = TRUE, overwrite = TRUE)

file.copy(Sys.glob(here("output", "figures", "*.png")), 
          to = path_to,
          recursive = TRUE, overwrite = TRUE)


qpmnguyen/microbe_set_trait documentation built on April 11, 2024, 6:07 p.m.