R/fit_stan_model.R

Defines functions fit_cmdstan_model get_model fit_stan_model

Documented in fit_stan_model

#' Fit a stan model using (modified) code and data generated by brms
#'
#' @param file Load/save csvs to this filepath and name.
#' @param seed Random seed passed to `rstan::stan()`.
#' @param bform A `brms` formula. N.B., transformations of variables within the formula may not be picked up by
#' post-processing functions; it's generally better to create new, transformed variables.
#' @param bdata A data frame used to fit the model.
#' @param bpriors Priors specified by `brms` functions.
#' @param car1 Logical. Generate CAR(1) errors?
#' @param sample_prior Passed on to `brms::brm()`.
#' @param knots Passed on to `brms::brm()`.
#' @param d_x A vector representing the spacing in time of observations in each series,
#' equal to zero at the first timestep. If `NULL` (the default), `d_x` is drawn from the dataframe `bdata`.
#' @param ... Passed on to `rstan::stan()` or `cmdstanr::sample()`, depending on the backend selected.
#' @param family A `brmsfamily` object. Note that some post-processing functions assume a student-t likelihood.
#' @param backend Run Stan's algorithms using `rstan` or `cmdstanr`.
#' @param overwrite Overwrite an exising model stored as CSVs? Defaults to `FALSE`.
#' @param var_car1 For multivariate models, specify which variable to generate CAR(1) errors for.
#' @param lcl A numeric vector of censoring limits for left-censored variables. Alternatively, a list of numeric vectors with censoring limits for
#' each variable and observation; the order should correspond to that of `lcl` and `cens_ind`.
#' @param lower_bound An optional lower bound for left-censored parameters. The default is no lower bound.
#' @param var_xcens A character vector containing the names of variables with left-censoring; the order should
#' correspond to that of `lcl` and `cens_ind`.
#' @param cens_ind A character vector containing the names of left-censoring indicator variables corresponding
#' to the contents of `var_xcens`. (Indicator variables should take the form of character vectors with `"left"` for left-censored and
#' `"none"` for observed.)
#' @param stancode A named character vector of length one, where the name is the model block to modify
#' and the value is the additional Stan code. This has only been validated for the generated quantities block.
#'
#' @return A `brms` model object fitted with `rstan` or `cmdstanr`.
#' @importFrom stringr str_remove str_extract str_detect
#' @importFrom brms brm rename_pars make_stancode make_standata student is.brmsfit read_csv_as_stanfit
#' @importFrom rstan stan read_stan_csv
#' @importFrom dplyr %>%
#' @export
#'
#' @examples
#' library("brms")
#' seed <- 1
#' data <- read.csv(paste0(system.file("extdata", package = "bgamcar1"), "/data.csv"))
#' fit <- fit_stan_model(
#'   paste0(system.file("extdata", package = "bgamcar1"), "/test"),
#'   seed,
#'   bf(y | cens(ycens, y2 = y2) ~ 1),
#'   data,
#'   prior(normal(0, 1), class = Intercept),
#'   car1 = FALSE,
#'   save_warmup = FALSE,
#'   chains = 3
#' )
fit_stan_model <- function(file,
                           seed,
                           bform,
                           bdata,
                           bpriors = NULL,
                           car1 = TRUE,
                           sample_prior = "no",
                           knots = NULL,
                           d_x = NULL,
                           family = student(),
                           backend = "rstan",
                           overwrite = FALSE,
                           var_car1 = NULL,
                           var_xcens = NULL,
                           cens_ind = NULL,
                           lcl = NULL,
                           lower_bound = NULL,
                           stancode = NULL,
                           ...) {

  model_saved <- get_model(file)

  # generate stan data:

  data <- brms::make_standata(
    bform,
    data = bdata,
    prior = bpriors,
    family = family,
    sample_prior = sample_prior,
    knots = knots
  )

  # check for presence of d_x in data or supplied as an argument, then add to stan data:

  if (car1) {
    if (is.null(d_x)) {
      stopifnot("column d_x not found in data" = !is.null(bdata$d_x))
      data$s <- bdata$d_x
    } else {
      data$s <- d_x
    }
  }

  # modify standata:

  if (!is.null(var_xcens)) {
    data <- modify_standata(data, bdata, lcl, var_xcens = var_xcens, cens_ind = cens_ind)
  }

  # generate stan code:

  code <- brms::make_stancode(
    bform,
    data = bdata,
    prior = bpriors,
    family = family,
    sample_prior = sample_prior,
    knots = knots
  )

  # modify stancode:

  if (car1) {
    code <- modify_stancode(code, var_car1 = var_car1)
  }

  if (!is.null(var_xcens)) {
    code <- modify_stancode(code, modify = "xcens", var_xcens = var_xcens, lower_bound = lower_bound, lcl = lcl)
  }

  if (!is.null(stancode)) {
    code <- add_stancode(code, stancode, block = names(stancode))
  }

  # fit model:

  stanmod <- if (length(model_saved$csvs) > 0 && !overwrite) {
    if (backend == "rstan") rstan::read_stan_csv(model_saved$csvs) else
      if (backend == "cmdstanr") brms::read_csv_as_stanfit(model_saved$csvs)
  } else if (backend == "rstan") {
    rstan::stan(
      model_code = code,
      data = data,
      sample_file = file, # output in csv format
      seed = seed,
      ...
    )
  } else if (backend == "cmdstanr") {
    if (!requireNamespace("cmdstanr", quietly = TRUE)) {
      stop(
        "Package \"cmdstanr\" must be installed for backend == \"cmdstanr\".",
        .call = FALSE
      )
    }
    fit_cmdstan_model(
      code, data, seed, model_saved$path, model_saved$basename, file, ...
    )
  } else stop("Backend must be either \"rstan\" or \"cmdstanr\".")

  # feed back into brms:

  if (length(model_saved$rds) > 0 && !overwrite) {
    brmsmod <- brm(
      bform,
      data = bdata,
      prior = bpriors,
      family = family,
      knots = knots,
      file = file,
      file_refit = "never"
    )
  } else {
    brmsmod <- brm(
      bform,
      data = bdata,
      prior = bpriors,
      family = family,
      knots = knots,
      empty = TRUE
    )
    # save empty fit:
    if (!is.null(file)) brms:::write_brmsfit(brmsmod, file = file)
  }

  # add stan model to fit slot:
  brmsmod$fit <- stanmod
  brmsmod <- rename_pars(brmsmod)
  brmsmod$model <- code # replace the original Stan code with the modified code

  return(brmsmod)
}

