Warning: this code may or may not run depending on your current version of mlr3proba
. First we load some additional mlr3
libraries.
suppressPackageStartupMessages({ library(mlr3verse) library(mlr3proba) library(mlr3extralearners) library(mlr3viz) library(mlr3benchmark) library(aorsf) library(tidyverse) })
Next we'll define some tasks for our learners to engage with.
# Mayo Clinic Primary Biliary Cholangitis Data task_pbc <- TaskSurv$new( id = 'pbc', backend = select(pbc_orsf, -id) %>% mutate(stage = as.numeric(stage)), time = "time", event = "status" ) # Veteran's Administration Lung Cancer Trial data(veteran, package = "randomForestSRC") task_veteran <- TaskSurv$new( id = 'veteran', backend = veteran, time = "time", event = "status" ) # NKI 70 gene signature data_nki <- OpenML::getOMLDataSet(data.id = 1228) task_nki <- TaskSurv$new( id = 'nki', backend = data_nki$data, time = "time", event = "event" ) # Gene Expression-Based Survival Prediction in Lung Adenocarcinoma data_lung <- OpenML::getOMLDataSet(data.id = 1245) task_lung <- TaskSurv$new( id = 'nki', backend = data_lung$data %>% mutate(OS_event = as.numeric(OS_event) -1) %>% select(-histology), time = "OS_years", event = "OS_event" ) # Chemotherapy for Stage B/C colon cancer # (there are two rows per person, one for death # and the other for recurrence, hence the two tasks) task_colon_death <- TaskSurv$new( id = 'colon_death', backend = survival::colon %>% filter(etype == 2) %>% drop_na() %>% # drop id, redundant variables select(-id, -study, -node4, -etype), mutate(OS_event = as.numeric(OS_event) -1), time = "time", event = "status" ) task_colon_recur <- TaskSurv$new( id = 'colon_death', backend = survival::colon %>% filter(etype == 1) %>% drop_na() %>% # drop id, redundant variables select(-id, -study, -node4, -etype), mutate(OS_event = as.numeric(OS_event) -1), time = "time", event = "status" ) # putting them all together tasks <- list(task_pbc, task_veteran, task_nki, task_lung, task_colon_death, task_colon_recur, # add a few more pre-made ones tsk("actg"), tsk('gbcs'), tsk('grace'), tsk("unemployment"))
Now we can make a benchmark designed to compare our three favorite learners:
# Learners with default parameters learners <- lrns(c("surv.ranger", "surv.rfsrc", "surv.aorsf")) # Brier (Graf) score, c-index and training time as measures measures <- msrs(c("surv.graf", "surv.cindex", "time_train")) # Benchmark with 5-fold CV design <- benchmark_grid( tasks = tasks, learners = learners, resamplings = rsmps("cv", folds = 5) ) benchmark_result <- benchmark(design) bm_scores <- benchmark_result$score(measures, predict_sets = "test")
Let's look at the overall results:
bm_scores %>% select(task_id, learner_id, surv.graf, surv.cindex, time_train) %>% group_by(learner_id) %>% filter(!is.infinite(surv.graf)) %>% summarize( across( .cols = c(surv.graf, surv.cindex, time_train), .fns = ~mean(.x, na.rm = TRUE) ) )
tbl_data <- structure( list( learner_id = c("surv.aorsf", "surv.ranger", "surv.rfsrc"), surv.graf = c(0.150484184019999, 0.164170239607235, 0.154920952989591), surv.cindex = c(0.734037935733795, 0.716040564907922, 0.724149275709082), time_train = c(0.382962962962959, 2.04055555555556, 0.757407407407405) ), row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame") ) tbl_data
From inspection,
aorsf
has a higher expected value for 'surv.cindex' (higher is better)
aorsf
has a lower expected value for 'surv.graf' (lower is better)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.