R/metrics.R

# Function to predict the survival for each individual
#'
#' Predict survival from a stan_surv object
#'
#' @export get_metrics
get_metrics <- function(train, test = NULL, cox_rp, splines_rp, cox_rp_train = NULL, splines_rp_train = NULL){
  
  if(is.null(test)){
    test <- train
    cox_rp_train <-  cox_rp
    splines_rp_train <- splines_rp
  }
  
  ## Prepare to calculate metrics
  ws <- rep(1, length(test$time))
  ws[test$time > max(test$time[test$status] ) ] <- 0.01
  
  obs.test.time <- sort ( unique( test$time[test$status] ) )
  #cindex
  Cindex.cox <- survcomp::concordance.index(x = cox_rp, surv.time = test$time, surv.event = test$status, method = "conservative",  na.rm=TRUE, weights = ws)$c.index
  Cindex.splines <- survcomp::concordance.index(x = splines_rp, surv.time = test$time, surv.event = test$status, method = "conservative",  na.rm=TRUE, weights = ws)$c.index
  
  #brier score
  bs.cox <- survcomp::sbrier.score2proba(data.tr = data.frame(time = train$time, event = train$status, score = cox_rp_train), data.ts = data.frame(time = test$time, event = test$status, score = cox_rp), method = c("prodlim"))
  bs.splines <- survcomp::sbrier.score2proba(data.tr = data.frame(time = train$time, event = train$status, score = splines_rp_train), data.ts = data.frame(time = test$time, event = test$status, score = splines_rp), method = c("prodlim"))
  
  #AUC ROC
  AUC.cox <- survcomp::tdrocc(x = as.vector(cox_rp) , surv.time = test$time, surv.event = test$status,  time = max(obs.test.time) )
  AUC.splines <- survcomp::tdrocc(x = as.vector( splines_rp) , surv.time = test$time, surv.event = test$status,  time = max(obs.test.time) )
  list("Cindex.cox" = Cindex.cox, "Cindex.splines" = Cindex.splines, "bs.cox" = bs.cox, "bs.splines" = bs.splines, "AUC.cox" = AUC.cox, "AUC.splines" = AUC.splines)
}

