R/parblock.R

Defines functions parblock

Documented in parblock

#' Define a parameter block for an MCMC kernel
#'
#' @param pars_nat character vector of parameter names on their natural scales
#'   (i.e., corresponding to parameters supplied to stem_dynamics)
#' @param pars_est character vector of parameter names on their estimation scale
#'   (i.e., corresponding to row/column names in the kernel covariance matrix).
#'   IMPORTANT: the estimation scale for a parameter should be unconstrained,
#'   e.g., if mu is the recovery rate in an SIR model, it is natural to use
#'   the log infectious period duration, -log(mu) as the estimation scale.
#' @param priors A list of three functions supplied by the user with names
#'   "logprior", "to_estimation_scale", and "from_estimation_scale" (N.B. All
#'   three must be supplied). The first of these functions should take as an
#'   argument a numeric vector of parameters on their estimation scales. The
#'   functions for converting parameters to and from their estimation scales.
#'   The priors should not include priors for the initial compartment counts or
#'   time-varying parameters.
#' @param alg either "mvnmh"  or "mvnss" for multivariate normal metropolis
#'   hastings updates or multivariate normal slice sampling updates,
#'   respectively.
#' @param sigma initial covariance matrix for the parameter block, possibly to
#'   be adapted.
#' @param initializer optional function for initializing the parameters in the
#'   parameter block
#' @param control list of mcmc control settings, generated by a call to
#'   \code{mvnmh_control} or \code{mvnss_control} as appropriate.
#'
#' @return parameter block for use in MCMC kernel
#' @export
parblock <-
    function(pars_nat,
             pars_est,
             priors,
             alg,
             sigma,
             initializer = NULL,
             control = NULL) {

    if(!alg %in% c("mvnmh", "mvnss")) {
        stop("MCMC algorithm for updating parameters must be one of 'mvnmh' or 'mvnss'.")
    }

    if(length(pars_nat) != length(pars_est)) {
        stop("pars_nat and pars_est must have the same length.")
    }

    # make sure sigma has row and column names
    if(is.null(rownames(sigma))) rownames(sigma) = pars_est
    if(is.null(colnames(sigma))) colnames(sigma) = pars_est

    if(is.null(control)) {
        control =
            if(alg == "mvnmh") {
                mvnmh_control()
            } else {
                mvnss_control()
            }
    }

    if(is.null(control$nugget)) {
        control$nugget =
            if(alg == "mvnmh") {
                control$nugget = 0.001 * min(diag(sigma))
            } else {
                control$nugget = 0.5
            }
    }

    if(alg == "mvnmh" & is.null(control$target_acceptance)) stop('alg is "mvnmh", but control is "mvnss_control"')
    if(alg == "mvnss" & is.null(control$bracket_limits)) stop('alg is "mvnss", but control is "mvnmh_control"')

    return(list(pars_nat = pars_nat,
                pars_est = pars_est,
                priors = priors,
                alg = alg,
                sigma = sigma,
                initializer = initializer,
                control = control))
}
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.