knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) library(tidyverse) mtsamples_embeddings = readr::read_csv('Z:/kdpsingh/mtsamples_description.csv') %>% mutate(is_cardiology_note = if_else(medical_specialty == 'Cardiovascular / Pulmonary', 'Yes', 'No')) %>% select(contains('emb_'), is_cardiology_note) %>% na.omit()
library(tidyverse) library(tidymodels) library(clinspacy) library(runway)
mtsamples = dataset_mtsamples()
Here, we are aiming to predict which of the descriptions refer to 'Cardiovascular / Pulmonary' notes so we will convert the outcome into a binary outcome. We will remove all of the predictor variables other than
mtsamples_embeddings = mtsamples %>% clinspacy(df_col = 'description', return_scispacy_embeddings = TRUE, verbose = FALSE) %>% bind_clinspacy_embeddings(mtsamples) %>% mutate(is_cardiology_note = if_else(medical_specialty == 'Cardiovascular / Pulmonary', 'Yes', 'No')) %>% select(contains('emb_'), is_cardiology_note) %>% na.omit()
set.seed(1) logreg_workflow = workflow() %>% add_model(logistic_reg() %>% set_engine('glm')) %>% add_recipe((recipe(is_cardiology_note~., data = mtsamples_embeddings))) logreg_result = fit_resamples(logreg_workflow, resamples = validation_split(data = mtsamples_embeddings, prop = 2/3), metrics = metric_set(roc_auc, pr_auc), control = control_resamples(save_pred = TRUE)) logreg_result %>% collect_metrics()
set.seed(1) rf_workflow = workflow() %>% add_model(rand_forest(mode = 'classification', trees = 1000) %>% set_engine('ranger')) %>% add_recipe((recipe(is_cardiology_note~., data = mtsamples_embeddings))) rf_result = fit_resamples(rf_workflow, resamples = validation_split(data = mtsamples_embeddings, prop = 2/3), metrics = metric_set(roc_auc, pr_auc), control = control_resamples(save_pred = TRUE)) rf_result %>% collect_metrics()
combined_predictions = bind_rows( logreg_result %>% collect_predictions() %>% mutate(model_name = 'Logistic regression'), rf_result %>% collect_predictions() %>% mutate(model_name = 'Random forest') )
combined_predictions %>% mutate(is_cardiology_note = if_else(is_cardiology_note == 'Yes', 1, 0)) %>% threshperf_plot_multi(outcome = 'is_cardiology_note', prediction = '.pred_Yes', model = 'model_name')
combined_predictions %>% mutate(is_cardiology_note = if_else(is_cardiology_note == 'Yes', 1, 0)) %>% cal_plot_multi(outcome = 'is_cardiology_note', prediction = '.pred_Yes', model = 'model_name', n_bins = 5)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.