Evaluate a model trained on the cross-sectional population on case-control cohorts with variable matching.
library(knitr); library(pander); library(tibble); library(caret) library(AnalysisToolkit) devtools::load_all() knitr::opts_chunk$set(comment="#>", fig.show='hold', fig.align="center", fig.height=8, fig.width=8, message=FALSE, warning=FALSE, cache=FALSE, rownames.print=FALSE) ggplot2::theme_set(vizR::theme_()) set.seed(123)
kRUN_STATS <- TRUE
kDIR_BY_COHORT <- Fp_ml_dir("by_cohort") InFp <- function(fn) { file.path(kDIR_BY_COHORT, fn) } models <- readRDS(file = InFp("trained_models.rds")) load(InFp("test_cohorts.Rdata")) data("caseControlCohorts", package = "hips")
Originally, I had been using models that trained on variable patient populations, but after discussing with the team, I'm going to just the model trained on the cross-sectional population.
model <- models$full test_cohort <- test_cohorts$full
For the test sets, I'll have to take the cross-sectional test set and then filter by membership in different cohorts.
matched_test_cohorts <- map(caseControlCohorts, ~filter(test_cohort, img %in% .x)) test_cohorts <- c(list(test_cohort), matched_test_cohorts) names(test_cohorts) <- c("crossSectional", "caseControl_matchNone", "caseControl_matchDem", "caseControl_matchPt", "caseControl_matchAll") save(test_cohorts, file = file.path(kDIR_BY_COHORT, "by_test_cohorts.Rdata"))
# Model inference ---- pY_lst <- map(test_cohorts, ~predict_pY(model, newdata=.x)) Y_lst <- map(test_cohorts, "fx") # Assert disjoint image partitions train_imgs <- model %>% train_eg_ids() test_img_sets <- test_cohorts %>% map("img") map(test_img_sets, ~ .x %in% train_imgs) %>% map_lgl(any) %>% map_lgl(any) %>% `!` %>% stopifnot() cCs <- map2(pY_lst, Y_lst, ~ClassifierCurve(pY = ..1, Y = ..2))
pretty_perf <- cCs %>% glance_pretty() %>% mutate(Classifier = mapvalues(Classifier, from = names(AES_COHORT_LABELS), to = unname(AES_COHORT_LABELS))) %>% mapnames("Classifier", "Test Cohort") %>% arrange(desc(auc)) pretty_perf %>% kable(caption = "Fracture prediction with pretrained image models applied to various cohorts.") Tbl(pretty_perf, bn = "SuppTable6_byCohortPerf")
# Compare method uses spc separator comp_tbl <- hips::compare_cCs(cCs, pilot = FALSE) cCs %>% map(cC2roc) %>% map(pROC::ci.auc) cCs %>% map(cC2roc) %>% map(pROC::ci.auc, method = "bootstrap") comp_tbl %<>% map_df( ~.x[c("p.value", "method", "statistic")], .id = "cC_pair") %>% tidyr::extract("cC_pair", c("cC1", "cC2"), "(\\w+)-(\\w+)") comp_tbl %>% arrange(desc(p.value)) %>% mutate(cC1 = mapvalues(cC1, from = names(AES_COHORT_LABELS) %>% str_case_snake(), to = unname(AES_COHORT_LABELS)), cC2 = mapvalues(cC2, from = names(AES_COHORT_LABELS) %>% str_case_snake(), to = unname(AES_COHORT_LABELS))) comp_tbl %>% mutate(p.value = ifelse(p.value < 0.05, yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE), kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>% mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% knitr::kable(format = "markdown", caption = "Comparing Fracture Models predicted by Various Predictor Sets.") %>% kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE) comp_tbl %>% mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% Tbl(bn = "SuppTable7_byCohortComp") comp_tbl %>% mutate(p.value = ifelse(p.value < 0.05, yes = kableExtra::cell_spec(p.value, color = "green", bold = TRUE), kableExtra::cell_spec(p.value, color = "red", bold = FALSE))) %>% mutate(p.value = str_replace(p.value, "([0-9]\\.[0-9]{3})[0-9]+", "\\1")) %>% mutate(p.value = str_replace(p.value, "([0-9])\\.[0-9]{3}e", "\\1e")) %>% select("Classifier 1" = cC1, "Classifier 2" = cC2, "Bootstrap Test p-value"=p.value) %>% knitr::kable(format = "html", escape = FALSE, caption = "Comparing Fracture Models predicted by Various Predictor Sets.") %>% kableExtra::kable_styling(bootstrap_options = c("condensed", "hover", "striped"), full_width = FALSE) %>% Tbl(bn = "SuppTable7_byCohortComp", tbl_type = "html")
DATA <- glance(cCs) %>% mapnames("Classifier", "cohort") DATA %<>% arrange(desc(auc)) glue::glue("We then evaluated the image-only classifier for fracture on each test set (Figure 3B-D, Supplementary Tables 6&7). The area under the Precision Recall Curve (PRC) is dependent on the disease prevalence, and since the original population had a 3% fracture prevalence but case-control cohorts have a 50% prevalence, the PRC is significantly higher for case-control cohorts. Random subsampling had no effect on the primary evaluation metric, AUC ({DATA %>% filter(cohort=='crossSectional') %$% auc %>% round(digits=2)} vs {DATA %>% filter(cohort=='caseControl_matchNone') %$% auc %>% round(digits=2)}, p={comp_tbl[1, 'p.value'] %>% round(digits=2)}). Balancing patient demographics (age and gender) also made no difference to model performance (p={comp_tbl[3, 'p.value'] %>% round(digits=2)}). Model performance was significantly lower when evaluating on a test cohort matched by patient demographics, bmi and symptoms (AUC={DATA %>% filter(cohort=='caseControl_matchPt') %$% auc %>% round(digits=2)}, p={comp_tbl[6, 'p.value'] %>% round(digits=2)}). When evaluated on a test cohort matched by all covariates, the fracture detector was no longer better than random (AUC={DATA %>% filter(cohort=='caseControl_matchAll') %$% auc %>% round(digits=2)}, 95% CI {DATA %>% filter(cohort=='caseControl_matchAll') %>% with(., glue('{auc_lower %>% round(digits=2)} - {auc_upper %>% round(digits=2)}'))}) and significantly worse than when assessed on all other test cohorts.")
# Target learnability # ROC by target DATA$cohort %<>% fct_reorder(DATA$auc) PLT_DAT_ROC_SUM <- DATA # Save intermediate plot data gg_roc <- ggplot(DATA, aes(x = fct_reorder(cohort, auc), y = auc, ymin = auc_lower, ymax = auc_upper)) + geom_errorbar(width = 0.25, position = Pd()) + geom_point(position = Pd()) + geom_text(aes(label = is_sig), nudge_x = 0.2) + aes_cohort_col() + scale_x_cohort() gg_roc <- gg_roc + geom_hline(yintercept = 0.5, alpha = 0.5, linetype = 2) + geom_hline(yintercept = 1, alpha = 0.5) + theme(axis.text.x = element_text(angle = 90)) + labs(y = "AUROC +/- 95% bootstrap CI") + theme(strip.text.y = element_text(angle = 0), axis.text.x = element_text(vjust = 0.5)) + scale_y_continuous(breaks = c(0.5, 0.75, 1), labels = c(0.5, 0.75, 1)) + theme(legend.position = "bottom", legend.direction = "vertical") GG_ROC <- gg_roc %+% coord_flip() + labs(title = "Different methods of subsampling a test-set lead to markedly different performance scores") + theme(plot.title = element_text(hjust = 1)) GG_ROC
# ROC XY_byClassifier <- purrr::map(cCs, gg_data_roc) XY_byGeom <- purrr::transpose(XY_byClassifier) xy_geoms <- purrr::map(XY_byGeom, purrr::lift_dl(dplyr::bind_rows, .id = "cohort")) PLT_DAT_ROC <- xy_geoms # Save intermediate plot data gg_roc <- ggplot2::ggplot(xy_geoms$lines, ggplot2::aes(x=x, y=y)) + ggplot2::geom_line() + ggplot2::geom_point(data = xy_geoms$points, shape=10, size = 3) + ggplot2::geom_rug(data = xy_geoms$points) + ggplot2::geom_text(data = xy_geoms$points, aes(label = is_sig), size = 5, nudge_y = 0.03, show.legend=FALSE) + gg_style_roc + aes_cohort_col() # PRROC XY_byGeom <- map(cCs, gg_data_prc) %>% transpose() xy_geoms <- map(XY_byGeom, lift_dl(bind_rows, .id="cohort")) PLT_DAT_PRC <- xy_geoms # Save intermediate plot data gg_prc <- ggplot2::ggplot(xy_geoms$lines, ggplot2::aes(x=x, y=y)) + ggplot2::geom_line() + ggplot2::geom_point(data = xy_geoms$points, shape=10, size = 3) + ggplot2::geom_rug(data = xy_geoms$points) + gg_style_prc + aes_cohort_col() gg_prc %<>% `+`(list(theme(legend.direction = "vertical", legend.position = "bottom"))) gg_prc %<>% `+`(guides(color = guide_legend(title = "Cohort", ncol = 2, title.hjust = 0.5))) leg <- cowplot::get_legend(gg_prc) gg_prc %<>% `+`(NL) gg_roc %<>% `+`(NL) GG <- cowplot::plot_grid( cowplot::plot_grid(gg_roc, gg_prc, ncol = 2, align='hv'), cowplot::plot_grid(NULL, leg, NULL, ncol = 3, rel_widths = c(1, .1, 1)), ncol = 1, rel_heights = c(.6, 0.4)) GG # cowplot::save_plot("analysis/figures/ROC_PRROC.png", GG, base_width = 4, base_height = 6)
cohort_imgs <- map(test_cohorts, "img") mshBinary <- hipsCohort(mutating = binary) mshBinCohorts <- map(cohort_imgs, ~filter(mshBinary, img %in% .x)) OR_DATA <- mshBinCohorts %>% OddsRatios(grp_chr = "cohort") OR_DATA %>% ggOddsRatios() %+% aes_cohort_col() sig_ors <- OR_DATA %>% mutate(is_sig = map_lgl(p.value, ~ .x < 0.05)) %>% group_by(cohort) %>% filter(is_sig) sig_ors %>% mutate(target = map_chr(target, PrettyTargets)) %>% summarise(`Sig. Assoc` = str_x(target)) %>% kable(caption = "Significant fracture-covariate associations in each cohort") sig_ors %>% summarise(n_sig = n()) %>% kable(caption = "Number of significant fracture-covariate associations in each cohort")
models %>% map(train_terms) models %>% map(train_n_eg)
FP_OUT_INT_PLT_DATA <- file.path(kDIR_BY_COHORT, "plot_data.Rdata") save(PLT_DAT_ROC_SUM, PLT_DAT_ROC, PLT_DAT_PRC, file = FP_OUT_INT_PLT_DATA)
Codebase State: r dotfileR::gitState()
on r date()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.