Objective

Evaluate a model trained on the cross-sectional population on case-control cohorts with variable matching.

library(knitr); library(pander); library(tibble); library(caret)
library(AnalysisToolkit)
devtools::load_all()

knitr::opts_chunk$set(comment="#>",
                      fig.show='hold', fig.align="center", fig.height=8, fig.width=8,
                      message=FALSE,
                      warning=FALSE, cache=FALSE, rownames.print=FALSE)

ggplot2::theme_set(vizR::theme_())

set.seed(123)

Report Options

kRUN_STATS <- TRUE

Load Models

kDIR_BY_COHORT <- Fp_ml_dir("by_cohort")
InFp <- function(fn) {
    file.path(kDIR_BY_COHORT, fn)
}
models <- readRDS(file = InFp("trained_models.rds"))
load(InFp("test_cohorts.Rdata"))

data("caseControlCohorts", package = "hips")

Originally, I had been using models that trained on variable patient populations, but after discussing with the team, I'm going to just the model trained on the cross-sectional population.

model <- models$full
test_cohort <- test_cohorts$full

Evaluation

Test Cohorts

For the test sets, I'll have to take the cross-sectional test set and then filter by membership in different cohorts.

test_cohort %<>% unnest(pre)

matched_test_cohorts <- map(caseControlCohorts, ~filter(test_cohort, img %in% .x))

test_cohorts <- c(list(test_cohort), matched_test_cohorts)
names(test_cohorts) <- hipsOpt(cohorts)

Inference

# Model inference ----
pY_lst <- map(test_cohorts, ~predict_pY(model, newdata=.x))
Y_lst <- map(test_cohorts, "fx")


# Assert disjoint image partitions
train_imgs <- model %>% train_eg_ids()
test_img_sets <- test_cohorts %>% map("img")
map(test_img_sets, ~ .x %in% train_imgs) %>% 
    map_lgl(any) %>% 
    map_lgl(any) %>% 
    `!` %>% 
    stopifnot()

Build classifier curves

# Adl ----
adl_cohorts <- adlCohort()
adl_cCs <- map(adl_cohorts, ~ClassifierCurve(pY = .x$pY, Y = .x$fx))

# Msh ----
msh_cCs <- map2(pY_lst, Y_lst,
     ~ClassifierCurve(pY = ..1, Y = ..2))

Tables

msh_cCs %>%
    glance_pretty() %>% 
    mutate(Classifier = mapvalues(Classifier,
                                  from = names(AES_COHORT_LABELS),
                                  to = unname(AES_COHORT_LABELS))) %>% 
    mapnames("Classifier", "Test Cohort") %>% 
    arrange(desc(auc)) %>% 

    kable(caption = "MSH Fracture prediction evaluated with various cohorts.")

adl_pretty_perf <- adl_cCs %>%
    glance_pretty() %>% 
    mutate(Classifier = mapvalues(Classifier,
                                  from = names(AES_COHORT_LABELS),
                                  to = unname(AES_COHORT_LABELS))) %>% 
    mapnames("Classifier", "Test Cohort") %>% 
    arrange(desc(auc)) 

adl_pretty_perf %>% 
    kable(caption = "Adl Fracture prediction evaluated with various cohorts.")

Tbl(adl_pretty_perf, bn = "SuppTable10_adlByCohortPerf")

Bootstrap Comparison

# Compare method uses spc separator
compare_cCs(msh_cCs, pilot = FALSE) %>% 
    arrange(desc(p.value)) %>% 
    mutate(p.value = ifelse(p.value < 0.05,
                      yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE),
                      kableExtra::cell_spec(p.value, color = "red", bold = FALSE)),
           cC1 = mapvalues(cC1,
                           from = names(AES_COHORT_LABELS) %>% str_case_snake(),
                           to = unname(AES_COHORT_LABELS)),
           cC2 = mapvalues(cC2,
                           from = names(AES_COHORT_LABELS) %>% str_case_snake(),
                           to = unname(AES_COHORT_LABELS))) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% 
    knitr::kable(format = "markdown", caption = "Comparing MSH Fracture Models predicted by Various Predictor Sets.") %>%
    kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE)

