mlr3 pipelines

Warning: this code may or may not run depending on your current version of mlr3proba. First we load some additional mlr3 libraries.

suppressPackageStartupMessages({
 library(mlr3verse)
 library(mlr3proba)
 library(mlr3extralearners)
 library(mlr3viz)
 library(mlr3benchmark)
 library(aorsf)
 library(tidyverse)
})

Next we'll define some tasks for our learners to engage with.

# Mayo Clinic Primary Biliary Cholangitis Data
task_pbc <-
 TaskSurv$new(
  id = 'pbc',
  backend = select(pbc_orsf, -id) %>%
   mutate(stage = as.numeric(stage)),
  time = "time",
  event = "status"
 )

# Veteran's Administration Lung Cancer Trial
data(veteran, package = "randomForestSRC")

task_veteran <-
 TaskSurv$new(
  id = 'veteran',
  backend = veteran,
  time = "time",
  event = "status"
 )

# NKI 70 gene signature
data_nki <- OpenML::getOMLDataSet(data.id = 1228)

task_nki <-
 TaskSurv$new(
  id = 'nki',
  backend = data_nki$data,
  time = "time",
  event = "event"
 )

# Gene Expression-Based Survival Prediction in Lung Adenocarcinoma
data_lung <- OpenML::getOMLDataSet(data.id = 1245)

task_lung <-
 TaskSurv$new(
  id = 'nki',
  backend = data_lung$data %>%
   mutate(OS_event = as.numeric(OS_event) -1) %>%
   select(-histology),
  time = "OS_years",
  event = "OS_event"
 )


# Chemotherapy for Stage B/C colon cancer
# (there are two rows per person, one for death
#  and the other for recurrence, hence the two tasks)

task_colon_death <-
 TaskSurv$new(
  id = 'colon_death',
  backend = survival::colon %>%
   filter(etype == 2) %>%
   drop_na() %>%
   # drop id, redundant variables
   select(-id, -study, -node4, -etype),
  mutate(OS_event = as.numeric(OS_event) -1),
  time = "time",
  event = "status"
 )

task_colon_recur <-
 TaskSurv$new(
  id = 'colon_death',
  backend = survival::colon %>%
   filter(etype == 1) %>%
   drop_na() %>%
   # drop id, redundant variables
   select(-id, -study, -node4, -etype),
  mutate(OS_event = as.numeric(OS_event) -1),
  time = "time",
  event = "status"
 )

# putting them all together
tasks <- list(task_pbc,
              task_veteran,
              task_nki,
              task_lung,
              task_colon_death,
              task_colon_recur,
              # add a few more pre-made ones
              tsk("actg"),
              tsk('gbcs'),
              tsk('grace'),
              tsk("unemployment"))

Now we can make a benchmark designed to compare our three favorite learners:

# Learners with default parameters
learners <- lrns(c("surv.ranger", "surv.rfsrc", "surv.aorsf"))

# Brier (Graf) score, c-index and training time as measures
measures <- msrs(c("surv.graf", "surv.cindex", "time_train"))

# Benchmark with 5-fold CV
design <- benchmark_grid(
 tasks = tasks,
 learners = learners,
 resamplings = rsmps("cv", folds = 5)
)

benchmark_result <- benchmark(design)

bm_scores <- benchmark_result$score(measures, predict_sets = "test")

Let's look at the overall results:

bm_scores %>%
 select(task_id, learner_id, surv.graf, surv.cindex, time_train) %>%
 group_by(learner_id) %>%
 filter(!is.infinite(surv.graf)) %>%
 summarize(
  across(
   .cols = c(surv.graf, surv.cindex, time_train),
   .fns = ~mean(.x, na.rm = TRUE)
  )
 )
tbl_data <-
 structure(
  list(
   learner_id = c("surv.aorsf", "surv.ranger", "surv.rfsrc"),
   surv.graf = c(0.150484184019999, 0.164170239607235, 0.154920952989591),
   surv.cindex = c(0.734037935733795, 0.716040564907922, 0.724149275709082),
   time_train = c(0.382962962962959, 2.04055555555556, 0.757407407407405)
  ),
  row.names = c(NA, -3L),
  class = c("tbl_df", "tbl", "data.frame")
 )

tbl_data

From inspection,



bcjaeger/aorsf documentation built on April 3, 2025, 4:16 p.m.