R/backends.R

Defines functions read_csv_as_stanfit file_refit_options repair_stanfit repair_variable_names validate_silent use_opencl validate_opencl is.brmsopencl opencl use_threading validate_threads is.brmsthreads threading require_backend algorithm_choices backend_choices elapsed_time recompile_model needs_recompilation compiled_model .fit_model_mock .fit_model_cmdstanr .fit_model_rstan fit_model .compile_model_mock .compile_model_cmdstanr .compile_model_rstan compile_model .parse_model_mock .parse_model_cmdstanr .parse_model_rstan parse_model

Documented in opencl read_csv_as_stanfit recompile_model threading

# parse Stan model code
# @param model Stan model code
# @return validated Stan model code
parse_model <- function(model, backend, ...) {
  backend <- as_one_character(backend)
  .parse_model <- get(paste0(".parse_model_", backend), mode = "function")
  .parse_model(model, ...)
}

# parse Stan model code with rstan
# @param model Stan model code
# @return validated Stan model code
.parse_model_rstan <- function(model, silent = 1, ...) {
  out <- eval_silent(
    rstan::stanc(model_code = model, ...),
    type = "message", try = TRUE, silent = silent
  )
  out$model_code
}

# parse Stan model code with cmdstanr
# @param model Stan model code
# @return validated Stan model code
.parse_model_cmdstanr <- function(model, silent = 1, ...) {
  require_package("cmdstanr")
  temp_file <- cmdstanr::write_stan_file(model)
  # if (cmdstanr::cmdstan_version() >= "2.29.0") {
  #   .canonicalize_stan_model(temp_file, overwrite_file = TRUE)
  # }
  out <- eval_silent(
    cmdstanr::cmdstan_model(temp_file, compile = FALSE, ...),
    type = "message", try = TRUE, silent = silent
  )
  out$check_syntax(quiet = TRUE)
  collapse(out$code(), "\n")
}

# parse model with a mock backend for testing
.parse_model_mock <- function(model, silent = TRUE, parse_error = NULL,
                              parse_check = "rstan", ...) {
  if (!is.null(parse_error)) {
    stop2(parse_error)
  } else if (parse_check == "rstan") {
    out <- .parse_model_rstan(model, silent = silent, ...)
  } else if (parse_check == "cmdstanr") {
    out <- .parse_model_cmdstanr(model, silent = silent, ...)
  } else if (is.null(parse_check)) {
    out <- "mock_code"
  } else {
    stop2("Unknown 'parse_check' value.")
  }
  out
}

# compile Stan model
# @param model Stan model code
# @return validated Stan model code
compile_model <- function(model, backend, ...) {
  backend <- as_one_character(backend)
  .compile_model <- get(paste0(".compile_model_", backend), mode = "function")
  .compile_model(model, ...)
}

# compile Stan model with rstan
# @param model Stan model code
# @return model compiled with rstan
.compile_model_rstan <- function(model, threads, opencl, silent = 1, ...) {
  args <- list(...)
  args$model_code <- model
  if (silent < 2) {
    message("Compiling Stan program...")
  }
  if (use_threading(threads, force = TRUE)) {
    if (utils::packageVersion("rstan") >= "2.26") {
      threads_per_chain_def <- rstan::rstan_options("threads_per_chain")
      on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def))
      rstan::rstan_options(threads_per_chain = threads$threads)
    } else {
      stop2("Threading is not supported by backend 'rstan' version ",
            utils::packageVersion("rstan"), ".")
    }
  }
  if (use_opencl(opencl)) {
    stop2("OpenCL is not supported by backend 'rstan' version ",
          utils::packageVersion("rstan"), ".")
  }
  eval_silent(
    do_call(rstan::stan_model, args),
    type = "message", try = TRUE, silent = silent >= 2
  )
}

# compile Stan model with cmdstanr
# @param model Stan model code
# @return model compiled with cmdstanr
.compile_model_cmdstanr <- function(model, threads, opencl, silent = 1, ...) {
  require_package("cmdstanr")
  args <- list(...)
  args$stan_file <- cmdstanr::write_stan_file(model)
  # if (cmdstanr::cmdstan_version() >= "2.29.0") {
  #   .canonicalize_stan_model(args$stan_file, overwrite_file = TRUE)
  # }
  if (use_threading(threads, force = TRUE)) {
    args$cpp_options$stan_threads <- TRUE
  }
  if (use_opencl(opencl)) {
    args$cpp_options$stan_opencl <- TRUE
  }
  eval_silent(
    do_call(cmdstanr::cmdstan_model, args),
    type = "message", try = TRUE, silent = silent >= 2
  )
}