# Function to predict the survival for each individual
#'
#' Predict survival from a stan_surv object
#'
#' @export link_surv
#' @importFrom Hmisc approxExtrap
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#'                 data = df, cores = 1, chains = 1, iter = 2000,
#'                 basehaz = "fpm", iknots = c(6.594869,  7.285963 ),
#'                 degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
pred_surv2 <- function(fit, holdout = NULL, timepoints = NULL, time = "time", status = "status", ncores = 1L, method = "gp"){
  ## get obs.time
  obs.time <- get_obs.time(fit$data)
  if(is.null(timepoints)){
    timepoints <- test.time <- get_obs.time(holdout)
  }
  
  ##methd stan fit
  if(any( class(fit$stanfit) == "stanfit" ) ){
  basehaz.samples <- extract_basehaz_draws(fit)
  spline.basis <- extract_splines_bases(fit)
  
  if(!check_null_model(fit)){
    beta_draws <- extract_beta_draws(fit);
    varis.obs <- get_predictors(fit);
    
    basehaz.post <- lapply(seq_along(1:nrow(basehaz.samples)), function(s){
      sapply(seq_along(1:nrow(spline.basis)), function(n){
        basehaz.samples[s, ]  %*% spline.basis[n, ]
      })})
    
    basehaz.post <- do.call(cbind, basehaz.post)
    basehaz.post <- apply(X = basehaz.post, MARGIN = 1, FUN = mean)
    
    if(method == "linear" | method == "constant"){
      # surv.approx <- approx(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # only interpolation
      haz.approx <- Hmisc::approxExtrap(x = obs.time, y = basehaz.post, xout = timepoints, method = method)$y # with extrapolation
    } else if(method == "gp" | method == "gaussian process"){
      haz.bgp <- tgp::bgp(X= obs.time, Z= basehaz.post, verb = 1)
      X <- data.frame(x1 = obs.time)
      XX <- data.frame(x1 = timepoints)
      haz.bgp$Xsplit <- rbind(X, XX)
      haz.approx <- predict(object = haz.bgp, XX = timepoints)$ZZ.mean
    } else if(method == "tgp" | method == "treed gaussian process"){
      haz.btgp <- tgp:: btgp(X= obs.time,
                             Z= basehaz.post, verb = 1)
      X <- data.frame(x1 = obs.time)
      XX <- data.frame(x1 = timepoints)
      haz.btgp$Xsplit <- rbind(X, XX)
      haz.approx <- predict(object = haz.btgp, XX = timepoints)$ZZ.mean
    } else {
      stop2(paste0("method ", method, " not implemented"))
    }
    surv.approx <- link_surv_base.helper(b = haz.approx)
    surv.approx[surv.approx > 1] <- 1
    surv.approx[surv.approx < 0] <- 0
    
    ind.post <- lapply(seq_along(1:nrow(holdout)), function(i){
      surv.post <- sapply(seq_along(1:nrow(beta_draws)), function(s){
        lin_post <- varis.obs[i, ] %*% beta_draws[s, ]
        link_surv_lin.helper(p = lin_post, s = surv.approx)
      })
      apply(surv.post, 1, mean)
    })
    
    ind.post <- do.call(rbind, ind.post)
    rownames(ind.post) <- rownames(holdout)
  } else {
    # For null model
    
    if(method == "linear" | method == "constant"){
      # surv.approx <- approx(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # only interpolation
      haz.approx <- Hmisc::approxExtrap(x = obs.time, y = basehaz.post, xout = timepoints, method = method)$y # with extrapolation
    } else if(method == "gp" | method == "gaussian process"){
      haz.bgp <- tgp::bgp(X= obs.time, Z= basehaz.post, verb = 1)
      X <- data.frame(x1 = obs.time)
      XX <- data.frame(x1 = timepoints)
      haz.bgp$Xsplit <- rbind(X, XX)
      haz.approx <- predict(object = haz.bgp, XX = timepoints)$ZZ.mean
    } else if(method == "tgp" | method == "treed gaussian process"){
      haz.btgp <- tgp:: btgp(X= obs.time,
                             Z= basehaz.post, verb = 1)
      X <- data.frame(x1 = obs.time)
      XX <- data.frame(x1 = timepoints)
      haz.btgp$Xsplit <- rbind(X, XX)
      haz.approx <- predict(object = haz.btgp, XX = timepoints)$ZZ.mean
    } else {
      stop2(paste0("method ", method, " not implemented"))
    }
    surv.approx <- link_surv_base.helper(b = haz.approx)
    surv.approx[surv.approx > 1] <- 1
    surv.approx[surv.approx < 0] <- 0
    
    #No difference in survival expected in the null model
    ind.post <- matrix(rep(surv.approx, nrow(holdout)),
                       nrow = nrow(holdout),
                       ncol = n_distinct(timepoints),
                       byrow = TRUE)
    
  }
  } else if(any( class( fit ) == "gam")  ) {
    split_holdout <- split(holdout, as.factor(holdout$patient_id) )
    linked.surv <- lapply(split_holdout, function(s){
      long_holdout <- gen_long_data_backup(x = s, 
                   timepoints = timepoints)
      long_holdout$bh <- predict(fit, newdata = long_holdout )
      
      long_holdout %>%
        dplyr::group_by(patient_id) %>%
        dplyr::mutate(surv = exp( - cumsum(exp(bh))) ) %>%
        dplyr::ungroup() %>%
        select(surv) %>%
        unlist()
    })
    ind.post <- do.call(rbind, linked.surv)
    
  }  else if (any(class(fit$mfit) == "ahaz")) {
    
    obs.haz <- predict(fit$mfit, type = "cumhaz")$cumhaz
    obs.time <- predict(fit$mfit, type = "cumhaz")$time
    
    cumhaz.approx <- Hmisc::approxExtrap(x = obs.time, y = obs.haz, xout = timepoints, method = method)$y # with
    
    surv.approx <- exp(-cumhaz.approx)
    surv.approx[surv.approx > 1] <- 1
    surv.approx[surv.approx < 0] <- 0
    
    
    varis.obs <- as.matrix(gam_matrix(x = holdout, vars = fit$formula))
    
    ind.post <- lapply(seq_along(1:nrow(holdout)), function(i){
      lin_post <- varis.obs[i, ] %*% as.vector( unlist( coef(fit$mfit) ) ) 
      link_surv_lin.helper(p = lin_post, s = surv.approx)
    })
    
    ind.post <- do.call(rbind, ind.post)
    rownames(ind.post) <- rownames(holdout)
    
    
  }  else {
    stop(paste0("Error: No method for evaluating predicted probabilities from objects in class", class(fit)))
  }
  return(ind.post)
}

