R/control.R

Defines functions spim_control_cores spim_control_parallel spim_control

Documented in spim_control spim_control_cores

##' High-level control.
##'
##' TODO: document parallel detection
##'
##' @title High-level control
##'
##' @param short_run Logical, indicating if this is a debug run. This
##'   will use freer particles, mcmc steps and a small sample
##'
##' @param n_chains The number of chains to run
##'
##' @param deterministic Logical, indicating if the model to fit to
##'   data is run deterministically or stochastically
##'
##' @param multiregion Logical, indicating if we are fitting multiple
##'   regions at once (in which case even the deterministic model may
##'   benefit from multithreading).
##'
##' @param rt Logical, indicating if we are calculating Rt trajectories or
##'   not. Default is TRUE
##'
##' @param severity Logical, indicating if we are outputting severity
##'   trajectories (e.g. IFR, IHR, HFR). Default to FALSE.
##'
##' @param intrinsic_severity Logical, indicating if we are calculating
##'   intrinsic severity (e.g. IFR, IHR, HFR). Default to FALSE.
##'
##' @param simulate Logical, indicating if we are outputting a simulate object
##'   for onward simulation
##'
##' @param demography Logical, indicating if we are outputting model
##'   demography - admissions and deaths by age
##'
##' @param date_restart Optionally, dates save restart data in the
##'   pmcmc (see [mcstate::pmcmc]
##'
##' @param n_particles number of particles to be used in particle filter if
##'   `short_run = FALSE`
##'
##' @param n_mcmc number of steps to be used in PMCMC if
##'   `short_run = FALSE`
##'
##' @param n_sample number of steps to retain (across all chains) if
##'   `short_run = FALSE`
##'
##' @param burnin number of steps out of `n_mcmc` to be used as a burn-in in
##'    PMCMC if `short_run = FALSE`
##'
##' @param workers Logical, indicating if we should enable workers. If
##'   `TRUE`, then a number of workers between 1 and 4 will be used
##'   depending on `n_chains` and the detected number of cores.
##'
##' @param n_threads Explicit number of threads, overriding detection
##'   by [spim_control_cores]
##'
##' @param compiled_compare Use a compiled compare function (rather
##'   than the R version).  This can speed things up with the
##'   deterministic models in particular.
##'
##' @param mcmc_path Path to store the mcmc results in
##'
##' @param adaptive_proposal Control the adaptive proposal. By default
##'   this is disabled (value of `NULL` or `FALSE`). This can only be
##'   enabled for determinsitic models. Pass either `TRUE` here or the
##'   results from [mcstate::adaptive_proposal_control()]
##'
##' @param verbose Logical, indicating if we should print information
##'   about the parallel configuration
##'
##' @return A list of options
##' @export
spim_control <- function(short_run, n_chains, deterministic = FALSE,
                         multiregion = FALSE, rt = TRUE, severity = FALSE,
                         intrinsic_severity = FALSE,
                         simulate = FALSE, demography = FALSE,
                         date_restart = NULL, n_particles = 192, n_mcmc = 1500,
                         burnin = 500, workers = TRUE, n_sample = 1000,
                         n_threads = NULL, compiled_compare = FALSE,
                         adaptive_proposal = NULL,
                         mcmc_path = NULL, verbose = TRUE) {
  if (short_run) {
    n_particles <- min(10, n_particles)
    n_mcmc <- min(20, n_mcmc)
    n_sample <- min(10, n_mcmc)
    n_chains <- min(4, n_chains)
    burnin <- 1
  }

  n_steps_retain <- ceiling(n_sample / n_chains)

  rerun_every <- if (deterministic) Inf else 100
  if (!is.null(date_restart)) {
    date_restart <- sircovid::sircovid_date(date_restart)
  }

  parallel <- spim_control_parallel(n_chains, workers, n_threads,
                                    deterministic, multiregion,
                                    verbose)

  ## Once happy here, you could change the default behaviour, so that
  ## if determinsitic and not `FALSE` we set this up.
  adaptive_proposal <- adaptive_proposal %||% FALSE
  adaptive_proposal_on <- isTRUE(adaptive_proposal) ||
    class(adaptive_proposal) == "adaptive_proposal_control"
  if (isTRUE(adaptive_proposal_on) && !deterministic) {
    stop("Can't use adaptive_proposal with non-deterministic models")
  }

  pmcmc <- mcstate::pmcmc_control(n_mcmc,
                                  n_chains = n_chains,
                                  n_threads_total = parallel$n_threads_total,
                                  n_workers = parallel$n_workers,
                                  save_state = TRUE,
                                  save_trajectories = TRUE,
                                  use_parallel_seed = TRUE,
                                  progress = interactive(),
                                  save_restart = date_restart,
                                  nested_step_ratio = 1, # ignored if single
                                  rerun_every = rerun_every,
                                  rerun_random = TRUE,
                                  filter_early_exit = !deterministic,
                                  n_burnin = burnin,
                                  adaptive_proposal = adaptive_proposal,
                                  n_steps_retain = n_steps_retain,
                                  path = mcmc_path)

  n_threads <- parallel$n_threads_total / parallel$n_workers

  particle_filter <- list(n_particles = n_particles,
                          n_threads = n_threads,
                          seed = NULL,
                          compiled_compare = compiled_compare,
                          rt = rt,
                          severity = severity,
                          intrinsic_severity = intrinsic_severity,
                          simulate = simulate,
                          demography = demography)

  list(pmcmc = pmcmc,
       particle_filter = particle_filter)
}


spim_control_parallel <- function(n_chains, workers, n_threads,
                                  deterministic, multiregion,
                                  verbose = TRUE) {
  n_threads <- n_threads %||% spim_control_cores()
  max_workers <- 4

  if (workers) {
    if (deterministic && !multiregion) {
      ## Increase the number of workers because each will be running
      ## separately
      n_workers <- min(n_chains, n_threads)
      n_threads <- n_workers
    } else if (deterministic && multiregion) {
      ## In the case of the deterministic multiregion case, we get
      ## more from workers (where available) than from within-particle
      ## parallelisation, so let's try and get these onto up to max
      ## workers, then increase the number of threads a bit.
      n_workers <- min(n_chains, n_threads, max_workers)
      n_threads_given <- n_threads
      n_threads <- ceiling(n_threads / n_workers) * n_workers
      if (verbose && n_threads > n_threads_given) {
        message(sprintf("Increasing total threads from %d to %d",
                        n_threads_given, n_threads))
      }
    } else {
      pos <- seq_len(max_workers)
      n_workers <- max(pos[n_threads %% pos == 0 & pos <= n_chains])
    }
  } else {
    n_workers <- 1L
  }

  if (verbose) {
    message(sprintf("Running on %d workers with %d threads",
                    n_workers, n_threads))
  }
  list(n_threads_total = n_threads, n_workers = n_workers)
}


##' Return the number of available cores, looking at `CONTEXT_CORES`,
##' `MC_CORES` and `mc.cores` in turn
##'
##' @title Return the number of cores
##' @return An integer
##' @export
spim_control_cores <- function() {
  as.integer(Sys.getenv("CONTEXT_CORES",
                        Sys.getenv("MC_CORES",
                                   getOption("mc.cores", 1))))
}
mrc-ide/spimalot documentation built on Oct. 15, 2024, 12:15 p.m.