Objective

Evaluate a collection of models.

library(knitr); library(pander); library(tibble); library(caret)
library(AnalysisToolkit)
devtools::load_all()

knitr::opts_chunk$set(comment="#>",
                      fig.show='hold', fig.align="center", fig.height=8, fig.width=8,
                      message=FALSE,
                      warning=FALSE, cache=FALSE, rownames.print=FALSE)

ggplot2::theme_set(vizR::theme_())

set.seed(123)

Report Inputs

# All model products contained in multimodal directory
kDIR_MULTIMODAL <- hips::Fp_ml_dir("multimodal")

FpMultimodal <- function(s_chr) {
    MyUtils::fp_mostRecent(file.path(kDIR_MULTIMODAL, s_chr), verbose = TRUE)
}

fp_models <- FpMultimodal("trained_models.rds")
FUNC_load_models <- readRDS

fp_cohort <- FpMultimodal("cohorts.Rdata")
FUNC_load_cohort <- load

# Preconditions
stopifnot(all(map_lgl(.x = c(fp_models, fp_cohort), .f = file.exists)))

Load Models

models <- FUNC_load_models(fp_models)
FUNC_load_cohort(fp_cohort)

Evaluation

Inference

# Model inference ----
pY_lst <- map(models, predict_pY, newdata = test_df)
Y <- test_df[["fx"]]
cCs <- imap(pY_lst, ~ClassifierCurve(pY = .x, Y = Y, id = .y))

Tables

pretty_perf <- cCs %>% 
    AnalysisToolkit::glance_pretty() %>% 
    arrange(desc(auc)) %>% 
    mutate(Classifier = mapvalues(Classifier, 
                                  from = names(AES_PREDICTOR_LABELS),
                                  to = unname(AES_PREDICTOR_LABELS))) %>% 
    mapnames("Classifier", "Predictor Set")

pretty_perf %>% 
    kable(caption = "Performance of fracture models with various predictor sets")

Tbl(pretty_perf, bn = "SuppTable4_MultimodalPerf")

Bootstrap ROC Comparison

compare_tbl <- hips::compare_cCs(cCs, pilot = FALSE)
compare_tbl %<>% 
    map_df( ~.x[c("p.value", "method")], .id = "cC_pair") %>%
    tidyr::extract("cC_pair", c("cC1", "cC2"), "(\\w+)-(\\w+)")

compare_tbl <- AnalysisToolkit::compare_cCs(cCs, pilot = FALSE) %>% 
    #mutate(p.value = p.adjust(p.value, method = "bonferroni")) %>% 
    mutate(cC1 = mapvalues(cC1 %>% str_case_camel(), 
                           from = names(AES_PREDICTOR_LABELS),
                           to = unname(AES_PREDICTOR_LABELS)),
           cC2 = mapvalues(cC2 %>% str_case_camel(), 
                           from = names(AES_PREDICTOR_LABELS),
                           to = unname(AES_PREDICTOR_LABELS))) %>% 
    arrange(desc(p.value))

compare_tbl %>% 
    mutate(p.value = ifelse(p.value < 0.05,
                         yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE),
                         kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>%
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "DeLong's Test p-value"=p.value) %>% 
    knitr::kable(format = "markdown", caption = "Comparing Fracture Models predicted by Various Predictor Sets.") %>%
    kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE)

compare_tbl %>% 
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "DeLong's Test p-value"=p.value) %>% 
    Tbl(bn = "SuppTable5_MultimodalComp")

compare_tbl %>% 
    mutate(p.value = ifelse(p.value < 0.05,
                         yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE),
                         kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>% 
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "DeLong's Test p-value"=p.value) %>% 
    knitr::kable(format = "html", escape = FALSE, caption = "Comparing Fracture Models predicted by Various Predictor Sets.") %>%
    kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE) %>%
    Tbl(bn = "SuppTable5_MultimodalComp", tbl_type = "html")

Plots

Summary

# Target learnability
perf_tbl <- glance(cCs)

# ROC by target
DATA <- perf_tbl
DATA %<>% arrange(desc(auc))

PLT_DAT_ROC_SUM <- DATA[, c("Classifier", "auc", "auc_lower", "auc_upper", "is_sig")]  # Collect data intermediates