#'  Calculate tdBrier
#'
#' These functions calculate the survival analysis metric measured of a
#' system compared to a hold-out test set. The measurement and the "truth"
#' have a survival time and a censoring indicator 0/1 indicating if the event
#' result or the event.
#'
#'
#' The Brier score is defined as the squared distance between the
#' expected survival probability and the observed survival.
#' Therefore, it measures the discrepancy between observation
#' and model-based prediction.
#'
#' The integrated Brier Score summarises the Brier Score over the range
#' of observed events.Similar to the original Brier score [40] the iBrier:
#' ranges from 0 to 1; the model with an out-of-training sample value closer
#' to 0 outperforms the rest.
#' @aliases tdbrier tdbrier.model.list tdbrier.int.matrix
#' stdbrier.int.reference
#' @param data For the default functions, a datframe containing survival
#' (time), and status (0:censored/1:event), and the explanatory variables.
#' @param mod Coxph model object fitted with coxph (survival).
#' @return A tdBrier object
#' @seealso [iBrier]
#' @keywords brier
#' @examples
#' require(survival)
#' require(dplyr)
#' data(lung)
#' lung <- lung %>%
#' mutate(status = (status == 2))
#'
#' mod <- coxph(Surv(time, status)~ age, data = lung)
#'
#' tdbrier <- get_tdbrier(lung, mod)
#' integrate_tdbrier(tdroc)
#' @export tdbrier
#' @references
#'
#'  Ulla B. Mogensen, Hemant Ishwaran, Thomas A. Gerds (2012).
#' Evaluating Random Forests for Survival Analysis Using Prediction Error
#' Curves. Journal of Statistical Software, 50(11), 1-23.
#' URL http://www.jstatsoft.org/v50/i11/.
#' @export tdbrier

tdbrier <- function(holdout, fit, time = "time", status = "status", ncores = 1L, method = NULL){
  #select timepoints with observed event
  holdout <- prepare_surv(holdout, time = time, status = status)
  observed.timepoints <- get_obs.time(holdout)
  #timepoints <-  seq(min(observed.timepoints), max(observed.timepoints), length.out = 100)
  timepoints <- observed.timepoints
  
  if(any( class(fit)  == "coxph") ){
    #get probabilites
    probs <- pec::predictSurvProb(fit, newdata = holdout, times = timepoints)
    } else if (any( class(fit$stanfit) == "stanfit" ) | any(class(fit$mfit) == "ahaz") | any(class(fit) == "gam") ){
      probs <- pred_surv2(fit = fit, holdout = holdout, timepoints = timepoints, ncores = ncores, method = method) 
    } else {
      stop(paste0("Error: No method for evaluating predicted probabilities from objects in class", class(fit)))
    }
    
    
    out.brier <- suppressMessages( pec::pec(probs, Surv(time, status) ~ 1,
                                            data = holdout,
                                            maxtime = max(timepoints),
                                            exact = FALSE,
                                            exactness = length(timepoints) - 1 ) )
    apperror <- out.brier$AppErr$matrix
    time <- timepoints
    ref <- out.brier$AppErr$Reference
    ibrier <- integrate_tdbrier(out.brier)
    ibrier_ref <- integrate_tdbrier_reference(out.brier)
    object <- out.brier
    
    rsquaredbs <- 1 - (apperror / ref)
    
    irsquared <- unclass(1 - ( ibrier / ibrier_ref ) )
    
    out <- list(object, rsquaredbs, ibrier, ibrier_ref, apperror, ref, time, irsquared, method = method)
    names(out) <- c("brier.object", "rsquaredbs","ibrier_model", "ibrier_ref", "model_error", "ref_error","time", "irsquared", "method")
    out
    
}