comp_df <- compare_cCs(adl_cCs, pilot = FALSE) %>% 
    arrange(desc(p.value)) %>% 
    mutate(cC1 = mapvalues(cC1,
                           from = names(AES_COHORT_LABELS) %>% str_case_snake(),
                           to = unname(AES_COHORT_LABELS)),
           cC2 = mapvalues(cC2,
                           from = names(AES_COHORT_LABELS) %>% str_case_snake(),
                           to = unname(AES_COHORT_LABELS)))

comp_df %>% 
    mutate(p.value = ifelse(p.value < 0.05,
                      yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE),
                      kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>% 
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% 
    knitr::kable(format = "markdown", caption = "Comparing Adelaide Fracture Models predicted by Various Predictor Sets.") %>%
    kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE)

comp_df %>%
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    Tbl(bn = "SuppTable11_adlByCohortComp")

comp_df %>% 
    mutate(p.value = ifelse(p.value < 0.05,
                      yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE),
                      kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>% 
    mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% 
    mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% 
    select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% 
    knitr::kable(format = "html", escape = FALSE, caption = "Comparing Adelaide Fracture Models predicted by Various Predictor Sets.") %>%
    kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE) %>% 
    Tbl(bn = "SuppTable11_adlByCohortComp", tbl_type = "html")

Overal Perf Views

# MSH ----
# Target learnability
# ROC by target
perf_tbl <- glance(msh_cCs) %>%
    mapnames("Classifier", "cohort")
DATA <- perf_tbl

DATA %<>% arrange(desc(auc))
DATA %<>% mapnames(from = "model_id", to = "cohort")

# Standardize order
DATA <- DATA[c(3,1,2,4), ]


gg_roc <- ggplot(DATA,
                 aes(x = cohort,
                     y = auc,
                     ymin = auc_lower,
                     ymax = auc_upper)) +
    geom_errorbar(width = 0.25, position = Pd()) +
    geom_point(position = Pd()) +
    scale_x_cohort() +
    geom_text(aes(label = is_sig), nudge_x = 0.2)
gg_roc <- gg_roc +
    geom_hline(yintercept = 0.5, alpha = 0.5, linetype = 2) +
    geom_hline(yintercept = 1, alpha = 0.5) +
    theme(axis.text.x = element_text(angle = 90)) +
    aes_cohort_col() +
    labs(y = "AUROC +/- 95% bootstrap CI") +
    theme(strip.text.y = element_text(angle = 0),
          axis.text.x = element_text(vjust = 0.5)) +
    scale_y_continuous(breaks = c(0.5, 0.75, 1), labels = c(0.5, 0.75, 1)) +
    theme(legend.position = "bottom", legend.direction = "vertical")
GG_ROC <- gg_roc %+%
    coord_flip() +
    #labs(title = "MSH: Different methods of subsampling a test-set lead to markedly different performance scores") +
    theme(plot.title = element_text(hjust = 1))
MSH_ROC <- GG_ROC


# Adl ----
# ROC by target
adl_perf_tbl <- glance(adl_cCs) %>%
    mapnames("Classifier", "cohort")

DATA <- adl_perf_tbl
DATA %<>% arrange(desc(auc))

DATA <- DATA[c(1, 4, 2, 3), ]
DATA$cohort %<>% as.factor %>% fct_reorder(DATA$auc, .desc = FALSE)

PLT_DAT_ADL_ROC_SUM <- DATA

gg_roc <- ggplot(DATA,
                 aes(x = cohort,
                     y = auc,
                     ymin = auc_lower,
                     ymax = auc_upper)) +
    geom_errorbar(width = 0.25, position = Pd()) +
    geom_point(position = Pd()) +
    scale_x_cohort() +
    geom_text(aes(label = is_sig), nudge_x = 0.2)