get_model <- function(file) {
  path <- str_remove(file, "\\/[^\\/]+$")# remove base filename
  bname <- str_extract(file, "[^\\/]+$") # extract base filename
  # list csv and rds files matching file path/basename combo:
  csvfiles <- list.files(path = path,
                         pattern = paste0("^", paste0(bname, "[-_]\\d\\.csv")),
                         full.names = TRUE)
  rdsfiles <- list.files(path = path,
                         pattern = paste0("^", paste0(bname, "\\.rds")),
                         full.names = TRUE)
  list(csvs = csvfiles, path = path, basename = bname, rds = rdsfiles)
}

fit_cmdstan_model <- function(code, data, seed, path, basename, file, ...) {
  model_setup <- cmdstanr::cmdstan_model(stan_file = cmdstanr::write_stan_file(code), compile = FALSE)
  model_setup$format(overwrite_file = TRUE, canonicalize = TRUE, backup = FALSE)
  model_setup$compile()
  model <- model_setup$sample(data = data, seed = seed, ...)
  model$save_output_files(
    dir = path,
    basename = basename,
    random = FALSE,
    timestamp = FALSE
  )
  # rstan::read_stan_csv(model$output_files())
  brms::read_csv_as_stanfit(model$output_files())
}
bentrueman/bgamcar1 documentation built on July 6, 2024, 11:16 p.m.