#' @rdname tdbrier
#' @export
integrate_tdbrier <-
  function(x, ...) {
    ibrier <- pec::crps(x, models = "matrix", times = x$maxtime)
    return(ibrier)
  }
#' @export
#' @rdname tdbrier
integrate_tdbrier_reference <-
  function(x, ...) {
    ibrier <- pec::crps(x, models = "Reference", times =  x$maxtime)
    return(ibrier)
  }

integrate_rsquared <- function(m, r){
  1 - (m / r)
}


#' Get c-index
#'
#'
#' The efficacy of the survival model can be measured by
#'  the concordance statistic
#'
#' @param data For the default functions, a datframe containing survival
#' (time), and status (0:censored/1:event), and the explanatory variables.
#' @param mod Coxph model object fitted with coxph (survival).
#' @return A cindex object
#' @seealso [coxph]
#' @keywords cindex
#' @examples
#' require(survival)
#' require(dplyr)
#' data(lung)
#' lung <- lung %>%
#' mutate(status = (status == 2))
#'
#' mod <- coxph(Surv(time, status)~ age, data = lung)
#'
#' get_cindex(lung, mod)
#'
#' @export get_cindex
#' @author Carlos S Traynor
#' @references
#'
#'  Terry M. Therneau and Patricia M. Grambsch (2000).
#'   _Modeling Survival Data: Extending the Cox Model_.
#'   Springer, New York. ISBN 0-387-98784-3.
#'   @export get_cindex
#'
get_cindex <- function(data, mod,...)
  UseMethod("get_cindex")

#' @export
#' @rdname get_cindex
get_cindex <-
  function(holdout, fit, time = "time", status = "status", ncores = 1L, method = NULL){
    
    #select timepoints with observed event
    holdout <- prepare_surv(holdout, time = time, status = status)
    #timepoints =  seq(min(get_obs.time(holdout)), max(get_obs.time(holdout)), length.out = 100)
    timepoints <- get_obs.time(holdout)
    
    if(any( class(fit)  == "coxph") ){
      #get probabilites
      probs <- pec::predictSurvProb(fit, newdata = holdout, times = timepoints)
    } else if (any( class(fit$stanfit) == "stanfit" ) | any(class(fit$mfit) == "ahaz") | any(class(fit) == "gam") ){
      probs <- pred_surv2(fit = fit, holdout = holdout, timepoints = timepoints, ncores = ncores, method = method) 
    } else {
      stop(paste0("Error: No method for evaluating predicted probabilities from objects in class ", class(fit)))
    }
    statistic <- suppressMessages( pec::cindex(probs,
                                               Surv(time, status) ~ 1,
                                               data = holdout,
                                               pred.times = timepoints,
                                               eval.times = timepoints ) )
    
    apperror <- unlist(statistic$AppCindex)
    time <- timepoints
    concordance <-  mean(apperror)
    ref <- 0.5
    out <- list(concordance, apperror, ref, time, statistic)
    names(out) <- c( "concordance", "apperror", "reference","time",
                     "object")
    
    out
  }