gg_roc <- gg_roc +
    geom_hline(yintercept = 0.5, alpha = 0.5, linetype = 2) +
    geom_hline(yintercept = 1, alpha = 0.5) +
    theme(axis.text.x = element_text(angle = 90)) +
    aes_cohort_col() +
    labs(y = "AUROC +/- 95% bootstrap CI") +
    theme(strip.text.y = element_text(angle = 0),
          axis.text.x = element_text(vjust = 0.5)) +
    scale_y_continuous(breaks = c(0.5, 0.75, 1), labels = c(0.5, 0.75, 1)) +
    theme(legend.position = "bottom", legend.direction = "vertical")
GG_ROC <- gg_roc %+%
    coord_flip() +
    #labs(title = "Adelaide: Different methods of subsampling a test-set lead to markedly different performance scores") +
    theme(plot.title = element_text(hjust = 1))
ADL_ROC <- GG_ROC


# Combine
MSH_ROC %<>% `+`(NL)
ADL_ROC %<>% `+`(NL)
ADL_ROC %<>% `+`(theme(axis.text.y = element_blank(), axis.title.y = element_blank(), axis.title.x = element_blank(), axis.ticks.y = element_blank()))

cowplot::plot_grid(MSH_ROC, ADL_ROC, rel_widths = c(1, 0.5), align = 'h')
# MSH ----
# ROC PRROC
roc_XY_byClassifier <- purrr::map(msh_cCs, gg_data_roc)
roc_XY_byGeom <- purrr::transpose(roc_XY_byClassifier)
roc_xy_geoms <- purrr::map(roc_XY_byGeom, purrr::lift_dl(dplyr::bind_rows, .id = "cohort"))

prc_XY_byGeom <- map(msh_cCs, gg_data_prc) %>% transpose()
prc_xy_geoms <- map(prc_XY_byGeom, lift_dl(bind_rows, .id="cohort"))

gg_roc <- ggplot2::ggplot(roc_xy_geoms$lines, ggplot2::aes(x=x, y=y, col = cohort)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = roc_xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = roc_xy_geoms$points) +
    ggplot2::geom_text(data = roc_xy_geoms$points, aes(label = is_sig), size = 5, nudge_y = 0.03, show.legend=FALSE) +
    gg_style_roc +
    labs(title = "MSH Whole image \nReceiver Operator Curve") +
    scale_color_cohort()

gg_prc <- ggplot2::ggplot(prc_xy_geoms$lines, ggplot2::aes(x=x, y=y, col=cohort)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = prc_xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = prc_xy_geoms$points) +
    gg_style_prc +
    labs(title = "MSH Whole image \nPrecision-Recall Curve") +
    scale_color_cohort()


gg_prc %<>% `+`(list(theme(legend.direction = "vertical", legend.position = "bottom")))

gg_prc %<>% `+`(guides(color = guide_legend(title = "Test Cohort", ncol = 2, title.hjust = 0.5)))
leg <- cowplot::get_legend(gg_prc)
gg_prc %<>% `+`(NL)
gg_roc %<>% `+`(NL)

MSH_ROC_SUMMARY <- MSH_ROC
ADL_ROC_SUMMARY <- ADL_ROC
MSH_ROC <- gg_roc
MSH_PRC <- gg_prc
# GG <- cowplot::plot_grid(
#     cowplot::plot_grid(gg_roc, gg_prc, ncol = 2, align='hv'),
#     cowplot::plot_grid(NULL, leg, NULL, ncol = 3, rel_widths = c(1, .1, 1)),
#     ncol = 1, rel_heights = c(.6, 0.4))
# GG

# ADL ----
# ROC PRROC
roc_XY_byClassifier <- purrr::map(adl_cCs, gg_data_roc)
roc_XY_byGeom <- purrr::transpose(roc_XY_byClassifier)
roc_xy_geoms <- purrr::map(roc_XY_byGeom, purrr::lift_dl(dplyr::bind_rows, .id = "cohort"))

PLT_DAT_ADL_ROC <- roc_xy_geoms

