Custom predict function for survival models

knitr::opts_chunk$set(echo = TRUE,
                      message = FALSE,
                      warning = FALSE)

Introduction

This vignette contains example predict functions for survival models. Some functions are already implemented. Therefore, for some models there is no need to specify predict function.

data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc),]
pbc$sex <- as.factor(pbc$sex)
pbc$stage <- as.factor(pbc$stage)

Implemented Models

Currently implemented model classes. Objects listed below don't need specified predict function.

set.seed(1024)
library(rms)
library(survxai)
cph_model <- cph(Surv(days/365, status) ~ treatment + age + sex + ascites + hepatom + spiders + edema + bili + chol + albumin + copper + alk + sgot + trig + platelet + prothrombin + stage , data = pbc, surv = TRUE, x = TRUE, y=TRUE)

surve_cph <- explain(model = cph_model,
                     data = pbc[,-c(1,2)], y = Surv(pbc$days/365, pbc$status))

RandomForestSRC

Predict function for class rfsrc is not implemented. Therefore, custom predict function should be provided.

library(prodlim)
library(randomForestSRC)

predict_rf <- function(object, newdata, times, ...){
  f <- sapply(newdata, is.integer)
  cols <- names(which(f))
  object$xvar[cols] <- lapply(object$xvar[cols], as.integer)
  ptemp <- predict(object,newdata=newdata,importance="none")$survival
  pos <- prodlim::sindex(jump.times=object$time.interest,eval.times=times)
  p <- cbind(1,ptemp)[,pos+1,drop=FALSE]
  return(p)
}
pbc$year <- pbc$days/365
rf_model <- rfsrc(Surv(year, status)~., data = pbc[,-1])

surve_rf <- explain(model = rf_model,
                    data = pbc[,-c(1,2,20)], y = Surv(pbc$year, pbc$status),
                    predict_function = predict_rf)

survreg

Predict function for class survreg is not implemented. Therefore, custom predict function should be provided.

library(survival)

predict_reg <- function(model, newdata, times){
  times <- sort(times)
  vars <- all.vars(model$call[[2]][[2]])
  n_vars <- which(colnames(newdata) %in% vars)
  if(length(n_vars)>0){
    newdata <- newdata[,-c(n_vars)]
  }
  model$x <- model.matrix(~., newdata)
  res <- matrix(ncol = length(times), nrow = nrow(newdata))
  for(i in 1:nrow(newdata)) {
    res[i,] <- cfc.survreg.survprob(t = times, args = model, n = i)    
  }
  return(res)
}
reg_model <- survreg(Surv(year, status)~., data = pbc[,-1], x = TRUE)

surve_reg <- explain(model = rf_model,
                    data = pbc[,-c(1,2,20)], 
                    y = Surv(pbc$year, pbc$status),
                    predict_function = predict_reg)


Try the survxai package in your browser

Any scripts or data that you put into this service are public.

survxai documentation built on Aug. 28, 2020, 5:07 p.m.