#'  Calculate survival ROC
#'
#' These functions calculate the survival analysis metric measured of a
#' system compared to a hold-out test set. The measurement and the "truth"
#' have a survival time and a censoring indicator 0/1 indicating if the event
#' result or the event.
#'
#'
#' The Brier score is defined as the squared distance between the
#' expected survival probability and the observed survival.
#' Therefore, it measures the discrepancy between observation
#' and model-based prediction.
#'
#' The integrated Brier Score summarises the Brier Score over the range
#' of observed events.Similar to the original Brier score [40] the iBrier:
#' ranges from 0 to 1; the model with an out-of-training sample value closer
#' to 0 outperforms the rest.
#' @aliases tdbrier tdbrier.model.list tdbrier.int.matrix
#' stdbrier.int.reference
#' @param data For the default functions, a datframe containing survival
#' (time), and status (0:censored/1:event), and the explanatory variables.
#' @param mod Coxph model object fitted with coxph (survival).
#' @return A tdBrier object
#' @seealso [iBrier]
#' @keywords brier
#' @examples
#' require(survival)
#' require(dplyr)
#' data(lung)
#' lung <- lung %>%
#' mutate(status = (status == 2))
#'
#' mod <- coxph(Surv(time, status)~ age, data = lung)
#'
#' tdbrier <- get_tdbrier(lung, mod)
#' integrate_tdbrier(tdroc)
#'
#' @export get_survroc
#' @author Carlos S Traynor
#' @references
#'
#'  Ulla B. Mogensen, Hemant Ishwaran, Thomas A. Gerds (2012).
#' Evaluating Random Forests for Survival Analysis Using Prediction Error
#' Curves. Journal of Statistical Software, 50(11), 1-23.
#' URL http://www.jstatsoft.org/v50/i11/.
get_survroc <- function(holdout, fit, time = "time", status = "status", ncores = ncores){
  
  #select timepoints with observed event
  holdout <- prepare_surv(holdout, time = time, status = status)
  timepoints =  seq(min(get_obs.time(holdout)), max(get_obs.time(holdout)), length.out = 100)
  
  if(any( class(fit)  == "coxph") ){
    probs <- predict(fit, newdata = holdout, type = "lp")
    
    statistic <- tdROC::tdROC(X = probs[!is.na(probs)],
                              Y = holdout[[time]][!is.na(probs)],
                              delta = holdout[[status]][!is.na(probs)],
                              tau = max(holdout[[time]][holdout[[status]] ]),
                              n.grid = 1000)
    
    iroc <- integrate_tdroc(statistic)
    
    sens <-  statistic$ROC$sens
    
    spec <-  statistic$ROC$spec
    
    grid <- statistic[[1]]$ROC$grid
    
    out <- list(iroc, sens, spec, grid)
    names(out) <- c( "iroc", "sens", "spec","grid")
    
    out
    
  } else if(any( class(fit) == "gam" )){
    beta <- fit$coefficients[!grepl("Intercept|tend", names(fit$coefficients))]
    
    vars <- unlist( strsplit(deparse(fit$formula[[3]]), "[+]") )
    vars <- gsub(" ", "", vars) #remove space
    vars <- vars[!grepl("1|offset|tend",  vars)]
    vars <- vars[vars != ""]
    
    X <- gam_matrix(x = holdout, vars = vars)
    
    probs <- as.matrix(X) %*% as.vector(unlist( beta) )
    
    statistic <- tdROC::tdROC(X = probs[!is.na(probs)],
                              Y = holdout[[time]][!is.na(probs)],
                              delta = holdout[[status]][!is.na(probs)],
                              tau = max(holdout[[time]][holdout[[status]] ]),
                              n.grid = 1000)
    
    iroc <- integrate_tdroc(statistic)
    
    sens <-  statistic$ROC$sens
    
    spec <-  statistic$ROC$spec
    
    grid <- statistic[[1]]$ROC$grid
    
    out <- list(iroc, sens, spec, grid)
    names(out) <- c( "iroc", "sens", "spec","grid")
    
    out
    
    
  } else if (any( class(fit$stanfit) == "stanfit" ) ){
    
    if(!check_null_model(fit)){
      beta_draws <- extract_beta_draws(fit);
      varis.obs <- get_predictors(fit);
      vars <- unlist( strsplit(deparse(fit$formula$rhs ), "[+]") )
      vars <- gsub(" ", "", vars) #remove space
      vars <- vars[vars != ""]
      
      x <- holdout[fit$formula$allvars[-(1:2)]]
      
      X <- data.frame(id = seq_along(1:nrow(x)))
      
      for(i in seq_along(vars) ){
        
        if(is.numeric(x[[i]])){
          if(grepl("log", vars[i])){
            x[[i]] <- log(x[[i]])
          }
          matrixna <- x[i]
          X <- cbind(X, matrixna)
        } else{
          x[[i]] <- as.factor(x[[i]])
          coluna <- x[i]
          matrixna <- model.matrix(~ ., coluna)[ ,-1]
          matrixna <- as.data.frame(matrixna)
          if('matrixna' %in% colnames(matrixna) ){
            colnames(matrixna) <- paste0(colnames(coluna), "1")
          }
          X <- cbind(X, matrixna)
        }
      }
      X <- X[-match("id", colnames(X) )]
      beta_draws_mean <- apply(beta_draws, 2, mean)
      
      probs <- as.matrix(X) %*% as.vector(unlist( beta_draws_mean ) )
      
      statistic <- tdROC::tdROC(X = probs[!is.na(probs)],
                                Y = holdout[[time]][!is.na(probs)],
                                delta = holdout[[status]][!is.na(probs)],
                                tau = max(holdout[[time]][holdout[[status]] ]),
                                n.grid = 1000)
      
      iroc <- integrate_tdroc(statistic)
      
      sens <-  statistic$ROC$sens
      
      spec <-  statistic$ROC$spec
      
      grid <- statistic[[1]]$ROC$grid
      
      out <- list(iroc, sens, spec, grid)
      names(out) <- c( "iroc", "sens", "spec","grid")
      
      out
      
      
      
      # 
      # roc.out <- lapply(seq_along(1:nrow(beta_draws)), function(samp){
      #   probs <- as.matrix(X) %*% as.vector(unlist( beta_draws[samp, ] ))
      #   tdROC::tdROC(X = probs[!is.na(probs)],
      #                Y = holdout[[time]][!is.na(probs)],
      #                delta = holdout[[status]][!is.na(probs)],
      #                tau = max(holdout[[time]][holdout[[status]] ]),
      #                n.grid = 1000)
      # } )
      # 
      # iroc <- sapply(roc.out, function(statistic){
      #   integrate_tdroc(statistic)
      # })
      # 
      # iroc <- mean(unlist( iroc))
      # 
      # sens <- lapply(roc.out, function(sensibility){
      #   sensibility$ROC$sens
      # })
      # 
      # sens <- do.call(cbind, sens)
      # sens <- apply(sens, 1, mean)
      # 
      # spec <- lapply(roc.out, function(specificity){
      #   specificity$ROC$spec
      # })
      # 
      # spec <- do.call(cbind, spec)
      # spec <- apply(spec, 1, mean)
      # 
      # grid <- roc.out[[1]]$ROC$grid
      # 
      # out <- list(iroc, sens, spec, grid)
      # names(out) <- c( "iroc", "sens", "spec","grid")
      # 
      # out
      
    } else {
      "Print object must be either coxph or stanreg"
    }
  }
}

#'  Calculate survival ROC
#' @export integrate_tdroc
#' @rdname integrate_tdroc
integrate_tdroc  <-
  function( fit, ...) {
    fit$AUC[1] %>% unlist
  }
csetraynor/rstanhaz documentation built on May 9, 2019, 8:14 a.m.