# compile model with a mock backend for testing
.compile_model_mock <- function(model, threads, opencl, compile_check = "rstan",
                                compile_error = NULL, silent = 1, ...) {
  if (!is.null(compile_error)) {
    stop2(compile_error)
  } else if (compile_check == "rstan") {
    out <- .parse_model_rstan(model, silent = silent, ...)
  } else if (compile_check == "cmdstanr") {
    out <- .parse_model_cmdstanr(model, silent = silent, ...)
  } else if (is.null(compile_check)) {
    out <- list()
  } else {
    stop2("Unknown 'compile_check' value.")
  }
  out
}

# fit Stan model
# @param model Stan model code
# @return validated Stan model code
fit_model <- function(model, backend, ...) {
  backend <- as_one_character(backend)
  .fit_model <- get(paste0(".fit_model_", backend), mode = "function")
  .fit_model(model, ...)
}

# fit Stan model with rstan
# @param model a compiled Stan model
# @param sdata named list to be passed to Stan as data
# @return a fitted Stan model
.fit_model_rstan <- function(model, sdata, algorithm, iter, warmup, thin,
                             chains, cores, threads, opencl, init, exclude,
                             seed, control, silent, future, ...) {

  # some input checks and housekeeping
  if (use_threading(threads, force = TRUE)) {
    if (utils::packageVersion("rstan") >= "2.26") {
      threads_per_chain_def <- rstan::rstan_options("threads_per_chain")
      on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def))
      rstan::rstan_options(threads_per_chain = threads$threads)
    } else {
      stop2("Threading is not supported by backend 'rstan' version ",
            utils::packageVersion("rstan"), ".")
    }
  }
  if (use_opencl(opencl)) {
    stop2("OpenCL is not supported by backend 'rstan' version ",
          utils::packageVersion("rstan"), ".")
  }
  if (is.null(init)) {
    init <- "random"
  } else if (is.character(init) && !init %in% c("random", "0")) {
    init <- get(init, mode = "function", envir = parent.frame())
  }
  future <- future && algorithm %in% "sampling"
  args <- nlist(
    object = model, data = sdata, iter, seed,
    init = init, pars = exclude, include = FALSE
  )
  dots <- list(...)
  args[names(dots)] <- dots

  # do the actual sampling
  if (silent < 2) {
    message("Start sampling")
  }
  if (algorithm %in% c("sampling", "fixed_param")) {
    c(args) <- nlist(warmup, thin, control, show_messages = !silent)
    if (algorithm == "fixed_param") {
      args$algorithm <- "Fixed_param"
    }
    if (future) {
      if (cores > 1L) {
        warning2("Argument 'cores' is ignored when using 'future'.")
      }
      args$chains <- 1L
      out <- futures <- vector("list", chains)
      for (i in seq_len(chains)) {
        args$chain_id <- i
        if (is.list(init)) {
          args$init <- init[i]
        }
        futures[[i]] <- future::future(
          brms::do_call(rstan::sampling, args),
          packages = "rstan",
          seed = TRUE
        )
      }
      for (i in seq_len(chains)) {
        out[[i]] <- future::value(futures[[i]])
      }
      out <- rstan::sflist2stanfit(out)
      rm(futures)
    } else {
      c(args) <- nlist(chains, cores)
      out <- do_call(rstan::sampling, args)
    }
  } else if (algorithm %in% c("fullrank", "meanfield")) {
    # vb does not support parallel execution
    c(args) <- nlist(algorithm)
    out <- do_call(rstan::vb, args)
  } else {
    stop2("Algorithm '", algorithm, "' is not supported.")
  }
  # TODO: add support for pathfinder and laplace
  out <- repair_stanfit(out)
  out
}

