R/helpers-surv.R

# Function to predict the survival
#'
#' Predict survival from a stan_surv object
#'
#' @export link_surv
#' @importFrom Hmisc approxExtrap
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#'                 data = df, cores = 1, chains = 1, iter = 2000,
#'                 basehaz = "fpm", iknots = c(6.594869,  7.285963 ),
#'                 degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
link_surv <- function(fit, testx = NULL, timepoints = NULL, method = "linear", time = "time", status = "status"){
  obs.time <- get_obs.time(fit$data)
  if(method != "linear" && method != "constant"){
    stop2(paste0("method ", method, " not implement.") )
  }
  if(is.null(timepoints)){
    timepoints <- test.time <- get_obs.time(testx)
  }
  basehaz.samples <- extract_basehaz_draws(fit)
  spline.basis <- extract_splines_bases(fit)

  if(!check_null_model(fit)){
    beta_draws <- extract_beta_draws(fit);
    varis.obs <- get_predictors(fit)

    surv.post <- lapply(seq_along(1:nrow(basehaz.samples)),
                        function(s){
      basehaz.post <- sapply(seq_along(1:nrow(spline.basis)),
                             function(n){
        basehaz.samples[s, ]  %*% spline.basis[n, ]
      })
      surv_base <- link_surv_base.helper(b = basehaz.post)

      surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
      surv.approx[surv.approx > 1] <- 1
      surv.approx[surv.approx < 0] <- 0

      linear_predictor <- varis.obs %*% beta_draws[s, ]
      link_surv_lin.helper(p = linear_predictor, s = surv.approx)
    })
  } else {
    surv.post <- lapply(seq_along(1:nrow(basehaz.samples)),
    function(s){
        basehaz.post <- sapply(seq_along(1:nrow(spline.basis)),
        function(n){
          basehaz.samples[s, ]  %*% spline.basis[n, ]})
        surv_base <- link_surv_base.helper(b = basehaz.post)
        surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
        surv.approx[surv.approx > 1] <- 1
        surv.approx[surv.approx < 0] <- 0
        surv.approx
  })
  }
  surv.post <- do.call(cbind, surv.post)
  surv.post
}

link_surv_base.helper <- function(b){
    #to translate baseline survival
    exp(-c(b ) )
}

link_surv_lin.helper <- function(p, s){
  s^(exp(c(p ) ) )
}

prepare_surv <- function(d, time = "time", status = "status"){
  if(!is_in("time", colnames(d) ) ){
    d$time <- d[[time]]
  }
  if(!is_in("status", colnames(d) ) ){
    d$status <- d[[status]]
  }
  d <- d[order(d$time), ]
  d
}

get_obs.time <- function(d){
  unique( sort(  d$time[ as.logical(d$status) ] ) )
}

extract_basehaz_draws <- function(fit){
  as.matrix( rstan::extract(fit$stanfit)$basehaz_coefs )
}

extract_splines_bases <- function(fit){
  unlist( fit$basehaz$spline_basis )
}

check_null_model <- function(fit){
  !( n_distinct(fit$formula$allvars) > 2 )
}

extract_beta_draws <- function(fit){
  if(!check_null_model(fit)){
    as.matrix( rstan::extract(fit$stanfit)$beta )
  } else {
    NULL
  }
}



get_km.frame <- function(obs, strata = NULL, time = "time", status = "status"){

  if(!is.null(strata)){
    form <- as.formula(paste0("Surv(", time, ",", status, ") ~ ", paste(strata, collapse = "+")))
    mle.surv <- survfit(form, data = obs  )
    obs.mortality <- data.frame(time = mle.surv$time,
                                surv = mle.surv$surv,
                                strata = summary(mle.surv, censored=T)$strata)
    zeros <- data.frame(time=0, surv=1, strata=unique(obs.mortality$strata))
    obs.mortality <- rbind(obs.mortality, zeros)

    strata <- str_strata(obs.mortality)
    obs.mortality$strata <- NULL
    obs.mortality <- cbind( obs.mortality, strata)
  }else{
    form <- as.formula(paste0("survival::Surv(", time, ",", status, ") ~ 1"))
    mle.surv <- survival::survfit(form, data = obs  )
    obs.mortality <- data.frame(time = mle.surv$time,
                                surv = mle.surv$surv
                            )
    zeros <- data.frame(time=0, surv=1)
    obs.mortality <- rbind(obs.mortality, zeros)
  }

  obs.mortality
}