gg_roc <- ggplot(DATA,
                 aes(x = fct_reorder(Classifier, auc),
                     y = auc,
                     ymin = auc_lower,
                     ymax = auc_upper,
                     col = Classifier)) +
    geom_errorbar(width = 0.25, position = Pd()) +
    geom_point(position = Pd()) +
    geom_text(aes(label = is_sig), nudge_x = 0.2) +
    scale_predictor(color)
gg_roc <- gg_roc +
    geom_hline(yintercept = 0.5, alpha = 0.5, linetype = 2) +
    geom_hline(yintercept = 1, alpha = 0.5) +
    theme(axis.text.x = element_text(angle = 90)) +
    scale_x_target() +
    labs(y = "AUROC +/- 95% bootstrap CI") +
    theme(strip.text.y = element_text(angle = 0),
          axis.text.x = element_text(vjust = 0.5)) +
    scale_y_continuous(breaks = c(0.5, 0.75, 1), labels = c(0.5, 0.75, 1)) +
    theme(legend.position = "bottom", legend.direction = "vertical")
GG_ROC <- gg_roc %+%
    coord_flip() +
    labs(title = "A pre-trained CNN has non-random performance\non every classification target attempted") +
    theme(plot.title = element_text(hjust = 1))
GG_ROC

# cowplot::save_plot("analysis/figures/roc_multitarget.png",
#                    plot = GG_ROC, base_width = 4, base_height = 6)
# ROC PRROC
XY_byClassifier <- purrr::map(cCs, gg_data_roc)
XY_byGeom <- purrr::transpose(XY_byClassifier)
xy_geoms <- purrr::map(XY_byGeom, purrr::lift_dl(dplyr::bind_rows, .id = "Classifier"))
xy_geoms$lines$Classifier %<>% as_factor() %>% fct_relevel(DATA$target)


PLT_DAT_ROC <- xy_geoms  # Save intermediate plot data

gg_roc <- ggplot2::ggplot(xy_geoms$lines, ggplot2::aes(x=x, y=y, col = Classifier)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = xy_geoms$points) +
    ggplot2::geom_text(data = xy_geoms$points, aes(label = is_sig), size = 5, nudge_y = 0.03, show.legend=FALSE) +
    gg_style_roc +
    scale_predictor(color)

XY_byGeom <- map(cCs, gg_data_prc) %>% transpose()
xy_geoms <- map(XY_byGeom, lift_dl(bind_rows, .id="Classifier"))
xy_geoms$lines$Classifier %<>% as_factor() %>% fct_relevel(DATA$target)


PLT_DAT_PRC <- xy_geoms  # Save intermediate plot data

gg_prc <- ggplot2::ggplot(xy_geoms$lines, ggplot2::aes(x=x, y=y, col=Classifier)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = xy_geoms$points) +
    gg_style_prc +
    scale_predictor(color)


gg_prc %<>% `+`(list(theme(legend.direction = "vertical", legend.position = "bottom")))

gg_prc %<>% `+`(guides(color = guide_legend(ncol = 2, title.hjust = 0.5)))
leg <- cowplot::get_legend(gg_prc)
gg_prc %<>% `+`(NL)
gg_roc %<>% `+`(NL)

GG <- cowplot::plot_grid(
    cowplot::plot_grid(gg_roc, gg_prc, ncol = 2, align='hv'),
    cowplot::plot_grid(NULL, leg, NULL, ncol = 3, rel_widths = c(1, .1, 1)),
    ncol = 1, rel_heights = c(.6, 0.4))
GG

# cowplot::save_plot("analysis/figures/ROC_PRROC.png", GG, base_width = 4, base_height = 6)

Model training stats

models %>% map(train_terms)
models %>% map(train_n_eg)
FP_OUT_INT_PLT_DATA <- file.path(kDIR_MULTIMODAL, "plot_data.Rdata")
save(PLT_DAT_ROC_SUM, PLT_DAT_ROC, PLT_DAT_PRC, file = FP_OUT_INT_PLT_DATA)

Codebase State: r dotfileR::gitState() on r date()



mbadge/hipsMultimodal documentation built on May 9, 2019, 12:05 a.m.