#| label: get-engines #| include: false library(tidymodels) library(censored) library(joineR) library(doMC) # ------------------------------------------------------------------------------ tidymodels_prefer() theme_set(theme_bw()) options(pillar.advice = FALSE, pillar.min_title_chars = Inf) registerDoMC(cores = 10) # ------------------------------------------------------------------------------ engines <- get_from_env("models") %>% map(~ get_from_env(.x) %>% mutate(model = .x)) %>% list_rbind() %>% filter(mode == "censored regression") num_models <- engines %>% distinct(model) %>% nrow()
The censored package was released in June 2022, enabling users to fit event time/survival time models using the tidymodels framework. As of this writing, there are now a total of r nrow(engines)
different engines that can be used with r num_models
different models.
This document is intended as a tutorial for using the broader tidymodels framework for event time analysis, including model tuning, evaluation, and selection.
To reproduce these results, you might need to update some package versions:
#| label: installs #| eval: false # Get CRAN versions of pak::pak(c("parsnip", "censored"), ask = FALSE) # Get GitHub versions of: pak::pak(c("tidymodels/tune@ipcw"), ask = FALSE) pak::pak(c("tidymodels/yardstick"), ask = FALSE)
We'll use the heart valve data set in the joineR package (also described in this publication). There are r length(unique(heart.valve$num))
patients in the study that experienced aortic valve replacement surgery. The data has time-dependent covariates, but we will skip those to simplify the analysis here. The outcome is the time to death after surgery:
#| label: print-data-head library(joineR) data(heart.valve, package = "joineR") str(heart.valve)
Loading needed tidymodels packages:
#| label: load-tm library(tidymodels) library(censored) # ------------------------------------------------------------------------------ tidymodels_prefer() theme_set(theme_bw()) options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
We'll retrieve the appropriate event times for the outcome (since there are multiple time points where patients were measured). Then, we'll identify the predictors that have the same values across the multiple time points and merge them. Functions in the joineR package will help us out here:
outcome_data <- UniqueVariables(heart.valve, var.col = c("fuyrs", "status"), id.col = "num") covar_data <- UniqueVariables(heart.valve, var.col = c("age", "hs", "sex", "lv", "emergenc", "hc", "sten.reg.mix"), id.col = "num") heart_data <- full_join(outcome_data, covar_data, by = "num") %>% select(-num) %>% as_tibble() heart_data
We'll reformat some of the categorical predictors since they are currently encoded as integers.
Also, tidymodels expects that the event times and corresponding status data are pre-formatted using the Surv
function in the survival package. We'll do that, then remove the original fuyrs
and status
columns.
#| label: reformat-data heart_data <- heart_data %>% mutate( event_time = Surv(fuyrs, status), lv = case_when( lv == 1 ~ "good", lv == 2 ~ "moderate", lv == 3 ~ "poor" ), emergenc = case_when( emergenc == 0 ~ "elective", emergenc == 1 ~ "urgent", emergenc == 2 ~ "emergency" ), hc = case_when( hc == 0 ~ "absent", hc == 1 ~ "present_treated", hc == 2 ~ "present_untreated" ), sten.reg.mix = case_when( sten.reg.mix == 1 ~ "stenosis", sten.reg.mix == 2 ~ "regurgitation", sten.reg.mix == 3 ~ "mixed" ), hs = case_when( hs == "Homograft" ~ "homograft", TRUE ~ "stentless_porcine_tissue" ), across(where(is.character), factor) ) %>% select(-fuyrs, -status)
Since our focus is on prediction, the standard tidymodels methods for data splitting are used to create training and test sets. We'll also make cross-validation folds:
#| label: data-splitting set.seed(6941) valve_split <- initial_split(heart_data) valve_tr <- training(valve_split) valve_te <- testing(valve_split)
In the training set, the observed time values range from r format(min(valve_tr$event_time[,1]), digits = 2)
years to r format(max(valve_tr$event_time[,1]), digits = 2)
years and r round(mean(valve_tr$event_time[,2]) * 100, 2)
% of the patients died (i.e. were events).
There are different types of predictions for event time analysis. Dynamic predictions require a specific time to make the prediction at. That time is sometimes called a "landmark time", we call it "evaluation time" since our focus is prediction. For example, we might want to know the probability of survival up to some evaluation time $t$. A static prediction is one that is not dependent on an evaluation time point. For example, we might predict the event time from a model.
To demonstrate, let's fit a bagged tree to the training data:
#| label: initial-example bag_spec <- bag_tree() %>% set_mode("censored regression") %>% set_engine("rpart", nbagg = 50) set.seed(29872) bag_fit <- bag_spec %>% fit(event_time ~ ., data = valve_tr)
Instead of using the training or testing sets, let's make two fake patients by randomly selecting rows from the training set:
#| fake-patients set.seed(4853) fake_examples <- slice_sample(valve_tr, n = 2) fake_examples
The standard predict()
machinery can be used to get static (e.g., type = "time"
) or dynamic predictions (e.g., type = "survival"
). We'll create a grid of r length(seq(0, 10, by = .1))
evaluation time points for the latter:
#| fake-predictions time_points <- seq(0, 10, by = .1) bag_pred <- predict(bag_fit, fake_examples, type = "survival", eval_time = time_points) %>% bind_cols( predict(bag_fit, fake_examples), fake_examples %>% select(event_time) ) %>% add_rowindex() bag_pred
As usual, the prediction columns are prefixed with .pred_
. What is unusual is that .pred
is a list column, and each list element is a tibble with r ncol(bag_pred$.pred[[1]])
columns and r nrow(bag_pred$.pred[[1]])
rows. They contain the survival estimates for each patient:
#| dot-pred bag_pred$.pred[[1]] %>% slice(1:5)
We can unnest these and plot the per-patient survival curves:
#| label: example-survival-probs #| out-width: "60%" #| fig-align: center bag_pred %>% unnest(.pred) %>% mutate(sample = format(.row)) %>% ggplot(aes(.eval_time, .pred_survival, group = sample, col = sample)) + geom_step() + lims(y = 0:1) + labs(x = "Time", y = "Probability of Survival") + scale_color_brewer(palette = "Paired")
The static/dynamic prediction types make these models' tuning and evaluations a little more complex. In many tidymodels functions, there is a new argument called eval_time
that is used to specify the time points for dynamic predictions (as we'll see in a minute).
Metrics to measure how well our model performs can also be split into dynamic and static metrics.
For static, a common choice is the concordance statistic, accessible via the concordance_survival()
function. If we were looking at the test set results for the bagged tree model:
#| label: test-concordance test_pred <- predict(bag_fit, valve_te, type = "survival", eval_time = time_points) %>% bind_cols( predict(bag_fit, valve_te), valve_te %>% select(event_time) ) test_pred %>% concordance_survival(truth = event_time, estimate = .pred_time)
Dynamic metrics usually are classification metrics re-purposed for survival analysis. For example, if we wanted to evaluate the model at $t = 5$, we could use the predicted survival probabilities and try to classify each data point as dead or alive. This ends up being a two class situation, and metrics like the Brier Score or the area under the ROC curve can be used to quantify how well the model works at evaluation time $t$.
The main difficulty is that, due to censoring, some data can't be cleanly classified. If we have a censored event at time 6, we definitely know that it should not be classified as an event. However, if the observed time were 2 and censored, we don't know if it is an event at $t = 5$ or not.
There are a lot of ways to deal with this issue. We've done an exhaustive reading of the literature and have come to a somewhat opinionated conclusion. Most of the survival metrics in the literature are developed to univariately score a collection of predictors, typically biomarkers, regarding how well they are associated with the event times. That's not what we are doing; we have model predictions.
Our choice for dynamic metrics is to use the inverse probability of censoring weights (IPCW), specifically the scheme used by Graf et al. (1999). They compute the probability that every data point might have been censored and uses the inverse of this value as a case weight. If the observed time is a censoring time that occurs before the evaluation time, the data point should make no contribution to the performance metric.
If you were to compute model performance manually (as above), these weights are computed using:
#| ipcw-comp ipcw_data <- test_pred %>% .censoring_weights_graf(bag_fit, .) %>% select(-.pred_time)
This adds a column called .weight_censored
to the tibble of predicted survival probabilities which is used as a case weight in calculating the performance metric.
ipcw_data # The adjusted data: ipcw_data$.pred[[1]] %>% slice(1:5)
With the data in this format, we can use a yardstick function for dynamic metrics like brier_survival()
:
#| label: brier-surv brier_scores <- ipcw_data %>% # No argument name is used for .pred brier_survival(truth = event_time, .pred) brier_scores %>% slice(1:5)
We compute a score for each evaluation time:
#| label: brier-surv-plot #| out-width: "60%" #| fig-align: center brier_scores %>% ggplot(aes(.eval_time, .estimate)) + geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) + geom_line() + labs(x = "years", y = "Brier Score")
The vertical line is the level of performance that you would get with a non-informative model.
The other dynamic metrics that are currently implemented are brier_survival_integrated()
(for an AUC of the curve above) and roc_auc_survival().
Multiple static and dynamic metrics can be combined via a metric set.
tidymodels strongly focuses on empirical validation via resampling, which is also true of event time models.
We can use the fit_resamples()
function with an rsample object to compute performance without using the test set. We need to tell the function what times to use for the dynamic metrics:
#| label: resample # Create resamples set.seed(12) valve_rs <- vfold_cv(valve_tr, repeats = 5) bag_tree_res <- bag_spec %>% fit_resamples(event_time ~ ., resamples = valve_rs, eval_time = time_points)
By default, the Brier score is used:
#| label: resampled-brier #| out-width: "60%" #| fig-align: center collect_metrics(bag_tree_res) %>% slice(1:5) bag_tree_res %>% collect_metrics() %>% mutate( lower = mean - 1.96 * std_err, upper = mean + 1.96 * std_err ) %>% ggplot(aes(.eval_time)) + geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) + geom_line(aes(y = mean)) + geom_ribbon(aes(ymin = lower, ymax = upper), col = NA, alpha = 1 / 10) + labs(x = "years", y = "Brier Score")
Suppose we try a regularized Cox model for these data. We'll add a recipe to the analysis and tune a lasso model. The code is pretty standard tidymodels syntax, with the added eval_time
argument. We'll also use a metric set to include the integrated Brier score, which computes the AUC of the Brier/time curve:
#| label: lasso-model #| out-width: "60%" #| fig-align: center lasso_spec <- proportional_hazards(penalty = tune(), mixture = 0) %>% set_engine("glmnet") %>% set_mode("censored regression") lasso_rec <- recipe(event_time ~ ., data = valve_tr) %>% step_dummy(all_nominal_predictors()) %>% step_zv(all_predictors()) %>% step_normalize(all_numeric_predictors()) lasso_wflow <- workflow(lasso_rec, lasso_spec) surv_metrics <- metric_set(brier_survival_integrated, brier_survival) lasso_tune_res <- lasso_wflow %>% tune_grid( resamples = valve_rs, eval_time = time_points, grid = tibble(penalty = 10^seq(-3, -1, length.out = 20)), metrics = surv_metrics )
We can plot the results for that specific metric:
#| label: integrated-tune #| out-width: "60%" #| fig-align: center autoplot(lasso_tune_res, metric = "brier_survival", eval_time = 5)
For these plot methods, eval_time
can be passed in as shown. If a dynamic metric is used and eval_time
is not set, the function will pick a time near the middle of the range.
We can also choose the best penalty. If we use an integrated method, no eval_time
is needed:
best_penalty <- select_best(lasso_tune_res, metric = "brier_survival_integrated")
Now we can update the workflow and, assuming that this is the model that we want to keep, evaluate it on the test set:
#| label: finalize lasso_final_wflow <- lasso_wflow %>% finalize_workflow(best_penalty) lasso_final_wflow
For performance assessment on the test set, you can manually predict it and calculate performance or use last_fit()
with the original split object to do these steps for you:
#| label: last-fit test_res <- last_fit(lasso_final_wflow, valve_split, eval_time = time_points)
As usual, you can get the test set metrics via:
#| label: test-metrics collect_metrics(test_res)
How do the Brier Score estimates compare between the test set and resampling?
#| label: lasso-brier-curves #| out-width: "60%" #| fig-align: center #| warning: false collect_metrics(test_res) %>% mutate(estimator = "testing") %>% select(.eval_time, estimator, Brier = .estimate) %>% bind_rows( lasso_tune_res %>% collect_metrics() %>% mutate(estimator = "resampling") %>% select(.eval_time, estimator, Brier = mean, penalty) %>% inner_join(best_penalty, by = "penalty") ) %>% ggplot(aes(.eval_time)) + geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) + geom_line(aes(y = Brier, col = estimator)) + labs(x = "years", y = "Brier Score") + scale_color_brewer(palette = "Set2")
Good!
eval_time
augment()
to produce IPCW values. #| label: session sessioninfo::session_info()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.