# Function to predict the survival for each individual
#'
#' Predict survival from a stan_surv object
#'
#' @export link_surv
#' @importFrom Hmisc approxExtrap
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#'                 data = df, cores = 1, chains = 1, iter = 2000,
#'                 basehaz = "fpm", iknots = c(6.594869,  7.285963 ),
#'                 degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
pred_surv <- function(fit, testx = NULL, timepoints = NULL, method = "linear", time = "time", status = "status", ncores = 1L){
  obs.time <- get_obs.time(fit$data)
  if(method != "linear" && method != "constant"){
    stop2(paste0("method ", method, " not implement.") )
  }
  if(is.null(timepoints)){
    timepoints <- test.time <- get_obs.time(testx)
  }
  basehaz.samples <- extract_basehaz_draws(fit)
  spline.basis <- extract_splines_bases(fit)

  if(!check_null_model(fit)){
    beta_draws <- extract_beta_draws(fit);
    varis.obs <- get_predictors(fit)

    ind.post <- parallel::mclapply(seq_along(1:nrow(testx)), function(i){
      surv.post <- lapply(seq_along(1:nrow(basehaz.samples)), function(s){
        basehaz.post <- sapply(seq_along(1:nrow(spline.basis)), function(n){
          basehaz.samples[s, ]  %*% spline.basis[n, ]
        })
        surv_base <- link_surv_base.helper(b = basehaz.post)
        surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
        # surv.approx <- approx(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
        surv.approx[surv.approx > 1] <- 1
        surv.approx[surv.approx < 0] <- 0
        surv.approx
        linear_predictor <- varis.obs[i, ] %*% beta_draws[s, ]
        link_surv_lin.helper(p = linear_predictor, s = surv.approx)
      })
      surv.post <- do.call(cbind, surv.post)
      apply(surv.post, 1, median)
    }, mc.cores = ncores)
    ind.post <- do.call(rbind, ind.post)
    rownames(ind.post) <- rownames(testx)
  } else {
    # For null model
    surv.post <- lapply(seq_along(1:nrow(basehaz.samples)), function(s){
      basehaz.post <- sapply(seq_along(1:nrow(spline.basis)), function(n){
        basehaz.samples[s, ]  %*% spline.basis[n, ]
      })
      surv.base <- link_surv_base.helper(b = basehaz.post)

      if(method == "linear" | method == "constant"){
        # surv.approx <- approx(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # only interpolation
        surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # with extrapolation
      } else if(method == "gp" | method == "gaussian process"){
        surv.bgp <- tgp:: btgp(X= obs.time, Z= surv.base)
        surv.bgp$Xsplit <- rbind(obs.time, timepoints)
        surv.approx <- predict(object = surv.bgp, XX = timepoints)$ZZ.mean
      } else {
        stop2(paste0("method ", method, " not supported"))
      }

      surv.approx[surv.approx > 1] <- 1
      surv.approx[surv.approx < 0] <- 0
      surv.approx
    })
    surv.post <- do.call(cbind, surv.post)
    surv.post <- apply(surv.post, 1, mean)
    #No difference in survival expected in the null model
    ind.post <- matrix(rep(surv.post, nrow(testx)),
                       nrow = nrow(testx),
                       ncol = n_distinct(timepoints),
                       byrow = TRUE)
  }
  return(ind.post)
}



#get predictors from formula

get_predictors <- function(fit){
  if(!check_null_model(fit)){
    obs.p <- extract_predictors(fit)
    factor_vars <- get_factors(fit = fit)
    obs.p[ ,match(colnames(factor_vars), colnames(obs.p))] <- factor_vars
    if(n_distinct(colnames(obs.p)) > 1000){
      obs.p <- preproc_big_mat(obs.p)
    } else if(ncol(obs.p) == 1 && nlevels(obs.p[[1]]) < 3) {
      obs.p <-preproc_one_mat(obs.p)
    } else {
      obs.p <- model.matrix(~ ., obs.p)[ ,-1]
    }
    as.matrix(obs.p)
  } else {
    NULL
  }
}