# fit Stan model with cmdstanr
# @param model a compiled Stan model
# @param sdata named list to be passed to Stan as data
# @return a fitted Stan model
.fit_model_cmdstanr <- function(model, sdata, algorithm, iter, warmup, thin,
                                chains, cores, threads, opencl, init, exclude,
                                seed, control, silent, future, ...) {

  require_package("cmdstanr")
  # some input checks and housekeeping
  class(sdata) <- "list"
  if (isNA(seed)) {
    seed <- NULL
  }
  if (is_equal(init, "random")) {
    init <- NULL
  } else if (is_equal(init, "0")) {
    init <- 0
  }
  future <- future && algorithm %in% "sampling"
  args <- nlist(data = sdata, seed, init)
  if (use_opencl(opencl)) {
    args$opencl_ids <- opencl$ids
  }
  dots <- list(...)
  args[names(dots)] <- dots
  args[names(control)] <- control

  chains <- as_one_numeric(chains)
  empty_model <- chains <= 0
  if (empty_model) {
    # fit the model with minimal amount of draws
    # TODO: replace with a better solution
    chains <- 1
    iter <- 2
    warmup <- 1
    thin <- 1
    cores <- 1
  }

  # do the actual sampling
  if (silent < 2) {
    message("Start sampling")
  }
  use_threading <- use_threading(threads, force = TRUE)
  if (algorithm %in% c("sampling", "fixed_param")) {
    c(args) <- nlist(
      iter_sampling = iter - warmup,
      iter_warmup = warmup,
      chains, thin,
      parallel_chains = cores,
      show_messages = silent < 2,
      show_exceptions = silent == 0,
      fixed_param = algorithm == "fixed_param"
    )
    if (use_threading) {
      args$threads_per_chain <- threads$threads
    }
    if (future) {
      if (cores > 1L) {
        warning2("Argument 'cores' is ignored when using 'future'.")
      }
      args$chains <- 1L
      out <- futures <- vector("list", chains)
      for (i in seq_len(chains)) {
        args$chain_ids <- i
        if (is.list(init)) {
          args$init <- init[i]
        }
        futures[[i]] <- future::future(
          brms::do_call(model$sample, args),
          packages = "cmdstanr",
          seed = TRUE
        )
      }
      for (i in seq_len(chains)) {
        out[[i]] <- future::value(futures[[i]])
      }
      rm(futures)
    } else {
      out <- do_call(model$sample, args)
    }
  } else if (algorithm %in% c("fullrank", "meanfield")) {
    c(args) <- nlist(iter, algorithm)
    if (use_threading) {
      args$threads <- threads$threads
    }
    out <- do_call(model$variational, args)
  } else if (algorithm %in% c("pathfinder")) {
    if (use_threading) {
      args$num_threads <- threads$threads
    }
    out <- do_call(model$pathfinder, args)
  } else if (algorithm %in% c("laplace")) {
    if (use_threading) {
      args$threads <- threads$threads
    }
    out <- do_call(model$laplace, args)
  } else {
    stop2("Algorithm '", algorithm, "' is not supported.")
  }

  if (future) {
    # 'out' is a list of fitted models
    output_files <- ulapply(out, function(x) x$output_files())
    stan_variables <- out[[1]]$metadata()$stan_variables
  } else {
    # 'out' is a single fitted model
    output_files <- out$output_files()
    stan_variables <- out$metadata()$stan_variables
  }

  out <- read_csv_as_stanfit(
    output_files, variables = stan_variables,
    model = model, exclude = exclude, algorithm = algorithm
  )

  if (empty_model) {
    # allow correct updating of an 'empty' model
    out@sim <- list()
  }
  out
}

# fit model with a mock backend for testing
.fit_model_mock <- function(model, sdata, algorithm, iter, warmup, thin,
                            chains, cores, threads, opencl, init, exclude,
                            seed, control, silent, future, mock_fit, ...) {
  if (is.function(mock_fit)) {
    out <- mock_fit()
  } else {
    out <- mock_fit
  }
  out
}

# extract the compiled stan model
# @param x brmsfit object
compiled_model <- function(x) {
  stopifnot(is.brmsfit(x))
  backend <- x$backend %||% "rstan"
  if (backend == "rstan") {
    out <- rstan::get_stanmodel(x$fit)
  } else if (backend == "cmdstanr") {
    out <- attributes(x$fit)$CmdStanModel
  } else if (backend == "mock") {
    stop2("'compiled_model' is not supported in the mock backend.")
  }
  out
}

# Does the model need recompilation before being able to sample again?
needs_recompilation <- function(x) {
  stopifnot(is.brmsfit(x))
  backend <- x$backend %||% "rstan"
  if (backend == "rstan") {
    # TODO: figure out when rstan requires recompilation
    out <- FALSE
  } else if (backend == "cmdstanr") {
    exe_file <- attributes(x$fit)$CmdStanModel$exe_file()
    out <- !is.character(exe_file) || !file.exists(exe_file)
  } else if (backend == "mock") {
    out <- FALSE
  }
  out
}

#' Recompile Stan models in \code{brmsfit} objects
#'
#' Recompile the Stan model inside a \code{brmsfit} object, if necessary.
#' This does not change the model, it simply recreates the executable
#' so that sampling is possible again.
#'
#' @param x An object of class \code{brmsfit}.
#' @param recompile Logical, indicating whether the Stan model should be
#'   recompiled. If \code{NULL} (the default), \code{recompile_model} tries
#'   to figure out internally, if recompilation is necessary. Setting it to
#'   \code{FALSE} will cause \code{recompile_model} to always return the
#'   \code{brmsfit} object unchanged.
#'
#' @return A (possibly updated) \code{brmsfit} object.
#'
#' @export
recompile_model <- function(x, recompile = NULL) {
  stopifnot(is.brmsfit(x))
  if (is.null(recompile)) {
    recompile <- needs_recompilation(x)
  }
  recompile <- as_one_logical(recompile)
  if (!recompile) {
    return(x)
  }
  message("Recompiling the Stan model")
  backend <- x$backend %||% "rstan"
  new_model <- compile_model(
    stancode(x), backend = backend, threads = x$threads,
    opencl = x$opencl, silent = 2
  )
  if (backend == "rstan") {
    x$fit@stanmodel <- new_model
  } else if (backend == "cmdstanr") {
    attributes(x)$CmdStanModel <- new_model
  } else if (backend == "mock") {
    stop2("'recompile_model' is not supported in the mock backend.")
  }
  x
}

