R/tar_jags.R

Defines functions tar_jags_run tar_jags

Documented in tar_jags tar_jags_run

#' @title One MCMC per model with multiple outputs
#' @export
#' @description Targets to run a JAGS model once with MCMC
#'   and save multiple outputs.
#' @details The MCMC targets use `R2jags::jags()` if `n.cluster` is `1` and
#'   `R2jags::jags.parallel()` otherwise. Most arguments to `tar_jags()`
#'   are forwarded to these functions.
#' @return `tar_jags()` returns list of target objects.
#'   See the "Target objects" section for
#'   background.
#'   The target names use the `name` argument as a prefix, and the individual
#'   elements of `jags_files` appear in the suffixes where applicable.
#'   As an example, the specific target objects returned by
#'   `tar_jags(name = x, jags_files = "y.jags", ...)` returns a list
#'   of `targets::tar_target()` objects:
#'   * `x_file_y`: reproducibly track the JAGS model file. Returns
#'     a character vector of length 1 with the path to the JAGS
#'     model file.
#'   * `x_lines_y`: read the contents of the JAGS model file
#'     for safe transport to parallel workers.
#'     Returns a character vector of lines in the model file.
#'   * `x_data`: run the R expression in the `data` argument to produce
#'     a JAGS dataset for the model. Returns a JAGS data list.
#'   * `x_mcmc_y`: run MCMC on the model and dataset.
#'     Returns an `rjags` object from `R2jags` with all the MCMC results.
#'   * `x_draws_y`: extract posterior samples from `x_mcmc_y`.
#'     Returns a tidy data frame of MCMC draws. Omitted if `draws = FALSE`.
#'   * `x_summary_y`: extract posterior summaries from `x_mcmc_y`.
#'     Returns a tidy data frame of MCMC draws.
#'     Omitted if `summary = FALSE`.
#'   * `x_dic`: extract deviance information criterion (DIC) info
#'     from `x_mcmc_y`. Returns a tidy data frame of DIC info.
#'     Omitted if `dic = FALSE`.
#' @section Target objects:
#'   Most `stantargets` functions are target factories,
#'   which means they return target objects
#'   or lists of target objects.
#'   Target objects represent skippable steps of the analysis pipeline
#'   as described at <https://books.ropensci.org/targets/>.
#'   Please read the walkthrough at
#'   <https://books.ropensci.org/targets/walkthrough.html>
#'   to understand the role of target objects in analysis pipelines.
#'
#'   For developers,
#'   <https://wlandau.github.io/targetopia/contributing.html#target-factories>
#'   explains target factories (functions like this one which generate targets)
#'   and the design specification at
#'   <https://books.ropensci.org/targets-design/>
#'   details the structure and composition of target objects.
#' @inheritParams tar_jags_run
#' @inheritParams tar_jags_df
#' @inheritParams targets::tar_target
#' @param name Symbol, base name for the collection of targets.
#'   Serves as a prefix for target names.
#' @param data Code to generate the `data` list for the JAGS model.
#'   Optionally include a `.join_data` element to join parts of the data
#'   to correspondingly named parameters in the summary output.
#'   See the vignettes for details.
#' @param jags_files Character vector of JAGS model files. If you
#'   supply multiple files, each model will run on the one shared dataset
#'   generated by the code in `data`. If you supply an unnamed vector,
#'   `tools::file_path_sans_ext(basename(jags_files))` will be used
#'   as target name suffixes. If `jags_files` is a named vector,
#'   the suffixed will come from `names(jags_files)`.
#' @param draws Logical, whether to create a target for posterior draws.
#'   Saves draws as a compressed `posterior::as_draws_df()` `tibble`.
#'   Convenient, but duplicates storage.
#' @param summary Logical, whether to create a target to store a small
#'   data frame of posterior summary statistics and convergence diagnostics.
#' @param dic Logical, whether to create a target with deviance
#'   information criterion (DIC) results.
#' @param format Character of length 1, storage format of the non-data-frame
#'   targets such as the JAGS data and any JAGS fit objects.
#'   Please choose an all=purpose
#'   format such as `"qs"` or `"aws_qs"` rather than a file format like
#'   `"file"` or a data frame format like `"parquet"`. For more on storage
#'   formats, see the help file of `targets::tar_target()`.
#' @param format_df Character of length 1, storage format of the data frame
#'   targets such as posterior draws. We recommend efficient data frame formats
#'   such as `"feather"` or `"aws_parquet"`. For more on storage formats,
#'   see the help file of `targets::tar_target()`.
#' @examples
#' if (requireNamespace("R2jags", quietly = TRUE)) {
#' targets::tar_dir({ # tar_dir() runs code from a temporary directory.
#' targets::tar_script({
#' library(jagstargets)
#' # Do not use a temp file for a real project
#' # or else your targets will always rerun.
#' tmp <- tempfile(pattern = "", fileext = ".jags")
#' tar_jags_example_file(tmp)
#' list(
#'   tar_jags(
#'     your_model,
#'     jags_files = tmp,
#'     data = tar_jags_example_data(),
#'     parameters.to.save = "beta",
#'     stdout = R.utils::nullfile(),
#'     stderr = R.utils::nullfile()
#'   )
#' )
#' }, ask = FALSE)
#' targets::tar_make()
#' })
#' }
tar_jags <- function(
  name,
  jags_files,
  parameters.to.save,
  data = list(),
  summaries = list(),
  summary_args = list(),
  n.cluster = 1,
  n.chains = 3,
  n.iter = 2e3,
  n.burnin = as.integer(n.iter / 2),
  n.thin = 1,
  jags.module = c("glm", "dic"),
  inits = NULL,
  RNGname = c(
    "Wichmann-Hill",
    "Marsaglia-Multicarry",
    "Super-Duper",
    "Mersenne-Twister"
  ),
  jags.seed = 1,
  stdout = NULL,
  stderr = NULL,
  progress.bar = "text",
  refresh = 0,
  draws = TRUE,
  summary = TRUE,
  dic = TRUE,
  tidy_eval = targets::tar_option_get("tidy_eval"),
  packages = targets::tar_option_get("packages"),
  library = targets::tar_option_get("library"),
  format = "qs",
  format_df = "fst_tbl",
  repository = targets::tar_option_get("repository"),
  error = targets::tar_option_get("error"),
  memory = targets::tar_option_get("memory"),
  garbage_collection = targets::tar_option_get("garbage_collection"),
  deployment = targets::tar_option_get("deployment"),
  priority = targets::tar_option_get("priority"),
  resources = targets::tar_option_get("resources"),
  storage = targets::tar_option_get("storage"),
  retrieval = targets::tar_option_get("retrieval"),
  cue = targets::tar_option_get("cue"),
  description = targets::tar_option_get("description")
) {
  targets::tar_assert_package("rjags")
  targets::tar_assert_package("R2jags")
  envir <- tar_option_get("envir")
  targets::tar_assert_chr(jags_files)
  targets::tar_assert_unique(jags_files)
  lapply(jags_files, assert_jags_file)
  targets::tar_assert_in(
    as.integer(n.cluster),
    as.integer(c(1L, n.chains)),
    msg = "due to R2jags constraints, n.cluster must be 1 or n.chains."
  )
  name <- targets::tar_deparse_language(substitute(name))
  name_jags <- produce_jags_names(jags_files)
  name_file <- paste0(name, "_file")
  name_lines <- paste0(name, "_lines")
  name_data <- paste0(name, "_data")
  name_mcmc <- paste0(name, "_mcmc")
  name_draws <- paste0(name, "_draws")
  name_summary <- paste0(name, "_summary")
  name_dic <- paste0(name, "_dic")
  sym_jags <- as_symbols(name_jags)
  sym_file <- as.symbol(name_file)
  sym_lines <- as.symbol(name_lines)
  sym_data <- as.symbol(name_data)
  sym_mcmc <- as.symbol(name_mcmc)
  command_lines <- call_function(
    "readLines",
    args = list(con = as.symbol(name_file))
  )
  command_data <- targets::tar_tidy_eval(
    substitute(data),
    envir = envir,
    tidy_eval = tidy_eval
  )
  command_draws <- substitute(
    jagstargets::tar_jags_df(fit = fit, data = data, output = "draws"),
    env = list(fit = sym_mcmc, data = sym_data)
  )
  command_summary <- substitute(
    jagstargets::tar_jags_df(
      fit,
      data = data,
      output = "summary",
      summaries = quote(summaries),
      summary_args = quote(summary_args)
    ),
    env = list(
      fit = sym_mcmc,
      data = sym_data,
      summaries = substitute(summaries),
      summary_args = substitute(summary_args)
    )
  )
  command_dic <- substitute(
    jagstargets::tar_jags_df(fit = fit, data = data, output = "dic"),
    env = list(fit = sym_mcmc, data = sym_data)
  )
  args_mcmc <- list(
    call_ns("jagstargets", "tar_jags_run"),
    jags_lines = sym_lines,
    parameters.to.save = parameters.to.save,
    data = sym_data,
    inits = inits,
    n.cluster = n.cluster,
    n.chains = n.chains,
    n.iter = n.iter,
    n.burnin = n.burnin,
    n.thin = n.thin,
    jags.module = jags.module,
    RNGname = RNGname,
    jags.seed = jags.seed,
    stdout = stdout,
    stderr = stderr,
    progress.bar = progress.bar,
    refresh = refresh
  )
  command_mcmc <- as.expression(as.call(args_mcmc))
  target_object_file <- targets::tar_target_raw(
    name = name_file,
    command = quote(._jagstargets_file_50e43091),
    packages = character(0),
    format = "file",
    repository = "local",
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = "main",
    priority = priority,
    cue = cue,
    description = description
  )
  target_object_lines <- targets::tar_target_raw(
    name = name_lines,
    command = command_lines,
    packages = character(0),
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = "main",
    priority = priority,
    cue = cue,
    description = description
  )
  target_object_data <- targets::tar_target_raw(
    name = name_data,
    command = command_data,
    packages = packages,
    library = library,
    format = format,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_object_mcmc <- targets::tar_target_raw(
    name = name_mcmc,
    command = command_mcmc,
    format = format,
    repository = repository,
    packages = character(0),
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    resources = resources,
    storage = storage,
    retrieval = retrieval,
    cue = cue,
    description = description
  )
  target_object_draws <- targets::tar_target_raw(
    name = name_draws,
    command = command_draws,
    packages = character(0),
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_object_summary <- targets::tar_target_raw(
    name = name_summary,
    command = command_summary,
    packages = character(0),
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  target_object_dic <- targets::tar_target_raw(
    name = name_dic,
    command = command_dic,
    packages = character(0),
    format = format_df,
    repository = repository,
    error = error,
    memory = memory,
    garbage_collection = garbage_collection,
    deployment = deployment,
    priority = priority,
    cue = cue,
    description = description
  )
  out <- list(
    target_object_file,
    target_object_lines,
    target_object_mcmc,
    if_any(identical(draws, TRUE), target_object_draws, NULL),
    if_any(identical(summary, TRUE), target_object_summary, NULL),
    if_any(identical(dic, TRUE), target_object_dic, NULL)
  )
  out <- list_nonempty(out)
  values <- list(
    ._jagstargets_file_50e43091 = jags_files,
    ._jagstargets_name_50e43091 = sym_jags
  )
  out <- tarchetypes::tar_map(
    values = values,
    names = tidyselect::any_of("._jagstargets_name_50e43091"),
    descriptions = tidyselect::any_of("._jagstargets_file_50e43091"),
    unlist = TRUE,
    out
  )
  out[[name_data]] <- target_object_data
  out
}

#' @title Run a JAGS model and return the whole output object.
#' @export
#' @keywords internal
#' @description Not a user-side function. Do not invoke directly.
#' @return An `R2jags` output object.
#' @param jags_lines Character vector of lines from a JAGS model file.
#' @param stdout Character of length 1, file path to write the stdout stream
#'   of the model when it runs. Set to `NULL` to print to the console.
#'   Set to `R.utils::nullfile()` to suppress stdout.
#'   Does not apply to messages, warnings, or errors.
#' @param stderr Character of length 1, file path to write the stderr stream
#'   of the model when it runs. Set to `NULL` to print to the console.
#'   Set to `R.utils::nullfile()` to suppress stderr.
#'   Does not apply to messages, warnings, or errors.
#' @param parameters.to.save Model parameters to save, passed to
#'   `R2jags::jags()`. See the argument documentation of the
#'   `R2jags::jags()` help file for details.
#' @param n.cluster Number of parallel processes, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param n.chains Number of MCMC chains, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param n.iter Number if iterations (including warmup), passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param n.burnin Number of warmup iterations, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param n.thin Thinning interval, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param jags.module Character vector of JAGS modules to load, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param parameters.to.save Model parameters to save, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param inits Initial values of model parameters, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param RNGname Choice of random number generator, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param jags.seed Seeds to apply to JAGS, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param progress.bar Type of progress bar, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
#' @param refresh Frequency for refreshing the progress bar, passed to
#'   `R2jags::jags()` or `R2jags::jags.parallel()`.
#'   See the argument documentation of the
#'   `R2jags::jags()` and `R2jags::jags.parallel()` help files for details.
tar_jags_run <- function(
  jags_lines,
  parameters.to.save,
  data,
  inits,
  n.cluster,
  n.chains,
  n.iter,
  n.burnin,
  n.thin,
  jags.module,
  RNGname,
  jags.seed,
  stdout,
  stderr,
  progress.bar,
  refresh
) {
  targets::tar_assert_package("rjags")
  targets::tar_assert_package("R2jags")
  tmp <- tempfile()
  dir.create(tmp)
  withr::local_dir(tmp)
  file <- tempfile(pattern = "", fileext = ".jags")
  writeLines(jags_lines, file)
  envir <- environment()
  requireNamespace("coda")
  withr::local_seed(jags.seed)
  if (!is.null(stdout)) {
    withr::local_output_sink(new = stdout, append = TRUE)
  }
  if (!is.null(stderr)) {
    withr::local_message_sink(new = stderr, append = TRUE)
  }
  lapply(jags.module, rjags::load.module, quiet = TRUE)
  jags_data <- data
  jags_data$.join_data <- NULL
  if_any(
    n.cluster > 1L,
    R2jags::jags.parallel(
      data = jags_data,
      inits = inits,
      parameters.to.save = parameters.to.save,
      model.file = file,
      n.chains = n.chains,
      n.iter = n.iter,
      n.burnin = n.burnin,
      n.thin = n.thin,
      n.cluster = n.cluster,
      DIC = TRUE,
      jags.seed = jags.seed,
      RNGname = RNGname,
      jags.module = jags.module,
      envir = envir
    ),
    R2jags::jags(
      data = jags_data,
      inits = inits,
      parameters.to.save = parameters.to.save,
      model.file = file,
      n.chains = n.chains,
      n.iter = n.iter,
      n.burnin = n.burnin,
      n.thin = n.thin,
      DIC = TRUE,
      jags.seed = jags.seed,
      refresh = refresh,
      progress.bar = progress.bar,
      RNGname = RNGname,
      jags.module = jags.module
    )
  )
}
wlandau/jagstargets documentation built on April 19, 2024, 8:23 p.m.