#' Bayesian proportional hazards regression
#'
#' Bayesian inference for proportional hazards regression models. The user
#' can specify a variety of standard parametric distributions for the
#' baseline hazard, or a Royston-Parmar flexible parametric model.
#'
#' @export
#'
#' @examples
#' pbc2 <- survival::pbc
#' pbc2 <- pbc2[!is.na(pbc2$trt),]
#' pbc2$status <- as.integer(pbc2$status > 0)
#' m1 <- stan_surv(survival::Surv(time, status) ~ trt, data = pbc2)
#'
#' df <- flexsurv::bc
#' m2 <- stan_surv(survival::Surv(rectime, censrec) ~ group,
#' data = df, cores = 1, chains = 1, iter = 2000,
#' basehaz = "fpm", iknots = c(6.594869, 7.285963 ),
#' degree = 2, prior_aux = normal(0, 2, autoscale = F))
#'
generate_stan_data_2 <- function(formula1, formula2, data, basehaz = "fpm", timescale = "log",df = 5L, degree = 3L, iknots = NULL, bknots = NULL,
prior_1 = normal(),
prior_2 = hs(),
prior_intercept = normal(), add_prior = NULL,
prior_aux = list(), prior_PD = FALSE,
algorithm = c("sampling", "meanfield", "fullrank"),
adapt_delta = 0.95, max_treedepth = 11L,
init = "random", cores = 1L, out_data = FALSE, ...) {
#-----------------------------
# Pre-processing of arguments
#-----------------------------
dots <- list(...)
algorithm <- match.arg(algorithm)
# Formula
formula1 <- parse_formula(formula1, data)
formula2 <- parse_formula(formula2, data)
# Data
data <- validate_arg(data, "data")
#----------------
# Construct data
#----------------
standata <- list()
# model data frame
mf <- data
#----- dimensions, response, predictor matrix
# time variable for each row of data
standata$t_beg <- make_t(formula1, mf, type = "beg") # beg time
standata$t_end <- make_t(formula1, mf, type = "end") # end time
standata$t_gap <- make_t(formula1, mf, type = "gap") # gap time
# event indicator for each row of data
standata$d <- make_d(formula1, mf)
# design matrices for linear predictor
x1 <- make_x(formula1, mf) # fe predictor matrix
standata$nrows <- x1$N
standata$K_1 <- x1$K
standata$x_1 <- x1$x
standata$xbar_1 <- x1$xbar
#-------------------------
x2 <- make_x(formula2, mf) # fe predictor matrix
standata$K_2 <- x2$K
standata$x_2 <- x2$x
standata$xbar_2 <- x2$xbar
#standata$z <- make_z(formula, mf) # re predictor matrix
#standata$g <- make_g(formula, mf) # re group ids (for each row)
#----- time-dependent effects (i.e. non-proportional hazards)
# degrees of freedom for time-dependent effects
standata$df_tde <- aa(rep(0L, standata$K_1)) # not implemented yet
#----- baseline hazard
ok_basehaz <- c("exponential", "weibull", "fpm", "fpm2")
basehaz <- handle_basehaz(basehaz, df = df, degree = degree,
iknots = iknots, bknots = bknots,
t_beg = standata$t_beg, t_end = standata$t_end,
d = standata$d, ok_basehaz = ok_basehaz,
timescale = timescale)
standata$type <- ai(basehaz$type)
standata$df <- ai(basehaz$df)
#standata$basehaz_x_beg <- make_basehaz_x(standata$t_beg, basehaz = basehaz)
#standata$basehaz_dx_beg <- make_basehaz_x(standata$t_beg, basehaz = basehaz, deriv = TRUE)
standata$basehaz_x_beg <- matrix(0, standata$nrows, standata$df)
standata$basehaz_dx_beg <- matrix(0, standata$nrows, standata$df)
standata$basehaz_x_end <- make_basehaz_x(standata$t_end, basehaz = basehaz,
timescale = timescale)
standata$basehaz_dx_end <- make_basehaz_x(standata$t_end, basehaz = basehaz,
timescale = timescale, deriv = TRUE)
standata$has_intercept <- ai(has_intercept(basehaz))
#----- priors and hyperparameters 1
# priors
user_prior_1_stuff <- prior_1_stuff <-
handle_prior(prior_1, nvars = x1$K,
default_scale = 2,
ok_dists = ok_dists())
user_prior_intercept_stuff <- prior_intercept_stuff <-
handle_prior(prior_intercept, nvars = 1,
default_scale = 2,
ok_dists = ok_dists_for_intercept())
user_prior_aux_stuff <- prior_aux_stuff <-
handle_prior(prior_aux, nvars = basehaz$df,
default_scale = get_default_aux_scale(basehaz$name),
ok_dists = ok_dists_for_aux())
# autoscaling of priors
prior_1_stuff <- autoscale_prior(prior_1_stuff, predictors = x1$x)
prior_intercept_stuff <- autoscale_prior(prior_intercept_stuff)
prior_aux_stuff <- autoscale_prior(prior_aux_stuff)
# priors
standata$prior_1_dist <- prior_1_stuff$prior_dist
standata$prior_dist_for_intercept<- prior_intercept_stuff$prior_dist
standata$prior_dist_for_aux <- prior_aux_stuff$prior_dist
# hyperparameters
standata$prior_1_mean <- prior_1_stuff$prior_mean
standata$prior_1_scale <- prior_1_stuff$prior_scale
standata$prior_1_df <- prior_1_stuff$prior_df
standata$prior_mean_for_intercept <- c(prior_intercept_stuff$prior_mean)
standata$prior_scale_for_intercept<- c(prior_intercept_stuff$prior_scale)
standata$prior_df_for_intercept <- c(prior_intercept_stuff$prior_df)
standata$prior_scale_for_aux <- prior_aux_stuff$prior_scale
standata$prior_df_for_aux <- prior_aux_stuff$prior_df
standata$global_prior_1_scale <- prior_1_stuff$global_prior_scale
standata$global_prior_1_df <- prior_1_stuff$global_prior_df
standata$slab_1_df <- prior_1_stuff$slab_df
standata$slab_1_scale <- prior_1_stuff$slab_scale
#---- use additional prior
# if(!is.null(add_prior) ){
# user_set_prior_stuff <- set_prior_stuff <-
# handle_prior(set_prior$prior, nvars = set_prior$K,
# default_scale = 2,
# ok_dists = ok_dists())
# }
# not implemented as this now
#---- priors and hyperparameters 2
# priors
user_prior_2_stuff <- prior_2_stuff <-
handle_prior(prior_2, nvars = x2$K,
default_scale = 2,
ok_dists = ok_dists())
# autoscaling of priors
prior_2_stuff <- autoscale_prior(prior_2_stuff, predictors = x2$x)
# priors
standata$prior_2_dist <- prior_2_stuff$prior_dist
# hyperparameters
standata$prior_2_mean <- prior_2_stuff$prior_mean
standata$prior_2_scale <- prior_2_stuff$prior_scale
standata$prior_2_df <- prior_2_stuff$prior_df
standata$global_prior_2_scale <- prior_2_stuff$global_prior_scale
standata$global_prior_2_df <- prior_2_stuff$global_prior_df
standata$slab_2_df <- prior_2_stuff$slab_df
standata$slab_2_scale <- prior_2_stuff$slab_scale
#----- additional flags
standata$prior_PD <- ai(prior_PD)
standata$delayed <- ai(!all_zero(standata$t_beg))
standata$npats <- standata$nevents <- 0L # not currently used
# str(standata)
#-----------
# Fit model
#-----------
stanfit <- stanmodels$surv
stanpars <- pars_to_monitor_2(standata)
if (algorithm == "sampling") {
args <- nlist(
object = stanfit,
data = standata,
pars = stanpars,
show_messages = FALSE,
cores = cores
)
out <- list("basehaz" = basehaz, "standata" = standata)
return(out)
stop("returning data")
# args <- set_sampling_args(
# object = stanfit,
# data = standata,
# pars = stanpars,
# prior = NULL,
# dots = dots,
# adapt_delta = adapt_delta,
# max_treedepth = max_treedepth,
# init = init,
# show_messages = FALSE,
# cores = cores
# )
args[names(dots)] <- dots
stanfit <- do.call(rstan::sampling, args)
} else {
args <- nlist(
object = stanfit,
data = standata,
pars = stanpars,
algorithm # meanfield or fullrank vb
)
args[names(dots)] <- dots
stanfit <- do.call(rstan::vb, args)
}
#check_stanfit(stanfit)
fit <- nlist(stanfit, formula, data, basehaz, algorithm,
stan_function = "stan_surv", call = match.call(expand.dots = TRUE))
#out <- stansurv(fit)
return(fit)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.