R/tar_stan_mcmc.R

Defines functions tar_stan_mcmc_run tar_stan_mcmc

Documented in tar_stan_mcmc tar_stan_mcmc_run

#' @title One MCMC per model with multiple outputs
#' @export
#' @description `tar_stan_mcmc()` creates targets to run one MCMC
#'   per model and separately save summaries draws, and diagnostics.
#' @details Most of the arguments are passed to the `$compile()`,
#'  `$sample()`, and `$summary()` methods of the `CmdStanModel` class. If you
#'   previously compiled the model in an upstream [tar_stan_compile()]
#'   target, then the model should not recompile.
#' @family MCMC
#' @return `tar_stan_mcmc()` returns a list of target objects.
#'   See the "Target objects" section for
#'   background.
#'   The target names use the `name` argument as a prefix, and the individual
#'   elements of `stan_files` appear in the suffixes where applicable.
#'   As an example, the specific target objects returned by
#'   `tar_stan_mcmc(name = x, stan_files = "y.stan", ...)`
#'   are as follows.
#'   * `x_file_y`: reproducibly track the Stan model file. Returns
#'     a character vector with paths to the
#'     model file and compiled executable.
#'   * `x_lines_y`: read the Stan model file for safe transport to
#'     parallel workers. Omitted if `compile = "original"`.
#'     Returns a character vector of lines in the model file.
#'   * `x_data`: run the R expression in the `data` argument to produce
#'     a Stan dataset for the model. Returns a Stan data list.
#'   * `x_mcmc_y`: run MCMC on the model and the dataset.
#'     Returns a `cmdstanr` `CmdStanMCMC` object with all the results.
#'   * `x_draws_y`: extract draws from `x_mcmc_y`.
#'     Omitted if `draws = FALSE`.
#'     Returns a tidy data frame of draws.
#'   * `x_summary_y`: extract compact summaries from `x_mcmc_y`.
#'     Returns a tidy data frame of summaries.
#'     Omitted if `summary = FALSE`.
#'   * `x_diagnostics`: extract HMC diagnostics from `x_mcmc_y`.
#'     Returns a tidy data frame of HMC diagnostics.
#'     Omitted if `diagnostics = FALSE`.
#' @inheritSection tar_stan_compile Target objects
#' @inheritParams cmdstanr::cmdstan_model
#' @inheritParams cmdstanr::`fit-method-draws`
#' @inheritParams tar_stan_compile_run
#' @inheritParams tar_stan_mcmc_run
#' @inheritParams tar_stan_summary
#' @inheritParams targets::tar_target
#' @param name Symbol, base name for the collection of targets.
#'   Serves as a prefix for target names.
#' @param data Code to generate the `data` for the Stan model.
#' @param stan_files Character vector of Stan model files. If you
#'   supply multiple files, each model will run on the one shared dataset
#'   generated by the code in `data`. If you supply an unnamed vector,
#'   `fs::path_ext_remove(basename(stan_files))` will be used
#'   as target name suffixes. If `stan_files` is a named vector,
#'   the suffixed will come from `names(stan_files)`.
#' @param return_draws Logical, whether to create a target for posterior draws.
#'   Saves `posterior::as_draws_df(fit$draws())` to a compressed `tibble`.
#'   Convenient, but duplicates storage.
#' @param return_summary Logical, whether to create a target for
#'   `fit$summary()`.
#' @param return_diagnostics Logical, whether to create a target for
#'   `posterior::as_draws_df(fit$sampler_diagnostics())`.
#'   Saves `posterior::as_draws_df(fit$draws())` to a compressed `tibble`.
#'   Convenient, but duplicates storage.
#' @param format Character of length 1, storage format of the non-data-frame
#'   targets such as the Stan data and any CmdStanFit objects.
#'   Please choose an all=purpose
#'   format such as `"qs"` or `"aws_qs"` rather than a file format like
#'   `"file"` or a data frame format like `"parquet"`. For more on storage
#'   formats, see the help file of `targets::tar_target()`.
#' @param format_df Character of length 1, storage format of the data frame
#'   targets such as posterior draws. We recommend efficient data frame formats
#'   such as `"feather"` or `"aws_parquet"`. For more on storage formats,
#'   see the help file of `targets::tar_target()`.
#' @param draws Deprecated on 2022-07-22. Use `return_draws` instead.
#' @param summary Deprecated on 2022-07-22. Use `return_summary` instead.
#' @param variables_fit Character vector of variables to include in the
#'   big `CmdStanFit` object returned by the model fit target.
#'   The `variables` argument, by contrast, is for the `"draws"` target only.
#'   The `"draws"` target can only access the variables in the `CmdStanFit`
#'   target. Control the variables in each with the `variables`
#'   and `variables_fit` arguments.
#' @param inc_warmup_fit Logical of length 1, whether to include
#'   warmup draws in the big MCMC object (the target with `"mcmc"` in the name).
#'   `inc_warmup` must not be `TRUE` if `inc_warmup_fit` is `FALSE`.
#' @examples
#' if (Sys.getenv("TAR_LONG_EXAMPLES") == "true") {
#' targets::tar_dir({ # tar_dir() runs code from a temporary directory.
#' targets::tar_script({
#' library(stantargets)
#' # Do not use temporary storage for stan files in real projects
#' # or else your targets will always rerun.
#' path <- tempfile(pattern = "", fileext = ".stan")
#' tar_stan_example_file(path = path)
#' list(
#'   tar_stan_mcmc(
#'     your_model,
#'     stan_files = path,
#'     data = tar_stan_example_data(),
#'     variables = "beta",
#'     summaries = list(~quantile(.x, probs = c(0.25, 0.75))),
#'     stdout = R.utils::nullfile(),
#'     stderr = R.utils::nullfile()
#'   )
#' )
#' }, ask = FALSE)
#' targets::tar_make()
#' })
#' }
tar_stan_mcmc <- function(
  name,
  stan_files,
  data = list(),
  compile = c("original", "copy"),
  quiet = TRUE,
  stdout = NULL,
  stderr = NULL,
  dir = NULL,
  pedantic = FALSE,
  include_paths = NULL,
  cpp_options = list(),
  stanc_options = list(),
  force_recompile = FALSE,
  seed = NULL,
  refresh = NULL,
  init = NULL,
  save_latent_dynamics = FALSE,
  output_dir = NULL,
  output_basename = NULL,
  sig_figs = NULL,
  chains = 4,
  parallel_chains = getOption("mc.cores", 1),
  chain_ids = seq_len(chains),
  threads_per_chain = NULL,
  opencl_ids = NULL,
  iter_warmup = NULL,
  iter_sampling = NULL,
  save_warmup = FALSE,
  thin = NULL,
  max_treedepth = NULL,
  adapt_engaged = TRUE,
  adapt_delta = NULL,
  step_size = NULL,
  metric = NULL,
  metric_file = NULL,
  inv_metric = NULL,
  init_buffer = NULL,
  term_buffer = NULL,
  window = NULL,
  fixed_param = FALSE,
  show_messages = TRUE,
  diagnostics = c("divergences", "treedepth", "ebfmi"),
  variables = NULL,
  variables_fit = NULL,
  inc_warmup = FALSE,
  inc_warmup_fit = FALSE,
  summaries = list(),
  summary_args = list(),
  return_draws = TRUE,
  return_diagnostics = TRUE,
  return_summary = TRUE,
  draws = NULL,
  summary = NULL,
  tidy_eval = targets::tar_option_get("tidy_eval"),
  packages = targets::tar_option_get("packages"),
  library = targets::tar_option_get("library"),
  format = "qs",
  format_df = "fst_tbl",
  repository = targets::tar_option_get("repository"),
  error = targets::tar_option_get("error"),
  memory = targets::tar_option_get("memory"),
  garbage_collection = targets::tar_option_get("garbage_collection"),
  deployment = targets::tar_option_get("deployment"),
  priority = targets::tar_option_get("priority"),
  resources = targets::tar_option_get("resources"),
  storage = targets::tar_option_get("storage"),
  retrieval = targets::tar_option_get("retrieval"),
  cue = targets::tar_option_get("cue"),
  description = targets::tar_option_get("description")
) {
  assert_variables_fit(variables, variables_fit)
  assert_inc_warmup_fit(inc_warmup, inc_warmup_fit)
  targets::tar_assert_scalar(inc_warmup)
  targets::tar_assert_scalar(inc_warmup_fit)
  tar_stan_deprecate(draws, "return_draws")
  tar_stan_deprecate(summary, "return_summary")
  return_draws <- draws %|||% return_draws
  return_summary <- summary %|||% return_summary
  envir <- tar_option_get("envir")
  compile <- match.arg(compile)
  targets::tar_assert_chr(stan_files)
  targets::tar_assert_unique(stan_files)
  lapply(stan_files, assert_stan_file)
  name <- targets::tar_deparse_language(substitute(name))
  name_stan <- produce_stan_names(stan_files)
  name_file <- paste0(name, "_file")
  name_lines <- paste0(name, "_lines")
  name_data <- paste0(name, "_data")
  name_mcmc <- paste0(name, "_mcmc")
  name_draws <- paste0(name, "_draws")
  name_summary <- paste0(name, "_summary")
  name_diagnostics <- paste0(name, "_diagnostics")
  sym_stan <- as_symbols(name_stan)
  sym_file <- as.symbol(name_file)
  sym_lines <- as.symbol(name_lines)
  sym_data <- as.symbol(name_data)
  sym_mcmc <- as.symbol(name_mcmc)
  command_data <- targets::tar_tidy_eval(
    substitute(data),
    envir = envir,
    tidy_eval = tidy_eval
  )
  command_draws <- substitute(
    tibble::as_tibble(posterior::as_draws_df(
      fit$draws(variables = variables, inc_warmup = inc_warmup)
    )),
    env = list(
      fit = sym_mcmc,
      variables = variables,
      inc_warmup = inc_warmup
    )
  )
  command_summary <- tar_stan_summary_call(
    sym_fit = sym_mcmc,
    sym_data = sym_data,
    summaries = substitute(summaries),
    summary_args = substitute(summary_args),
    variables = variables
  )
  command_diagnostics <- substitute(
    tibble::as_tibble(
      posterior::as_draws_df(.targets_mcmc$sampler_diagnostics())
    ),
    env = list(.targets_mcmc = sym_mcmc)
  )
  args_mcmc <- list(
    call_ns("stantargets", "tar_stan_mcmc_run"),
    stan_file = if_any(identical(compile, "original"), sym_file, sym_lines),
    data = sym_data,
    compile = compile,
    quiet = quiet,
    stdout = stdout,
    stderr = stderr,
    dir = dir,
    pedantic = pedantic,
    include_paths = include_paths,
    cpp_options = cpp_options,
    stanc_options = stanc_options,
    force_recompile = force_recompile,
    seed = seed,
    refresh = refresh,
    init = init,
    save_latent_dynamics = save_latent_dynamics,
    output_dir = output_dir,
    output_basename = output_basename,
    sig_figs = sig_figs,
    chains = chains,
    parallel_chains = parallel_chains,
    chain_ids = chain_ids,
    threads_per_chain = threads_per_chain,
    opencl_ids = opencl_ids,
    iter_warmup = iter_warmup,
    iter_sampling = iter_sampling,
    save_warmup = save_warmup,
    thin = thin,
    max_treedepth = max_treedepth,
    adapt_engaged = adapt_engaged,
    adapt_delta = adapt_delta,
    step_size = step_size,
    metric = metric,
    metric_file = metric_file,
    inv_metric = inv_metric,
    init_buffer = init_buffer,
    term_buffer = term_buffer,
    window = window,
    fixed_param = fixed_param,
    show_messages = show_messages,
    diagnostics = diagnostics,
    variables = variables_fit,
    inc_warmup = inc_warmup_fit
  )
  command_mcmc <- as.expression(as.call(args_mcmc))
  target_file <- targets::tar_target_raw(
    name = name_file,
    command = quote(._stantargets_file_50e43091),
    packages = character(0),
    format = "file",
    repository = "local",
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = "main",
    priority = priority,
    cue = cue,
    description = description
  )
  target_lines <- targets::tar_target_raw(
    name = name_lines,
    command = command_lines(sym_file),
    packages = character(0),
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = "main",
    priority = priority,
    cue = cue,
    description = description
  )
  target_data <- targets::tar_target_raw(
    name = name_data,
    command = command_data,
    packages = packages,
    library = library,
    format = format,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_output <- targets::tar_target_raw(
    name = name_mcmc,
    command = command_mcmc,
    format = format,
    repository = repository,
    packages = character(0),
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    resources = resources,
    storage = storage,
    retrieval = retrieval,
    cue = cue,
    description = description
  )
  target_draws <- targets::tar_target_raw(
    name = name_draws,
    command = command_draws,
    packages = character(0),
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_summary <- targets::tar_target_raw(
    name = name_summary,
    command = command_summary,
    packages = packages,
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_diagnostics <- targets::tar_target_raw(
    name = name_diagnostics,
    command = command_diagnostics,
    packages = character(0),
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  tar_stan_target_list(
    name_data = name_data,
    stan_files = stan_files,
    sym_stan = sym_stan,
    compile = compile,
    return_draws = return_draws,
    return_summary = return_summary,
    return_diagnostics = return_diagnostics,
    target_file = target_file,
    target_lines = target_lines,
    target_data = target_data,
    target_output = target_output,
    target_draws = target_draws,
    target_summary = target_summary,
    target_diagnostics = target_diagnostics
  )
}

#' @title Compile and run a Stan model and return the `CmdStanFit` object.
#' @export
#' @keywords internal
#' @description Not a user-side function. Do not invoke directly.
#' @return A `CmdStanFit` object.
#' @inheritParams tar_stan_compile
#' @inheritParams cmdstanr::cmdstan_model
#' @inheritParams cmdstanr::`model-method-sample`
#' @param compile Character of length 1. If `"original"`, then
#'   `cmdstan` will compile the source file right before running
#'   it (or skip compilation if the binary is up to date). This
#'   assumes the worker has access to the file. If the worker
#'   is running on a remote computer that does not have access
#'   to the model file, set to `"copy"` instead. `compile = "copy"`
#'   means the pipeline will read the lines of the original Stan model file
#'   and send them to the worker. The worker writes the lines
#'   to a local copy and compiles the model from there, so it
#'   no longer needs access to the original Stan model file on your
#'   local machine. However, as a result, the Stan model re-compiles
#'   every time the main target reruns.
tar_stan_mcmc_run <- function(
  stan_file,
  data,
  compile,
  quiet,
  stdout,
  stderr,
  dir,
  pedantic,
  include_paths,
  cpp_options,
  stanc_options,
  force_recompile,
  seed,
  refresh,
  init,
  save_latent_dynamics,
  output_dir,
  output_basename,
  sig_figs,
  chains,
  parallel_chains,
  chain_ids,
  threads_per_chain,
  opencl_ids,
  iter_warmup,
  iter_sampling,
  save_warmup,
  thin,
  max_treedepth,
  adapt_engaged,
  adapt_delta,
  step_size,
  metric,
  metric_file,
  inv_metric,
  init_buffer,
  term_buffer,
  window,
  fixed_param,
  show_messages,
  diagnostics,
  variables,
  inc_warmup
) {
  if (!is.null(stdout)) {
    withr::local_output_sink(new = stdout, append = TRUE)
  }
  if (!is.null(stderr)) {
    withr::local_message_sink(new = stderr, append = TRUE)
  }
  file <- stan_file
  if (identical(compile, "copy")) {
    tmp <- tempfile(pattern = "", fileext = ".stan")
    writeLines(stan_file, tmp)
    file <- tmp
  }
  model <- cmdstanr::cmdstan_model(
    stan_file = file,
    compile = TRUE,
    quiet = quiet,
    dir = dir,
    pedantic = pedantic,
    include_paths = include_paths,
    cpp_options = cpp_options,
    stanc_options = stanc_options,
    force_recompile = force_recompile
  )
  if (is.null(seed)) {
    seed <- abs(targets::tar_seed_get()) + 1L
  }
  stan_data <- data
  stan_data$.join_data <- NULL
  fit <- model$sample(
    data = stan_data,
    seed = seed,
    refresh = refresh,
    init = init,
    save_latent_dynamics = save_latent_dynamics,
    output_dir = output_dir,
    output_basename = output_basename,
    sig_figs = sig_figs,
    chains = chains,
    parallel_chains = parallel_chains,
    chain_ids = chain_ids,
    threads_per_chain = threads_per_chain,
    opencl_ids = opencl_ids,
    iter_warmup = iter_warmup,
    iter_sampling = iter_sampling,
    save_warmup = save_warmup,
    thin = thin,
    max_treedepth = max_treedepth,
    adapt_engaged = adapt_engaged,
    adapt_delta = adapt_delta,
    step_size = step_size,
    metric = metric,
    metric_file = metric_file,
    inv_metric = inv_metric,
    init_buffer = init_buffer,
    term_buffer = term_buffer,
    window = window,
    fixed_param = fixed_param,
    show_messages = show_messages,
    diagnostics = diagnostics
  )
  # Load all the data and return the whole unserialized fit object:
  # https://github.com/stan-dev/cmdstanr/blob/d27994f804c493ff3047a2a98d693fa90b83af98/R/fit.R#L16-L18 # nolint
  fit$draws(variables = variables, inc_warmup = inc_warmup)
  try(fit$sampler_diagnostics(inc_warmup = inc_warmup), silent = TRUE)
  try(fit$init(), silent = TRUE)
  try(fit$profiles(), silent = TRUE)
  remove_temp_files(fit)
  fit
}
ropensci/stantargets documentation built on Feb. 8, 2025, 10:34 p.m.