prc_XY_byGeom <- map(adl_cCs, gg_data_prc) %>% transpose()
prc_xy_geoms <- map(prc_XY_byGeom, lift_dl(bind_rows, .id="cohort"))

PLT_DAT_ADL_PRC <- prc_xy_geoms

gg_roc <- ggplot2::ggplot(roc_xy_geoms$lines, ggplot2::aes(x=x, y=y, col = cohort)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = roc_xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = roc_xy_geoms$points) +
    ggplot2::geom_text(data = roc_xy_geoms$points, aes(label = is_sig), size = 5, nudge_y = 0.03, show.legend=FALSE) +
    gg_style_roc +
    labs(title = "ADL Zoomed Hip \nReceiver Operator Curve") +
    scale_color_cohort()

gg_prc <- ggplot2::ggplot(prc_xy_geoms$lines, ggplot2::aes(x=x, y=y, col=cohort)) +
    ggplot2::geom_line() +
    ggplot2::geom_point(data = prc_xy_geoms$points, shape=10, size = 3) +
    ggplot2::geom_rug(data = prc_xy_geoms$points) +
    gg_style_prc +
    labs(title = "ADL Zoomed Hip \nPrecision-Recall Curve") +
    scale_color_cohort()


gg_prc %<>% `+`(list(theme(legend.direction = "vertical", legend.position = "bottom")))

gg_prc %<>% `+`(guides(color = guide_legend(title = "Test Cohort", ncol = 2, title.hjust = 0.5)))
leg <- cowplot::get_legend(gg_prc)
gg_prc %<>% `+`(NL)
gg_roc %<>% `+`(NL)

ADL_PRC <- gg_prc
ADL_ROC <- gg_roc
# GG <- cowplot::plot_grid(
#     cowplot::plot_grid(gg_roc, gg_prc, ncol = 2, align='hv'),
#     cowplot::plot_grid(NULL, leg, NULL, ncol = 3, rel_widths = c(1, .1, 1)),
#     ncol = 1, rel_heights = c(.6, 0.4))
# GG
# cowplot::save_plot("analysis/figures/ROC_PRROC.png", GG, base_width = 4, base_height = 6)

cowplot::plot_grid(
    cowplot::plot_grid(MSH_ROC, ADL_ROC, ncol = 2, align='hv'),
    cowplot::plot_grid(MSH_PRC, ADL_PRC, ncol = 2, align='hv'),
    ncol = 1
)

Odds Ratios

ADL_OR_DAT <- adlCohort(binarized=TRUE) %>%
    OddsRatios(grp_chr = "cohort")
OR_ADL <- ADL_OR_DAT %>% 
    ggOddsRatios() %+% 
    aes_cohort_col()

cohort_imgs <- map(test_cohorts, "img")
mshBinary <- hipsCohort(mutating = binary)
mshBinCohorts <- map(cohort_imgs, ~filter(mshBinary, img %in% .x))    

OR_MSH <- mshBinCohorts %>%
    OddsRatios(grp_chr = "cohort") %>%
    ggOddsRatios() %+% 
    aes_cohort_col()

# Build multiview
LEG <- cowplot::get_legend(OR_ADL)
OR_ADL %<>% `+`(NL)
OR_MSH %<>% `+`(NL)

cowplot::plot_grid(
    cowplot::plot_grid(OR_MSH, OR_ADL, ncol = 2),
    cowplot::plot_grid(NULL, leg, NULL, ncol = 3, rel_widths = c(1, .1, 1)),
    ncol = 1, rel_heights = c(1, .1)
)
FP_OUT_INT_PLT_DATA <- file.path(kDIR_BY_COHORT, "adl_plot_data.Rdata")
save(ADL_ROC_SUMMARY, PLT_DAT_ADL_ROC_SUM, PLT_DAT_ADL_ROC, PLT_DAT_ADL_PRC, ADL_OR_DAT, file = FP_OUT_INT_PLT_DATA)

Codebase State: r dotfileR::gitState() on r date()



mbadge/hipsMultimodal documentation built on May 9, 2019, 12:05 a.m.