R/fit_model.R

Defines functions fit_model

Documented in fit_model

#' Fit a mark-recapture model
#'
#' Fits the multistate mark recapture model using cmdstanr. Most of the
#' arguments to this function are inherited from the `sample()` method in
#' cmdstanr, providing fine-tuned control over the posterior sampling.
#'
#' To use within-chain parallelization, provide the following arguments:
#'
#' 1. `cpp_options = list(stan_threads = TRUE)` to compile with threading
#'
#' 2. `threads_per_chain` set to some integer value > 1.
#'
#' @param data A list of data generated by the function
#' \code{mrmr::clean_data()}.
#' @inheritParams cmdstanr::cmdstan_model
#' @inheritParams cmdstanr::`model-method-sample`
#' @return A list containing the model object, extracted posterior draws, and
#' input data.
#' @examples
#' \dontrun{
#' captures <- system.file("extdata", "capture-example.csv", package = "mrmr") %>%
#'   readr::read_csv()
#' translocations <- system.file("extdata", "translocation-example.csv",
#'                               package = "mrmr") %>%
#'   readr::read_csv()
#' surveys <- system.file("extdata", "survey-example.csv", package = "mrmr") %>%
#'   readr::read_csv()
#' out <- clean_data(captures, surveys, translocations)
#'
#' # fit a model with 4 chains, 2 threads per chain (to use 8 physical cores)
#' fit_model(
#'   data = out,
#'   chains = 4,
#'   parallel_chains = 4,
#'   cpp_options = list(stan_threads = TRUE),
#'   threads_per_chain = 2
#' )
#' }
#' @export

fit_model <- function(
  data,
  seed = NULL,
  refresh = NULL,
  init = NULL,
  save_latent_dynamics = FALSE,
  output_dir = NULL,
  output_basename = NULL,
  sig_figs = NULL,
  chains = 4,
  parallel_chains = getOption("mc.cores", 1),
  chain_ids = seq_len(chains),
  threads_per_chain = NULL,
  opencl_ids = NULL,
  iter_warmup = 1000,
  iter_sampling = 1000,
  save_warmup = FALSE,
  thin = NULL,
  max_treedepth = NULL,
  adapt_engaged = TRUE,
  adapt_delta = NULL,
  step_size = NULL,
  metric = NULL,
  metric_file = NULL,
  inv_metric = NULL,
  init_buffer = NULL,
  term_buffer = NULL,
  window = NULL,
  fixed_param = FALSE,
  show_messages = TRUE,
  compile = TRUE,
  ...
  ) {

  stan_file <- system.file(
    "extdata", "twostate.stan", package = "mrmr", mustWork = TRUE
  )

  model <- cmdstanr::cmdstan_model(stan_file, ...)

  m_fit <- model$sample(
    data = data$stan_d,
    seed = seed,
    refresh = refresh,
    init = init,
    save_latent_dynamics = save_latent_dynamics,
    output_dir = output_dir,
    output_basename = output_basename,
    sig_figs = sig_figs,
    chains = chains,
    parallel_chains = parallel_chains,
    chain_ids = chain_ids,
    threads_per_chain = threads_per_chain,
    opencl_ids = opencl_ids,
    iter_warmup = iter_warmup,
    iter_sampling = iter_sampling,
    save_warmup = save_warmup,
    thin = thin,
    max_treedepth = max_treedepth,
    adapt_engaged = adapt_engaged,
    adapt_delta = adapt_delta,
    step_size = step_size,
    metric = metric,
    metric_file = metric_file,
    inv_metric = inv_metric,
    init_buffer = init_buffer,
    term_buffer = term_buffer,
    window = window,
    fixed_param = fixed_param,
    show_messages = show_messages
  )

  # Load all the data and return the whole unserialized fit object:
  # https://github.com/stan-dev/cmdstanr/blob/d27994f804c493ff3047a2a98d693fa90b83af98/R/fit.R#L16-L18 # nolint
  m_fit$draws() # Do not specify variables or inc_warmup.
  try(m_fit$sampler_diagnostics(), silent = TRUE)
  try(m_fit$init(), silent = TRUE)
  try(m_fit$profiles(), silent = TRUE)

  list(m_fit = m_fit, data = data)
}
SNARL1/mrmr documentation built on Nov. 23, 2023, 7:04 a.m.