# extract the elapsed time during model fitting
# @param x brmsfit object
elapsed_time <- function(x) {
  stopifnot(is.brmsfit(x))
  backend <- x$backend %||% "rstan"
  if (backend == "rstan") {
    out <- rstan::get_elapsed_time(x$fit)
    out <- data.frame(
      chain_id = seq_len(nrow(out)),
      warmup = out[, "warmup"],
      sampling = out[, "sample"]
    )
    out$total <- out$warmup + out$sampling
    rownames(out) <- NULL
  } else if (backend == "cmdstanr") {
    out <- attributes(x$fit)$metadata$time$chains
  } else if (backend == "mock") {
    stop2("'elapsed_time' not supported in the mock backend.")
  }
  out
}

# supported Stan backends
backend_choices <- function() {
  c("rstan", "cmdstanr", "mock")
}

# supported Stan algorithms
algorithm_choices <- function() {
  c("sampling", "meanfield", "fullrank", "pathfinder", "laplace", "fixed_param")
}

# check if the model was fit the the required backend
require_backend <- function(backend, x) {
  stopifnot(is.brmsfit(x))
  backend <- match.arg(backend, backend_choices())
  if (isTRUE(x$backend != backend)) {
    stop2("Backend '", backend, "' is required for this method.")
  }
  invisible(TRUE)
}

#' Threading in Stan
#'
#' Use threads for within-chain parallelization in \pkg{Stan} via the \pkg{brms}
#' interface. Within-chain parallelization is experimental! We recommend its use
#' only if you are experienced with Stan's \code{reduce_sum} function and have a
#' slow running model that cannot be sped up by any other means.
#'
#' @param threads Number of threads to use in within-chain parallelization.
#' @param grainsize Number of observations evaluated together in one chunk on
#'   one of the CPUs used for threading. If \code{NULL} (the default),
#'   \code{grainsize} is currently chosen as \code{max(100, N / (2 *
#'   threads))}, where \code{N} is the number of observations in the data. This
#'   default is experimental and may change in the future without prior notice.
#' @param static Logical. Apply the static (non-adaptive) version of
#'   \code{reduce_sum}? Defaults to \code{FALSE}. Setting it to \code{TRUE}
#'   is required to achieve exact reproducibility of the model results
#'   (if the random seed is set as well).
#' @param force Logical. Defaults to \code{FALSE}. If \code{TRUE}, this will
#'   force the Stan model to compile with threading enabled without altering the
#'   Stan code generated by brms. This can be useful if your own custom Stan
#'   functions use threading internally.
#'
#' @return A \code{brmsthreads} object which can be passed to the
#'   \code{threads} argument of \code{brm} and related functions.
#'
#' @details The adaptive scheduling procedure used by \code{reduce_sum} will
#'   prevent the results to be exactly reproducible even if you set the random
#'   seed. If you need exact reproducibility, you have to set argument
#'   \code{static = TRUE} which may reduce efficiency a bit.
#'
#'   To ensure that chunks (whose size is defined by \code{grainsize}) require
#'   roughly the same amount of computing time, we recommend storing
#'   observations in random order in the data. At least, please avoid sorting
#'   observations after the response values. This is because the latter often
#'   cause variations in the computing time of the pointwise log-likelihood,
#'   which makes up a big part of the parallelized code.
#'
#' @examples
#' \dontrun{
#' # this model just serves as an illustration
#' # threading may not actually speed things up here
#' fit <- brm(count ~ zAge + zBase * Trt + (1|patient),
#'            data = epilepsy, family = negbinomial(),
#'            chains = 1, threads = threading(2, grainsize = 100),
#'            backend = "cmdstanr")
#' summary(fit)
#' }
#'
#' @export
threading <- function(threads = NULL, grainsize = NULL, static = FALSE,
                      force = FALSE) {
  out <- list(threads = NULL, grainsize = NULL)
  class(out) <- "brmsthreads"
  if (!is.null(threads)) {
    threads <- as_one_numeric(threads)
    if (!is_wholenumber(threads) || threads < 1) {
      stop2("Number of threads needs to be positive.")
    }
    out$threads <- threads
  }
  if (!is.null(grainsize)) {
    grainsize <- as_one_numeric(grainsize)
    if (!is_wholenumber(grainsize) || grainsize < 1) {
      stop2("The grainsize needs to be positive.")
    }
    out$grainsize <- grainsize
  }
  out$static <- as_one_logical(static)
  out$force <- as_one_logical(force)
  out
}

is.brmsthreads <- function(x) {
  inherits(x, "brmsthreads")
}