# Better approach to model matrix with high dimensional datasets
model_matrix <- function(da, varis = NULL, form = NULL){
  if(!is.null(varis)){
    obs.p <- da[varis]
  } else if (!is.null(form)){
      varis <- unlist(strsplit(deparse(form[[3]]), "\\+"))
      obs.p <- da[varis]
    } else {
      varis <- colnames(da)
      obs.p <- da
      warning("Model matrix build with all variables from dataset")
    }
  factor_vars <- get_factors(p = obs.p, v = varis )
  obs.p[ ,match(colnames(factor_vars), colnames(obs.p))] <- factor_vars
  if(n_distinct(colnames(obs.p)) > 1000){
    obs.p <- preproc_big_mat(obs.p)
  } else if(ncol(obs.p) == 1 && nlevels(obs.p[[1]]) < 3) {
    obs.p <-preproc_one_mat(obs.p)
  } else {
    obs.p <- model.matrix(~ ., obs.p)[ ,-1]
  }
  as.matrix(obs.p)
}

# Deal with factors in the posterior fit
get_factors <- function(p, v, fit = NULL){
  if(!is.null(fit)){
    p <- extract_predictors(fit)
    v <- extract_vars(fit)
  }
  #from formula
  fact.1 <- p[grep("factor\\(", v)]
  names_fact.1 <- colnames(fact.1)
  if(n_distinct(names_fact.1)){
    fact.1 <- lapply(seq_along(colnames(fact.1) ), function(i) as.factor(fact.1[ ,i]))
    # fact.1 <- lapply(fact.1, addNA)
    if(n_distinct(fact.1) > 1){
      fact.1 <- do.call(cbind.data.frame, fact.1)
    } else {
      fact.1 <- as.data.frame(fact.1)
    }
    colnames(fact.1) <- names_fact.1
    fact.1
  } else {
    NULL
  }
}


extract_predictors <- function(fit){
  if(!check_null_model(fit)){
    fit$data[ fit$formula$allvars[-(1:2)] ]
  } else {
    NULL
  }
}


extract_vars <- function(fit){
  vars <- unlist( strsplit(deparse(fit$formula$rhs), "[+]") )
  vars <- gsub(" ", "", vars) #remove space
  vars
}



preproc_one_mat <- function(o){
  coluna <- as.data.frame(o)
  colnames_coluna <- colnames(coluna)
  matrixna <- model.matrix(~ ., coluna)[ ,-1]
  matrixna <- as.data.frame(matrixna)
  colnames(matrixna) <- colnames_coluna
  as.data.frame(matrixna)
}


preproc_big_mat <- function(big_mat, patient_id = NULL){

  #Convert to factor
  big_mat <- as.data.frame(big_mat)
  big_mat <- lapply(big_mat, as.factor)
  #add NA as factor
  big_mat <- lapply(big_mat, addNA)

  X <- lapply(big_mat, function(i){
    preproc_one_mat(i)
  })
  big_mat <- do.call(cbind.data.frame, X)
  #remove zero var columns
  Filter(function(x)!all(is.na(x)), big_mat)

  return(big_mat)
}

#' String method for strata in surv analysis
#'
#' This function is a string method to facilitate ploting.
#'
#' @param c a character frame \cr
#' @return d a clean dataset
#' @export
#' @importFrom magrittr %>%
str_strata <- function(c){
  test <- strsplit(as.character( c$strata ), "\\,")
  test <-as_data_frame( do.call(rbind, test) )
  colnames(test) <- gsub("=.*| ","",  test[1, ] )
  strata_replace <- function(x){
    x <- gsub(".*=| ", "", x)
    x
  }
  test <- test %>%
    dplyr::mutate_all(funs(strata_replace))
  test
}



convert_blank_to_na <- function(x) {
  if(!purrr::is_character(x)){
    return(x)
    print("Error not character")
  } else {
    ifelse(x == "" | x == "[Not Available]" | x == "--" | x == "not reported" | x == "[Not Applicable]", NA, x)
  }
}
csetraynor/rstanhaz documentation built on May 9, 2019, 8:14 a.m.