# Function to predict the survival
#'
#' Predict survival from a stan_surv object
#'
#' @export link_surv
#' @importFrom Hmisc approxExtrap
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#' data = df, cores = 1, chains = 1, iter = 2000,
#' basehaz = "fpm", iknots = c(6.594869, 7.285963 ),
#' degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
link_surv <- function(fit, testx = NULL, timepoints = NULL, method = "linear", time = "time", status = "status"){
obs.time <- get_obs.time(fit$data)
if(method != "linear" && method != "constant"){
stop2(paste0("method ", method, " not implement.") )
}
if(is.null(timepoints)){
timepoints <- test.time <- get_obs.time(testx)
}
basehaz.samples <- extract_basehaz_draws(fit)
spline.basis <- extract_splines_bases(fit)
if(!check_null_model(fit)){
beta_draws <- extract_beta_draws(fit);
varis.obs <- get_predictors(fit)
surv.post <- lapply(seq_along(1:nrow(basehaz.samples)),
function(s){
basehaz.post <- sapply(seq_along(1:nrow(spline.basis)),
function(n){
basehaz.samples[s, ] %*% spline.basis[n, ]
})
surv_base <- link_surv_base.helper(b = basehaz.post)
surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
surv.approx[surv.approx > 1] <- 1
surv.approx[surv.approx < 0] <- 0
linear_predictor <- varis.obs %*% beta_draws[s, ]
link_surv_lin.helper(p = linear_predictor, s = surv.approx)
})
} else {
surv.post <- lapply(seq_along(1:nrow(basehaz.samples)),
function(s){
basehaz.post <- sapply(seq_along(1:nrow(spline.basis)),
function(n){
basehaz.samples[s, ] %*% spline.basis[n, ]})
surv_base <- link_surv_base.helper(b = basehaz.post)
surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
surv.approx[surv.approx > 1] <- 1
surv.approx[surv.approx < 0] <- 0
surv.approx
})
}
surv.post <- do.call(cbind, surv.post)
surv.post
}
link_surv_base.helper <- function(b){
#to translate baseline survival
exp(-c(b ) )
}
link_surv_lin.helper <- function(p, s){
s^(exp(c(p ) ) )
}
prepare_surv <- function(d, time = "time", status = "status"){
if(!is_in("time", colnames(d) ) ){
d$time <- d[[time]]
}
if(!is_in("status", colnames(d) ) ){
d$status <- d[[status]]
}
d <- d[order(d$time), ]
d
}
get_obs.time <- function(d){
unique( sort( d$time[ as.logical(d$status) ] ) )
}
extract_basehaz_draws <- function(fit){
as.matrix( rstan::extract(fit$stanfit)$basehaz_coefs )
}
extract_splines_bases <- function(fit){
unlist( fit$basehaz$spline_basis )
}
check_null_model <- function(fit){
!( n_distinct(fit$formula$allvars) > 2 )
}
extract_beta_draws <- function(fit){
if(!check_null_model(fit)){
as.matrix( rstan::extract(fit$stanfit)$beta )
} else {
NULL
}
}
get_km.frame <- function(obs, strata = NULL, time = "time", status = "status"){
if(!is.null(strata)){
form <- as.formula(paste0("Surv(", time, ",", status, ") ~ ", paste(strata, collapse = "+")))
mle.surv <- survfit(form, data = obs )
obs.mortality <- data.frame(time = mle.surv$time,
surv = mle.surv$surv,
strata = summary(mle.surv, censored=T)$strata)
zeros <- data.frame(time=0, surv=1, strata=unique(obs.mortality$strata))
obs.mortality <- rbind(obs.mortality, zeros)
strata <- str_strata(obs.mortality)
obs.mortality$strata <- NULL
obs.mortality <- cbind( obs.mortality, strata)
}else{
form <- as.formula(paste0("survival::Surv(", time, ",", status, ") ~ 1"))
mle.surv <- survival::survfit(form, data = obs )
obs.mortality <- data.frame(time = mle.surv$time,
surv = mle.surv$surv
)
zeros <- data.frame(time=0, surv=1)
obs.mortality <- rbind(obs.mortality, zeros)
}
obs.mortality
}
# Function to predict the survival for each individual
#'
#' Predict survival from a stan_surv object
#'
#' @export link_surv
#' @importFrom Hmisc approxExtrap
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#' data = df, cores = 1, chains = 1, iter = 2000,
#' basehaz = "fpm", iknots = c(6.594869, 7.285963 ),
#' degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
pred_surv <- function(fit, testx = NULL, timepoints = NULL, method = "linear", time = "time", status = "status", ncores = 1L){
obs.time <- get_obs.time(fit$data)
if(method != "linear" && method != "constant"){
stop2(paste0("method ", method, " not implement.") )
}
if(is.null(timepoints)){
timepoints <- test.time <- get_obs.time(testx)
}
basehaz.samples <- extract_basehaz_draws(fit)
spline.basis <- extract_splines_bases(fit)
if(!check_null_model(fit)){
beta_draws <- extract_beta_draws(fit);
varis.obs <- get_predictors(fit)
ind.post <- parallel::mclapply(seq_along(1:nrow(testx)), function(i){
surv.post <- lapply(seq_along(1:nrow(basehaz.samples)), function(s){
basehaz.post <- sapply(seq_along(1:nrow(spline.basis)), function(n){
basehaz.samples[s, ] %*% spline.basis[n, ]
})
surv_base <- link_surv_base.helper(b = basehaz.post)
surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
# surv.approx <- approx(x = obs.time, y = surv_base, xout = timepoints, method = method)$y
surv.approx[surv.approx > 1] <- 1
surv.approx[surv.approx < 0] <- 0
surv.approx
linear_predictor <- varis.obs[i, ] %*% beta_draws[s, ]
link_surv_lin.helper(p = linear_predictor, s = surv.approx)
})
surv.post <- do.call(cbind, surv.post)
apply(surv.post, 1, median)
}, mc.cores = ncores)
ind.post <- do.call(rbind, ind.post)
rownames(ind.post) <- rownames(testx)
} else {
# For null model
surv.post <- lapply(seq_along(1:nrow(basehaz.samples)), function(s){
basehaz.post <- sapply(seq_along(1:nrow(spline.basis)), function(n){
basehaz.samples[s, ] %*% spline.basis[n, ]
})
surv.base <- link_surv_base.helper(b = basehaz.post)
if(method == "linear" | method == "constant"){
# surv.approx <- approx(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # only interpolation
surv.approx <- Hmisc::approxExtrap(x = obs.time, y = surv.base, xout = timepoints, method = method)$y # with extrapolation
} else if(method == "gp" | method == "gaussian process"){
surv.bgp <- tgp:: btgp(X= obs.time, Z= surv.base)
surv.bgp$Xsplit <- rbind(obs.time, timepoints)
surv.approx <- predict(object = surv.bgp, XX = timepoints)$ZZ.mean
} else {
stop2(paste0("method ", method, " not supported"))
}
surv.approx[surv.approx > 1] <- 1
surv.approx[surv.approx < 0] <- 0
surv.approx
})
surv.post <- do.call(cbind, surv.post)
surv.post <- apply(surv.post, 1, mean)
#No difference in survival expected in the null model
ind.post <- matrix(rep(surv.post, nrow(testx)),
nrow = nrow(testx),
ncol = n_distinct(timepoints),
byrow = TRUE)
}
return(ind.post)
}
#get predictors from formula
get_predictors <- function(fit){
if(!check_null_model(fit)){
obs.p <- extract_predictors(fit)
factor_vars <- get_factors(fit = fit)
obs.p[ ,match(colnames(factor_vars), colnames(obs.p))] <- factor_vars
if(n_distinct(colnames(obs.p)) > 1000){
obs.p <- preproc_big_mat(obs.p)
} else if(ncol(obs.p) == 1 && nlevels(obs.p[[1]]) < 3) {
obs.p <-preproc_one_mat(obs.p)
} else {
obs.p <- model.matrix(~ ., obs.p)[ ,-1]
}
as.matrix(obs.p)
} else {
NULL
}
}
# Better approach to model matrix with high dimensional datasets
model_matrix <- function(da, varis = NULL, form = NULL){
if(!is.null(varis)){
obs.p <- da[varis]
} else if (!is.null(form)){
varis <- unlist(strsplit(deparse(form[[3]]), "\\+"))
obs.p <- da[varis]
} else {
varis <- colnames(da)
obs.p <- da
warning("Model matrix build with all variables from dataset")
}
factor_vars <- get_factors(p = obs.p, v = varis )
obs.p[ ,match(colnames(factor_vars), colnames(obs.p))] <- factor_vars
if(n_distinct(colnames(obs.p)) > 1000){
obs.p <- preproc_big_mat(obs.p)
} else if(ncol(obs.p) == 1 && nlevels(obs.p[[1]]) < 3) {
obs.p <-preproc_one_mat(obs.p)
} else {
obs.p <- model.matrix(~ ., obs.p)[ ,-1]
}
as.matrix(obs.p)
}
# Deal with factors in the posterior fit
get_factors <- function(p, v, fit = NULL){
if(!is.null(fit)){
p <- extract_predictors(fit)
v <- extract_vars(fit)
}
#from formula
fact.1 <- p[grep("factor\\(", v)]
names_fact.1 <- colnames(fact.1)
if(n_distinct(names_fact.1)){
fact.1 <- lapply(seq_along(colnames(fact.1) ), function(i) as.factor(fact.1[ ,i]))
# fact.1 <- lapply(fact.1, addNA)
if(n_distinct(fact.1) > 1){
fact.1 <- do.call(cbind.data.frame, fact.1)
} else {
fact.1 <- as.data.frame(fact.1)
}
colnames(fact.1) <- names_fact.1
fact.1
} else {
NULL
}
}
extract_predictors <- function(fit){
if(!check_null_model(fit)){
fit$data[ fit$formula$allvars[-(1:2)] ]
} else {
NULL
}
}
extract_vars <- function(fit){
vars <- unlist( strsplit(deparse(fit$formula$rhs), "[+]") )
vars <- gsub(" ", "", vars) #remove space
vars
}
preproc_one_mat <- function(o){
coluna <- as.data.frame(o)
colnames_coluna <- colnames(coluna)
matrixna <- model.matrix(~ ., coluna)[ ,-1]
matrixna <- as.data.frame(matrixna)
colnames(matrixna) <- colnames_coluna
as.data.frame(matrixna)
}
preproc_big_mat <- function(big_mat, patient_id = NULL){
#Convert to factor
big_mat <- as.data.frame(big_mat)
big_mat <- lapply(big_mat, as.factor)
#add NA as factor
big_mat <- lapply(big_mat, addNA)
X <- lapply(big_mat, function(i){
preproc_one_mat(i)
})
big_mat <- do.call(cbind.data.frame, X)
#remove zero var columns
Filter(function(x)!all(is.na(x)), big_mat)
return(big_mat)
}
#' String method for strata in surv analysis
#'
#' This function is a string method to facilitate ploting.
#'
#' @param c a character frame \cr
#' @return d a clean dataset
#' @export
#' @importFrom magrittr %>%
str_strata <- function(c){
test <- strsplit(as.character( c$strata ), "\\,")
test <-as_data_frame( do.call(rbind, test) )
colnames(test) <- gsub("=.*| ","", test[1, ] )
strata_replace <- function(x){
x <- gsub(".*=| ", "", x)
x
}
test <- test %>%
dplyr::mutate_all(funs(strata_replace))
test
}
convert_blank_to_na <- function(x) {
if(!purrr::is_character(x)){
return(x)
print("Error not character")
} else {
ifelse(x == "" | x == "[Not Available]" | x == "--" | x == "not reported" | x == "[Not Applicable]", NA, x)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.