# validate 'thread' argument
validate_threads <- function(threads) {
  if (is.null(threads)) {
    threads <- threading()
  } else if (is.numeric(threads)) {
    threads <- as_one_numeric(threads)
    threads <- threading(threads)
  } else if (!is.brmsthreads(threads)) {
    stop2("Argument 'threads' needs to be numeric or ",
          "specified via the 'threading' function.")
  }
  threads
}

# is threading activated?
use_threading <- function(threads, force = FALSE) {
  threads <- validate_threads(threads)
  out <- isTRUE(threads$threads > 0)
  if (!force) {
    # Stan code will only be altered in non-forced mode
    out <- out && !isTRUE(threads$force)
  }
  out
}

#' GPU support in Stan via OpenCL
#'
#' Use OpenCL for GPU support in \pkg{Stan} via the \pkg{brms} interface. Only
#' some \pkg{Stan} functions can be run on a GPU at this point and so
#' a lot of \pkg{brms} models won't benefit from OpenCL for now.
#'
#' @param ids (integer vector of length 2) The platform and device IDs of the
#'   OpenCL device to use for fitting. If you don't know the IDs of your OpenCL
#'   device, \code{c(0,0)} is most likely what you need.
#'
#' @return A \code{brmsopencl} object which can be passed to the
#'   \code{opencl} argument of \code{brm} and related functions.
#'
#' @details For more details on OpenCL in \pkg{Stan}, check out
#' \url{https://mc-stan.org/docs/2_26/cmdstan-guide/parallelization.html#opencl}
#' as well as \url{https://mc-stan.org/docs/2_26/stan-users-guide/opencl.html}.
#'
#' @examples
#' \dontrun{
#' # this model just serves as an illustration
#' # OpenCL may not actually speed things up here
#' fit <- brm(count ~ zAge + zBase * Trt + (1|patient),
#'            data = epilepsy, family = poisson(),
#'            chains = 2, cores = 2, opencl = opencl(c(0, 0)),
#'            backend = "cmdstanr")
#' summary(fit)
#' }
#'
#' @export
opencl <- function(ids = NULL) {
  out <- list(ids = NULL)
  class(out) <- "brmsopencl"
  if (!is.null(ids)) {
    ids <- as.integer(ids)
    if (!length(ids) == 2L) {
      stop2("OpenCl 'ids' needs to be an integer vector of length 2.")
    }
    out$ids <- ids
  }
  out
}

is.brmsopencl <- function(x) {
  inherits(x, "brmsopencl")
}

# validate the 'opencl' argument
validate_opencl <- function(opencl) {
  if (is.null(opencl)) {
    opencl <- opencl()
  } else if (is.numeric(opencl)) {
    opencl <- opencl(opencl)
  } else if (!is.brmsopencl(opencl)) {
    stop2("Argument 'opencl' needs to an integer vector or ",
          "specified via the 'opencl' function.")
  }
  opencl
}

# is OpenCL activated?
use_opencl <- function(opencl) {
  !is.null(validate_opencl(opencl)$ids)
}

# validate the 'silent' argument
validate_silent <- function(silent) {
  silent <- as_one_integer(silent)
  if (silent < 0 || silent > 2) {
    stop2("'silent' must be between 0 and 2.")
  }
  silent
}

# ensure that variable dimensions at the end are correctly written
# convert names like b.1.1 to b[1,1]
repair_variable_names <- function(x) {
  x <- sub("\\.", "[", x)
  x <- gsub("\\.", ",", x)
  x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]")
  x
}

# repair parameter names of stanfit objects
repair_stanfit <- function(x) {
  stopifnot(is.stanfit(x))
  if (!length(x@sim$fnames_oi)) {
    # nothing to rename
    return(x)
  }
  # the posterior package cannot deal with non-unique parameter names
  # this case happens rarely but might happen when sample_prior = "yes"
  x@sim$fnames_oi <- make.unique(as.character(x@sim$fnames_oi), "__")
  for (i in seq_along(x@sim$samples)) {
    # stanfit may have renamed dimension suffixes (#1218)
    if (length(x@sim$samples[[i]]) == length(x@sim$fnames_oi)) {
      names(x@sim$samples[[i]]) <- x@sim$fnames_oi
    }
  }
  x
}

# possible options for argument 'file_refit'
file_refit_options <- function() {
  c("never", "always", "on_change")
}

# canonicalize Stan model file in accordance with the current Stan version
# this function may no longer be needed due to rstan 2.26+ now being on CRAN
# for more details see https://github.com/paul-buerkner/brms/issues/1544
# .canonicalize_stan_model <- function(stan_file, overwrite_file = TRUE) {
#   cmdstan_mod <- cmdstanr::cmdstan_model(stan_file, compile = FALSE)
#   out <- utils::capture.output(
#     cmdstan_mod$format(
#       canonicalize = list("deprecations", "braces", "parentheses"),
#       overwrite_file = overwrite_file, backup = FALSE
#     )
#   )
#   paste0(out, collapse = "\n")
# }

