Event time prediction with tidymodels

#| label: get-engines
#| include: false

library(tidymodels)
library(censored)
library(joineR)
library(doMC)

# ------------------------------------------------------------------------------

tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
registerDoMC(cores = 10)

# ------------------------------------------------------------------------------

engines <- 
  get_from_env("models") %>% 
  map(~ get_from_env(.x) %>% mutate(model = .x)) %>% 
  list_rbind() %>% 
  filter(mode == "censored regression")

num_models <- engines %>% distinct(model) %>% nrow()

The censored package was released in June 2022, enabling users to fit event time/survival time models using the tidymodels framework. As of this writing, there are now a total of r nrow(engines) different engines that can be used with r num_models different models.

This document is intended as a tutorial for using the broader tidymodels framework for event time analysis, including model tuning, evaluation, and selection.

To reproduce these results, you might need to update some package versions:

#| label: installs
#| eval: false

# Get CRAN versions of
pak::pak(c("parsnip", "censored"), ask = FALSE)

# Get GitHub versions of: 
pak::pak(c("tidymodels/tune@ipcw"), ask = FALSE)
pak::pak(c("tidymodels/yardstick"), ask = FALSE)

An Example

We'll use the heart valve data set in the joineR package (also described in this publication). There are r length(unique(heart.valve$num)) patients in the study that experienced aortic valve replacement surgery. The data has time-dependent covariates, but we will skip those to simplify the analysis here. The outcome is the time to death after surgery:

#| label: print-data-head
library(joineR)

data(heart.valve, package = "joineR")

str(heart.valve)

Loading needed tidymodels packages:

#| label: load-tm
library(tidymodels)
library(censored)

# ------------------------------------------------------------------------------

tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)

We'll retrieve the appropriate event times for the outcome (since there are multiple time points where patients were measured). Then, we'll identify the predictors that have the same values across the multiple time points and merge them. Functions in the joineR package will help us out here:

outcome_data <- 
  UniqueVariables(heart.valve, var.col = c("fuyrs", "status"), id.col = "num")
covar_data <- 
  UniqueVariables(heart.valve, 
                  var.col = c("age", "hs", "sex", "lv", "emergenc", "hc", "sten.reg.mix"), 
                  id.col = "num")

heart_data <- 
  full_join(outcome_data, covar_data, by = "num") %>% 
  select(-num) %>%
  as_tibble()

heart_data

We'll reformat some of the categorical predictors since they are currently encoded as integers.

Also, tidymodels expects that the event times and corresponding status data are pre-formatted using the Surv function in the survival package. We'll do that, then remove the original fuyrs and status columns.

#| label: reformat-data
heart_data <- 
  heart_data %>% 
  mutate(
    event_time = Surv(fuyrs, status),
    lv =
      case_when(
        lv == 1 ~ "good",
        lv == 2 ~ "moderate",
        lv == 3 ~ "poor"
      ),
    emergenc =
      case_when(
        emergenc == 0 ~ "elective",
        emergenc == 1 ~ "urgent",
        emergenc == 2 ~ "emergency"
      ),
    hc =
      case_when(
        hc == 0 ~ "absent",
        hc == 1 ~ "present_treated",
        hc == 2 ~ "present_untreated"
      ),
    sten.reg.mix =
      case_when(
        sten.reg.mix == 1 ~ "stenosis",
        sten.reg.mix == 2 ~ "regurgitation",
        sten.reg.mix == 3 ~ "mixed"
      ),
    hs =
      case_when(
        hs == "Homograft" ~ "homograft",
        TRUE ~ "stentless_porcine_tissue"
      ),
    across(where(is.character), factor)
  ) %>% 
  select(-fuyrs, -status)

Since our focus is on prediction, the standard tidymodels methods for data splitting are used to create training and test sets. We'll also make cross-validation folds:

#| label: data-splitting
set.seed(6941)
valve_split <- initial_split(heart_data)
valve_tr <- training(valve_split)
valve_te <- testing(valve_split)

In the training set, the observed time values range from r format(min(valve_tr$event_time[,1]), digits = 2) years to r format(max(valve_tr$event_time[,1]), digits = 2) years and r round(mean(valve_tr$event_time[,2]) * 100, 2)% of the patients died (i.e. were events).

New Prediction Types

There are different types of predictions for event time analysis. Dynamic predictions require a specific time to make the prediction at. That time is sometimes called a "landmark time", we call it "evaluation time" since our focus is prediction. For example, we might want to know the probability of survival up to some evaluation time $t$. A static prediction is one that is not dependent on an evaluation time point. For example, we might predict the event time from a model.

To demonstrate, let's fit a bagged tree to the training data:

#| label: initial-example

bag_spec <- 
  bag_tree() %>%
  set_mode("censored regression") %>% 
  set_engine("rpart", nbagg = 50)

set.seed(29872)
bag_fit <- 
  bag_spec %>% 
  fit(event_time ~ ., data = valve_tr)

Instead of using the training or testing sets, let's make two fake patients by randomly selecting rows from the training set:

#| fake-patients
set.seed(4853)
fake_examples <- 
  slice_sample(valve_tr, n = 2)

fake_examples

The standard predict() machinery can be used to get static (e.g., type = "time") or dynamic predictions (e.g., type = "survival"). We'll create a grid of r length(seq(0, 10, by = .1)) evaluation time points for the latter:

#| fake-predictions
time_points <- seq(0, 10, by = .1)
bag_pred <- 
  predict(bag_fit, fake_examples, type = "survival", eval_time = time_points) %>% 
  bind_cols(
    predict(bag_fit, fake_examples),
    fake_examples %>% select(event_time)
  ) %>% 
  add_rowindex()
bag_pred

As usual, the prediction columns are prefixed with .pred_. What is unusual is that .pred is a list column, and each list element is a tibble with r ncol(bag_pred$.pred[[1]]) columns and r nrow(bag_pred$.pred[[1]]) rows. They contain the survival estimates for each patient:

#| dot-pred
bag_pred$.pred[[1]] %>% slice(1:5)

We can unnest these and plot the per-patient survival curves:

#| label: example-survival-probs
#| out-width: "60%"
#| fig-align: center

bag_pred %>% 
  unnest(.pred) %>% 
  mutate(sample = format(.row)) %>% 
  ggplot(aes(.eval_time, .pred_survival, group = sample, col = sample)) + 
  geom_step() + 
  lims(y = 0:1) +
  labs(x = "Time", y = "Probability of Survival") +
  scale_color_brewer(palette = "Paired")

The static/dynamic prediction types make these models' tuning and evaluations a little more complex. In many tidymodels functions, there is a new argument called eval_time that is used to specify the time points for dynamic predictions (as we'll see in a minute).

Measures of Performance

Metrics to measure how well our model performs can also be split into dynamic and static metrics.

For static, a common choice is the concordance statistic, accessible via the concordance_survival() function. If we were looking at the test set results for the bagged tree model:

#| label: test-concordance

test_pred <- 
  predict(bag_fit, valve_te, type = "survival", eval_time = time_points) %>% 
  bind_cols(
    predict(bag_fit, valve_te),
    valve_te %>% select(event_time)
  )

test_pred %>% 
  concordance_survival(truth = event_time, estimate = .pred_time)

Dynamic metrics usually are classification metrics re-purposed for survival analysis. For example, if we wanted to evaluate the model at $t = 5$, we could use the predicted survival probabilities and try to classify each data point as dead or alive. This ends up being a two class situation, and metrics like the Brier Score or the area under the ROC curve can be used to quantify how well the model works at evaluation time $t$.

The main difficulty is that, due to censoring, some data can't be cleanly classified. If we have a censored event at time 6, we definitely know that it should not be classified as an event. However, if the observed time were 2 and censored, we don't know if it is an event at $t = 5$ or not.

There are a lot of ways to deal with this issue. We've done an exhaustive reading of the literature and have come to a somewhat opinionated conclusion. Most of the survival metrics in the literature are developed to univariately score a collection of predictors, typically biomarkers, regarding how well they are associated with the event times. That's not what we are doing; we have model predictions.

Our choice for dynamic metrics is to use the inverse probability of censoring weights (IPCW), specifically the scheme used by Graf et al. (1999). They compute the probability that every data point might have been censored and uses the inverse of this value as a case weight. If the observed time is a censoring time that occurs before the evaluation time, the data point should make no contribution to the performance metric.

If you were to compute model performance manually (as above), these weights are computed using:

#| ipcw-comp
ipcw_data <- 
  test_pred %>% 
  .censoring_weights_graf(bag_fit, .) %>% 
  select(-.pred_time)

This adds a column called .weight_censored to the tibble of predicted survival probabilities which is used as a case weight in calculating the performance metric.

ipcw_data

# The adjusted data:
ipcw_data$.pred[[1]] %>% slice(1:5)

With the data in this format, we can use a yardstick function for dynamic metrics like brier_survival():

#| label: brier-surv

brier_scores <-
  ipcw_data %>% 
  # No argument name is used for .pred
  brier_survival(truth = event_time, .pred)
brier_scores %>% slice(1:5)

We compute a score for each evaluation time:

#| label: brier-surv-plot
#| out-width: "60%"
#| fig-align: center

