#' Fit a stan model using (modified) code and data generated by brms
#'
#' @param file Load/save csvs to this filepath and name.
#' @param seed Random seed passed to `rstan::stan()`.
#' @param bform A `brms` formula. N.B., transformations of variables within the formula may not be picked up by
#' post-processing functions; it's generally better to create new, transformed variables.
#' @param bdata A data frame used to fit the model.
#' @param bpriors Priors specified by `brms` functions.
#' @param car1 Logical. Generate CAR(1) errors?
#' @param sample_prior Passed on to `brms::brm()`.
#' @param knots Passed on to `brms::brm()`.
#' @param d_x A vector representing the spacing in time of observations in each series,
#' equal to zero at the first timestep. If `NULL` (the default), `d_x` is drawn from the dataframe `bdata`.
#' @param ... Passed on to `rstan::stan()` or `cmdstanr::sample()`, depending on the backend selected.
#' @param family A `brmsfamily` object. Note that some post-processing functions assume a student-t likelihood.
#' @param backend Run Stan's algorithms using `rstan` or `cmdstanr`.
#' @param overwrite Overwrite an exising model stored as CSVs? Defaults to `FALSE`.
#' @param var_car1 For multivariate models, specify which variable to generate CAR(1) errors for.
#' @param lcl A numeric vector of censoring limits for left-censored variables. Alternatively, a list of numeric vectors with censoring limits for
#' each variable and observation; the order should correspond to that of `lcl` and `cens_ind`.
#' @param lower_bound An optional lower bound for left-censored parameters. The default is no lower bound.
#' @param var_xcens A character vector containing the names of variables with left-censoring; the order should
#' correspond to that of `lcl` and `cens_ind`.
#' @param cens_ind A character vector containing the names of left-censoring indicator variables corresponding
#' to the contents of `var_xcens`. (Indicator variables should take the form of character vectors with `"left"` for left-censored and
#' `"none"` for observed.)
#' @param stancode A named character vector of length one, where the name is the model block to modify
#' and the value is the additional Stan code. This has only been validated for the generated quantities block.
#'
#' @return A `brms` model object fitted with `rstan` or `cmdstanr`.
#' @importFrom stringr str_remove str_extract str_detect
#' @importFrom brms brm rename_pars make_stancode make_standata student is.brmsfit read_csv_as_stanfit
#' @importFrom rstan stan read_stan_csv
#' @importFrom dplyr %>%
#' @export
#'
#' @examples
#' library("brms")
#' seed <- 1
#' data <- read.csv(paste0(system.file("extdata", package = "bgamcar1"), "/data.csv"))
#' fit <- fit_stan_model(
#' paste0(system.file("extdata", package = "bgamcar1"), "/test"),
#' seed,
#' bf(y | cens(ycens, y2 = y2) ~ 1),
#' data,
#' prior(normal(0, 1), class = Intercept),
#' car1 = FALSE,
#' save_warmup = FALSE,
#' chains = 3
#' )
fit_stan_model <- function(file,
seed,
bform,
bdata,
bpriors = NULL,
car1 = TRUE,
sample_prior = "no",
knots = NULL,
d_x = NULL,
family = student(),
backend = "rstan",
overwrite = FALSE,
var_car1 = NULL,
var_xcens = NULL,
cens_ind = NULL,
lcl = NULL,
lower_bound = NULL,
stancode = NULL,
...) {
model_saved <- get_model(file)
# generate stan data:
data <- brms::make_standata(
bform,
data = bdata,
prior = bpriors,
family = family,
sample_prior = sample_prior,
knots = knots
)
# check for presence of d_x in data or supplied as an argument, then add to stan data:
if (car1) {
if (is.null(d_x)) {
stopifnot("column d_x not found in data" = !is.null(bdata$d_x))
data$s <- bdata$d_x
} else {
data$s <- d_x
}
}
# modify standata:
if (!is.null(var_xcens)) {
data <- modify_standata(data, bdata, lcl, var_xcens = var_xcens, cens_ind = cens_ind)
}
# generate stan code:
code <- brms::make_stancode(
bform,
data = bdata,
prior = bpriors,
family = family,
sample_prior = sample_prior,
knots = knots
)
# modify stancode:
if (car1) {
code <- modify_stancode(code, var_car1 = var_car1)
}
if (!is.null(var_xcens)) {
code <- modify_stancode(code, modify = "xcens", var_xcens = var_xcens, lower_bound = lower_bound, lcl = lcl)
}
if (!is.null(stancode)) {
code <- add_stancode(code, stancode, block = names(stancode))
}
# fit model:
stanmod <- if (length(model_saved$csvs) > 0 && !overwrite) {
if (backend == "rstan") rstan::read_stan_csv(model_saved$csvs) else
if (backend == "cmdstanr") brms::read_csv_as_stanfit(model_saved$csvs)
} else if (backend == "rstan") {
rstan::stan(
model_code = code,
data = data,
sample_file = file, # output in csv format
seed = seed,
...
)
} else if (backend == "cmdstanr") {
if (!requireNamespace("cmdstanr", quietly = TRUE)) {
stop(
"Package \"cmdstanr\" must be installed for backend == \"cmdstanr\".",
.call = FALSE
)
}
fit_cmdstan_model(
code, data, seed, model_saved$path, model_saved$basename, file, ...
)
} else stop("Backend must be either \"rstan\" or \"cmdstanr\".")
# feed back into brms:
if (length(model_saved$rds) > 0 && !overwrite) {
brmsmod <- brm(
bform,
data = bdata,
prior = bpriors,
family = family,
knots = knots,
file = file,
file_refit = "never"
)
} else {
brmsmod <- brm(
bform,
data = bdata,
prior = bpriors,
family = family,
knots = knots,
empty = TRUE
)
# save empty fit:
if (!is.null(file)) brms:::write_brmsfit(brmsmod, file = file)
}
# add stan model to fit slot:
brmsmod$fit <- stanmod
brmsmod <- rename_pars(brmsmod)
brmsmod$model <- code # replace the original Stan code with the modified code
return(brmsmod)
}
get_model <- function(file) {
path <- str_remove(file, "\\/[^\\/]+$")# remove base filename
bname <- str_extract(file, "[^\\/]+$") # extract base filename
# list csv and rds files matching file path/basename combo:
csvfiles <- list.files(path = path,
pattern = paste0("^", paste0(bname, "[-_]\\d\\.csv")),
full.names = TRUE)
rdsfiles <- list.files(path = path,
pattern = paste0("^", paste0(bname, "\\.rds")),
full.names = TRUE)
list(csvs = csvfiles, path = path, basename = bname, rds = rdsfiles)
}
fit_cmdstan_model <- function(code, data, seed, path, basename, file, ...) {
model_setup <- cmdstanr::cmdstan_model(stan_file = cmdstanr::write_stan_file(code), compile = FALSE)
model_setup$format(overwrite_file = TRUE, canonicalize = TRUE, backup = FALSE)
model_setup$compile()
model <- model_setup$sample(data = data, seed = seed, ...)
model$save_output_files(
dir = path,
basename = basename,
random = FALSE,
timestamp = FALSE
)
# rstan::read_stan_csv(model$output_files())
brms::read_csv_as_stanfit(model$output_files())
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.