#' Read CmdStan CSV files as a brms-formatted stanfit object
#'
#' \code{read_csv_as_stanfit} is used internally to read CmdStan CSV files into a
#' \code{stanfit} object that is consistent with the structure of the fit slot of a
#' brmsfit object.
#'
#' @param files Character vector of CSV files names where draws are stored.
#' @param variables Character vector of variables to extract from the CSV files.
#' @param sampler_diagnostics Character vector of sampler diagnostics to extract.
#' @param model A compiled cmdstanr model object (optional). Provide this argument
#'  if you want to allow updating the model without recompilation.
#' @param exclude Character vector of variables to exclude from the stanfit. Only
#'  used when \code{variables} is also specified.
#' @param algorithm The algorithm with which the model was fitted.
#'   See \code{\link{brm}} for details.
#'
#' @return A stanfit object consistent with the structure of the \code{fit}
#'  slot of a brmsfit object.
#'
#' @examples
#' \dontrun{
#' # fit a model manually via cmdstanr
#' scode <- stancode(count ~ Trt, data = epilepsy)
#' sdata <- standata(count ~ Trt, data = epilepsy)
#' mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(scode))
#' stanfit <- mod$sample(data = sdata)
#'
#' # feed the Stan model back into brms
#' fit <- brm(count ~ Trt, data = epilepsy, empty = TRUE, backend = 'cmdstanr')
#' fit$fit <- read_csv_as_stanfit(stanfit$output_files(), model = mod)
#' fit <- rename_pars(fit)
#' summary(fit)
#' }
#'
#' @export
read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = NULL,
                                model = NULL, exclude = "", algorithm = "sampling") {
  require_package("cmdstanr")

  if (!is.null(variables)) {
    # ensure that only relevant variables are read from CSV
    variables <- repair_variable_names(variables)
    variables <- unique(sub("\\[.+", "", variables))
    variables <- setdiff(variables, exclude)
    # temp fix for cmdstanr not recognizing the variable names it produces #1473
    if (algorithm %in% c("meanfield", "fullrank")) {
      variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
    } else if (algorithm %in% "pathfinder") {
      variables <- setdiff(variables, "lp_approx__")
    } else if (algorithm %in% "laplace") {
      variables <- setdiff(variables, c("lp__", "lp_approx__"))
    }
  }

  csfit <- cmdstanr::read_cmdstan_csv(
    files = files, variables = variables,
    sampler_diagnostics = sampler_diagnostics,
    format = NULL
  )

  # @model_name
  model_name = gsub(".csv", "", basename(files[[1]]))

  # @model_pars
  model_pars <- csfit$metadata$stan_variables
  if (!is.null(variables)) {
    model_pars <- intersect(model_pars, variables)
  }
  # special variables will be added back later on
  special_vars <- c("lp__", "lp_approx__")
  model_pars <- setdiff(model_pars, special_vars)

  # @par_dims
  par_dims <- vector("list", length(model_pars))
  names(par_dims) <- model_pars
  par_dims <- lapply(par_dims, function(x) integer(0))
  pdims_num <- ulapply(model_pars, function(x)
    sum(grepl(paste0("^", x, "\\[.*\\]$"), csfit$metadata$model_params))
  )
  par_dims[pdims_num != 0] <-
    csfit$metadata$stan_variable_sizes[model_pars][pdims_num != 0]

  # @mode
  mode <- 0L

  # @sim
  rstan_diagn_order <- c("accept_stat__", "treedepth__", "stepsize__",
                         "divergent__", "n_leapfrog__", "energy__")

  if (!is.null(sampler_diagnostics)) {
    rstan_diagn_order <- rstan_diagn_order[rstan_diagn_order %in% sampler_diagnostics]
  }

  res_vars <- c(".chain", ".iteration", ".draw")
  if ("post_warmup_draws" %in% names(csfit)) {
    # for MCMC samplers
    n_chains <- max(
      nchains(csfit$warmup_draws),
      nchains(csfit$post_warmup_draws)
    )
    n_iter_warmup <- niterations(csfit$warmup_draws)
    n_iter_sample <- niterations(csfit$post_warmup_draws)
    if (n_iter_warmup > 0) {
      csfit$warmup_draws <- as_draws_df(csfit$warmup_draws)
      csfit$warmup_sampler_diagnostics <-
        as_draws_df(csfit$warmup_sampler_diagnostics)
    }
    if (n_iter_sample > 0) {
      csfit$post_warmup_draws <- as_draws_df(csfit$post_warmup_draws)
      csfit$post_warmup_sampler_diagnostics <-
        as_draws_df(csfit$post_warmup_sampler_diagnostics)
    }

    # called 'samples' for consistency with rstan
    samples <- rbind(csfit$warmup_draws, csfit$post_warmup_draws)
    # manage memory
    csfit$warmup_draws <- NULL
    csfit$post_warmup_draws <- NULL

    # prepare sampler diagnostics
    diagnostics <- rbind(
      csfit$warmup_sampler_diagnostics,
      csfit$post_warmup_sampler_diagnostics
    )

    # manage memory
    csfit$warmup_sampler_diagnostics <- NULL
    csfit$post_warmup_sampler_diagnostics <- NULL
    # convert to regular data.frame
    diagnostics <- as.data.frame(diagnostics)
    diag_chain_ids <- diagnostics$.chain
    diagnostics[res_vars] <- NULL

  } else if ("draws" %in% names(csfit)) {
    # for variational inference "samplers"
    n_chains <- 1
    n_iter_warmup <- 0
    n_iter_sample <- niterations(csfit$draws)
    if (n_iter_sample > 0) {
      csfit$draws <- as_draws_df(csfit$draws)
    }

    # called 'samples' for consistency with rstan
    samples <- csfit$draws
    # manage memory
    csfit$draws <- NULL

    # VI has no sampler diagnostics
    diag_chain_ids <- rep(1L, nrow(samples))
    diagnostics <- as.data.frame(matrix(nrow = nrow(samples), ncol = 0))
  }

  # some diagnostics may be missing in the output depending on the algorithm
  rstan_diagn_order <- intersect(rstan_diagn_order, names(diagnostics))

  # convert to regular data.frame
  samples <- as.data.frame(samples)
  chain_ids <- samples$.chain
  samples[res_vars] <- NULL

  # only add special variables to dims if there are present in samples
  # this ensures that dims_oi, pars_oi, and fnames_oi match with samples
  for (p in special_vars) {
    if (p %in% colnames(samples)) {
      samples <- move2end(samples, p)
      par_dims[[p]] <- integer(0)
    }
  }
  model_pars <- names(par_dims)
  fnames_oi <- colnames(samples)

  # split samples into chains
  samples <- split(samples, chain_ids)
  names(samples) <- NULL

  # split diagnostics into chains
  diagnostics <- split(diagnostics, diag_chain_ids)
  names(diagnostics) <- NULL

  #  @sim$sample: largely 113-130 from rstan::read_stan_csv
  values <- list()
  values$algorithm <- csfit$metadata$algorithm
  values$engine <- csfit$metadata$engine
  values$metric <- csfit$metadata$metric

  sampler_t <- NULL
  if (!is.null(values$algorithm)) {
    if (values$algorithm == "rwm" || values$algorithm == "Metropolis") {
      sampler_t <- "Metropolis"
    } else if (values$algorithm == "hmc") {
      if (values$engine == "static") {
        sampler_t <- "HMC"
      } else {
        if (values$metric == "unit_e") {
          sampler_t <- "NUTS(unit_e)"
        } else if (values$metric == "diag_e") {
          sampler_t <- "NUTS(diag_e)"
        } else if (values$metric == "dense_e") {
          sampler_t <- "NUTS(dense_e)"
        }
      }
    }
  }

  adapt_info <- vector("list", 4)
  idx_samples <- (n_iter_warmup + 1):(n_iter_warmup + n_iter_sample)

  for (i in seq_along(samples)) {
    m <- colMeans(samples[[i]][idx_samples, , drop = FALSE])
    rownames(samples[[i]]) <- seq_rows(samples[[i]])
    attr(samples[[i]], "sampler_params") <- diagnostics[[i]][rstan_diagn_order]
    rownames(attr(samples[[i]], "sampler_params")) <- seq_rows(diagnostics[[i]])

    # reformat back to text
    if (length(csfit$inv_metric)) {
      if (is_equal(sampler_t, "NUTS(dense_e)")) {
        mmatrix_txt <- "\n# Elements of inverse mass matrix:\n# "
        mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse = ", "),
                       collapse = "\n# ")
      } else {
        mmatrix_txt <- "\n# Diagonal elements of inverse mass matrix:\n# "
        mmat <- paste0(csfit$inv_metric[[i]], collapse = ", ")
      }

      adapt_info[[i]] <- paste0("# Step size = ",
                                csfit$step_size[[i]],
                                mmatrix_txt,
                                mmat, "\n# ")
      attr(samples[[i]], "adaptation_info") <- adapt_info[[i]]
    } else {
      attr(samples[[i]], "adaptation_info") <- character(0)
    }


    attr(samples[[i]], "args") <- list(sampler_t = sampler_t, chain_id = i)

    if (NROW(csfit$metadata$time)) {
      time_i <- as.double(csfit$metadata$time[i, c("warmup", "sampling")])
      names(time_i) <- c("warmup", "sample")
      attr(samples[[i]], "elapsed_time") <- time_i
    }

    attr(samples[[i]], "mean_pars") <- m[-length(m)]
    attr(samples[[i]], "mean_lp__") <- m["lp__"]
  }

  perm_lst <- lapply(seq_len(n_chains), function(id) sample.int(n_iter_sample))

  # @sim
  sim <- list(
    samples = samples,
    iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup,
    thin = csfit$metadata$thin,
    warmup = csfit$metadata$iter_warmup,
    chains = n_chains,
    n_save = rep(n_iter_sample + n_iter_warmup, n_chains),
    warmup2 = rep(n_iter_warmup, n_chains),
    permutation = perm_lst,
    pars_oi = model_pars,
    dims_oi = par_dims,
    fnames_oi = fnames_oi,
    n_flatnames = length(fnames_oi)
  )

  # @stan_args
  sargs <- list(
    stan_version_major = as.character(csfit$metadata$stan_version_major),
    stan_version_minor = as.character(csfit$metadata$stan_version_minor),
    stan_version_patch = as.character(csfit$metadata$stan_version_patch),
    model = csfit$metadata$model_name,
    start_datetime = gsub(" ", "", csfit$metadata$start_datetime),
    method = csfit$metadata$method,
    iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup,
    warmup = csfit$metadata$iter_warmup,
    save_warmup = csfit$metadata$save_warmup,
    thin = csfit$metadata$thin,
    engaged = as.character(csfit$metadata$adapt_engaged),
    gamma = csfit$metadata$gamma,
    delta = csfit$metadata$adapt_delta,
    kappa = csfit$metadata$kappa,
    t0 = csfit$metadata$t0,
    init_buffer = as.character(csfit$metadata$init_buffer),
    term_buffer = as.character(csfit$metadata$term_buffer),
    window = as.character(csfit$metadata$window),
    algorithm = csfit$metadata$algorithm,
    engine = csfit$metadata$engine,
    max_depth = csfit$metadata$max_treedepth,
    metric = csfit$metadata$metric,
    metric_file = character(0), # not stored in metadata
    stepsize = NA, # add in loop
    stepsize_jitter = csfit$metadata$stepsize_jitter,
    num_chains = as.character(csfit$metadata$num_chains),
    chain_id = NA, # add in loop
    file = character(0), # not stored in metadata
    init = NA, # add in loop
    seed = as.character(csfit$metadata$seed),
    file = NA, # add in loop
    diagnostic_file = character(0), # not stored in metadata
    refresh = as.character(csfit$metadata$refresh),
    sig_figs = as.character(csfit$metadata$sig_figs),
    profile_file = csfit$metadata$profile_file,
    num_threads = as.character(csfit$metadata$threads_per_chain),
    stanc_version = gsub(" ", "", csfit$metadata$stanc_version),
    stancflags = character(0), # not stored in metadata
    adaptation_info = NA, # add in loop
    has_time = is.numeric(csfit$metadata$time$total),
    time_info = NA, # add in loop
    sampler_t = sampler_t
  )

  sargs_rep <- replicate(n_chains, sargs, simplify = FALSE)

  for (i in seq_along(sargs_rep)) {
    sargs_rep[[i]]$chain_id <- i
    sargs_rep[[i]]$stepsize <- csfit$metadata$step_size[i]
    sargs_rep[[i]]$init <- as.character(csfit$metadata$init[i])
    # two 'file' elements: select the second
    file_idx <- which(names(sargs_rep[[i]]) == "file")
    sargs_rep[[i]][[file_idx[2]]] <- files[[i]]

    sargs_rep[[i]]$adaptation_info <- adapt_info[[i]]

    if (NROW(csfit$metadata$time)) {
      sargs_rep[[i]]$time_info <- paste0(
        c("#  Elapsed Time: ", "#                ", "#                ", "# "),
        c(csfit$metadata$time[i, c("warmup", "sampling", "total")], ""),
        c(" seconds (Warm-up)", " seconds (Sampling)", " seconds (Total)", "")
      )
    }
  }

  # @stanmodel
  cxxdso_class <- "cxxdso"
  attr(cxxdso_class, "package") <- "rstan"
  null_dso <- new(
    cxxdso_class, sig = list(character(0)), dso_saved = FALSE,
    dso_filename = character(0), modulename = character(0),
    system = R.version$system, cxxflags = character(0),
    .CXXDSOMISC = new.env(parent = emptyenv())
  )
  null_sm <- new(
    "stanmodel", model_name = model_name, model_code = character(0),
    model_cpp = list(), dso = null_dso
  )

  # @date
  sdate <- do.call(max, lapply(files, function(csv) file.info(csv)$mtime))
  sdate <- format(sdate, "%a %b %d %X %Y")

  out <- new(
    "stanfit",
    model_name = model_name,
    model_pars = model_pars,
    par_dims = par_dims,
    mode = mode,
    sim = sim,
    inits = list(),
    stan_args = sargs_rep,
    stanmodel = null_sm,
    date = sdate,  # not the time of sampling
    .MISC = new.env(parent = emptyenv())
  )

  attributes(out)$metadata <- csfit
  attributes(out)$CmdStanModel <- model
  out
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Sept. 23, 2024, 5:08 p.m.