brier_scores %>% 
  ggplot(aes(.eval_time, .estimate)) +
  geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) +
  geom_line() +
  labs(x = "years", y = "Brier Score")

The vertical line is the level of performance that you would get with a non-informative model.

The other dynamic metrics that are currently implemented are brier_survival_integrated() (for an AUC of the curve above) and roc_auc_survival().

Multiple static and dynamic metrics can be combined via a metric set.

Resampling the Model

tidymodels strongly focuses on empirical validation via resampling, which is also true of event time models.

We can use the fit_resamples() function with an rsample object to compute performance without using the test set. We need to tell the function what times to use for the dynamic metrics:

#| label: resample

# Create resamples
set.seed(12)
valve_rs <- vfold_cv(valve_tr, repeats = 5)

bag_tree_res <- 
  bag_spec %>% 
  fit_resamples(event_time ~ ., resamples = valve_rs, eval_time = time_points)

By default, the Brier score is used:

#| label: resampled-brier
#| out-width: "60%"
#| fig-align: center

collect_metrics(bag_tree_res) %>% slice(1:5)

bag_tree_res  %>%
  collect_metrics() %>%
  mutate(
    lower = mean - 1.96 * std_err,
    upper = mean + 1.96 * std_err
  ) %>%
  ggplot(aes(.eval_time)) +
  geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) +
  geom_line(aes(y = mean)) +
  geom_ribbon(aes(ymin = lower, ymax = upper),
              col = NA,
              alpha = 1 / 10) +
  labs(x = "years", y = "Brier Score") 

Model Tuning

Suppose we try a regularized Cox model for these data. We'll add a recipe to the analysis and tune a lasso model. The code is pretty standard tidymodels syntax, with the added eval_time argument. We'll also use a metric set to include the integrated Brier score, which computes the AUC of the Brier/time curve:

#| label: lasso-model
#| out-width: "60%"
#| fig-align: center

lasso_spec <- 
  proportional_hazards(penalty = tune(), mixture = 0) %>%
  set_engine("glmnet") %>%
  set_mode("censored regression")

lasso_rec <- 
  recipe(event_time ~ ., data = valve_tr) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_normalize(all_numeric_predictors())

lasso_wflow <- workflow(lasso_rec, lasso_spec)

surv_metrics <- metric_set(brier_survival_integrated, brier_survival)

lasso_tune_res <-
  lasso_wflow %>%
  tune_grid(
    resamples = valve_rs,
    eval_time = time_points,
    grid = tibble(penalty = 10^seq(-3, -1, length.out = 20)),
    metrics = surv_metrics
  )

We can plot the results for that specific metric:

#| label: integrated-tune
#| out-width: "60%"
#| fig-align: center
autoplot(lasso_tune_res, metric = "brier_survival", eval_time = 5)

For these plot methods, eval_time can be passed in as shown. If a dynamic metric is used and eval_time is not set, the function will pick a time near the middle of the range.

We can also choose the best penalty. If we use an integrated method, no eval_time is needed:

best_penalty <- select_best(lasso_tune_res, metric = "brier_survival_integrated")

Now we can update the workflow and, assuming that this is the model that we want to keep, evaluate it on the test set:

#| label: finalize
lasso_final_wflow <- 
  lasso_wflow %>% 
  finalize_workflow(best_penalty)

lasso_final_wflow

For performance assessment on the test set, you can manually predict it and calculate performance or use last_fit() with the original split object to do these steps for you:

#| label: last-fit
test_res <- 
  last_fit(lasso_final_wflow, valve_split, eval_time = time_points)

As usual, you can get the test set metrics via:

#| label: test-metrics
collect_metrics(test_res)

How do the Brier Score estimates compare between the test set and resampling?

#| label: lasso-brier-curves
#| out-width: "60%"
#| fig-align: center
#| warning: false

collect_metrics(test_res) %>% 
  mutate(estimator = "testing") %>% 
  select(.eval_time, estimator, Brier = .estimate) %>% 
  bind_rows(
    lasso_tune_res %>% 
    collect_metrics() %>% 
      mutate(estimator = "resampling") %>% 
      select(.eval_time, estimator, Brier = mean, penalty) %>% 
      inner_join(best_penalty, by = "penalty")
  ) %>% 
  ggplot(aes(.eval_time)) +
  geom_hline(yintercept = 0.25, col = "red", alpha = 1 / 2, lty = 2) +
  geom_line(aes(y = Brier, col = estimator)) +
  labs(x = "years", y = "Brier Score") +
  scale_color_brewer(palette = "Set2")

Good!

Things still to do

Session Info

#| label: session

sessioninfo::session_info()


Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.