R/fit_stem.R

Defines functions fit_stem

Documented in fit_stem

#' Fit a stochastic epidemic model using the linear noise approximation or
#' ordinary differential equations to approximate the latent epidemic process.
#'
#' @param stem_object a stochastic epidemic model object containing the dataset,
#'   model dynamics, and measurement process.
#' @param method either "lna" or "ode".
#' @param mcmc_kern MCMC transition kernel generated by a call to the
#'   \code{mcmc_kernel} function.
#' @param iterations number of iterations
#' @param initialization_attempts number of initialization attempts
#' @param ess_warmup number of preliminary ESS iterations for the LNA, initial
#'   conditions, and time varying parameters prior to starting MCMC
#' @param thinning_interval thinning interval for posterior samples, defaults to
#'   saving every 100th sample
#' @param return_adapt_rec should the MCMC samples be returned during
#'   adaptation? defaults to FALSE.
#' @param return_ess_rec should elliptical slice sampling steps and angles be
#'   returned? defaults to FALSE
#' @param print_progress interval at which to print progress to a text file. If
#'   0 (default) progress is not printed.
#' @param status_filename string to pre-append to status files, defaults to LNA
#'   or ODE depending on the method used.
#'
#' @return list with posterior samples for the parameters and the latent
#'   process, along with MCMC diagnostics.
#' @export
fit_stem =
    function(stem_object,
             method,
             mcmc_kern,
             iterations,
             initialization_attempts = 500,
             ess_warmup = 50,
             thinning_interval = 100,
             return_adapt_rec = FALSE,
             return_ess_rec = FALSE,
             print_progress = 0,
             status_filename = NULL) {

        # check that the data, dynamics and measurement process are all supplied
        if(is.null(stem_object$measurement_process$data) ||
           is.null(stem_object$dynamics) ||
           is.null(stem_object$measurement_process)) {
            stop("The dataset, dynamics, and measurement process must all be specified.")
        }

        # check that the appropriate model object is compiled
        if(method == "lna") {
            if(is.null(stem_object$dynamics$lna_pointers)) {
                stop("LNA is not compiled.")
            }

            if(is.null(status_filename)) status_filename <- "LNA"

        } else if(method == "ode") {
            if(is.null(stem_object$dynamics$ode_pointers)) {
                stop("ODE is not compiled.")
            }

            if(is.null(status_filename)) status_filename <- "ODE"
        }

        # if the MCMC is being restarted, save the existing results
        mcmc_restart <- !is.null(stem_object$restart)

        # grab time-varying parameters
        if(mcmc_restart) {
            path   <- stem_object$restart$path
            tparam <- stem_object$restart$tparam
        } else {
            path   <- NULL
            tparam <- stem_object$dynamics$tparam
        }

        # grab parameter names
        parameters      <- stem_object$dynamics$parameters
        param_names_nat <- names(stem_object$dynamics$param_codes)[!grepl("_0", names(stem_object$dynamics$param_codes))]
        initdist_names  <- names(stem_object$dynamics$param_codes)[grepl("_0", names(stem_object$dynamics$param_codes))]
        param_names_est <- c(sapply(mcmc_kern$parameter_blocks, function(x) x$pars_est))
        n_model_params  <- length(param_names_est)

        # unpack mcmc kernel
        param_blocks         <- mcmc_kern$parameter_blocks
        lna_ess_control      <- mcmc_kern$lna_ess_control
        initdist_ess_control <- mcmc_kern$initdist_ess_control
        tparam_ess_control   <- mcmc_kern$tparam_ess_control

        # parameter codes
        if(method == "lna") {
            n_pars_tot  <- length(stem_object$dynamics$lna_rates$lna_param_codes)
            param_codes <- stem_object$dynamics$lna_rates$lna_param_codes
        } else {
            n_pars_tot  <- length(stem_object$dynamics$ode_rates$ode_param_codes)
            param_codes <- stem_object$dynamics$ode_rates$ode_param_codes
        }

        # prepare param_blocks
        param_blocks <-
            prepare_param_blocks(
                param_blocks = param_blocks,
                parameters   = parameters,
                param_codes  = param_codes,
                iterations   = iterations,
                mcmc_restart = mcmc_restart)

        # maximum number of adaptive iterations
        max_adaptation <- max(sapply(param_blocks, function(x) x$control$stop_adaptation))

        # number of posterior samples
        n_samples <-
            if(return_adapt_rec) {
                floor(iterations / thinning_interval)
            } else {
                if(max_adaptation %% thinning_interval != 0) {
                    ceiling((iterations - max_adaptation) / thinning_interval)
                } else {
                    floor((iterations - max_adaptation) / thinning_interval)
                }
            }
        n_ess_recs <- n_samples
        record_sample <- FALSE

        # progress printing interval
        if(print_progress != 0) {
            progress_interval <- print_progress
            print_progress <- TRUE
        } else {
            progress_interval <- NULL
            print_progress <- FALSE
        }

        ### Unpack stem_object-------------------------

        # dynamics that are not method specific
        censusmat   <- stem_object$measurement_process$censusmat
        constants   <- stem_object$dynamics$constants
        initializer <- stem_object$dynamics$initializer
        fixed_inits <- stem_object$dynamics$fixed_inits
        n_strata    <- stem_object$dynamics$n_strata
        step_size   <- stem_object$dynamics$dynamics_args$step_size

        # measurement process objects that are not method specific
        measproc_indmat <- stem_object$measurement_process$measproc_indmat
        dat             <- stem_object$measurement_process$obsmat
        obstimes        <- stem_object$measurement_process$obstimes

        # initialize the ess_record
        ess_record <- list(lna_ess_record = NULL,
                           initdist_ess_record = NULL,
                           tparam_ess_record = NULL)

        # dynamics and measurement process objects that are method specific
        if(method == "lna") {

            # extract objects from dynamics
            flow_matrix      <- stem_object$dynamics$flow_matrix_lna
            n_compartments   <- ncol(flow_matrix)
            n_rates          <- nrow(flow_matrix)
            stoich_matrix    <- stem_object$dynamics$stoich_matrix_lna
            proc_pointer     <- stem_object$dynamics$lna_pointers$lna_ptr
            set_pars_pointer <- stem_object$dynamics$lna_pointers$set_lna_params_ptr
            do_prevalence    <- stem_object$measurement_process$lna_prevalence
            event_inds       <- stem_object$measurement_process$incidence_codes_lna
            initdist_inds    <- stem_object$dynamics$lna_initdist_inds
            approx_warmup    <- lna_ess_control$approx_warmup

            # measurement process
            d_meas_pointer   <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr

            # indices of parameters, constants, and time-varying covariates
            param_inds <-
                setdiff(stem_object$dynamics$param_codes,
                        stem_object$dynamics$lna_initdist_inds)
            const_inds <-
                length(stem_object$dynamics$param_codes) +
                seq_along(stem_object$dynamics$const_codes) - 1
            tcovar_inds <-
                length(stem_object$dynamics$param_codes) +
                length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1

            # should initial concentrations be updated jointly with the LNA path
            joint_initdist_update <- !fixed_inits & lna_ess_control$joint_initdist_update

            # objects for computing the SVD of the LNA diffusion matrix
            svd_U <- diag(0.0, n_rates)
            svd_V <- diag(0.0, n_rates)
            svd_d <- rep(0.0, n_rates)

            # grab the tparam indices and update scheme
            if(!is.null(tparam)) {
                tparam_inds <-
                    stem_object$dynamics$lna_rates$lna_param_codes[
                        sapply(tparam, function(x) x$tparam_name)]
            } else {
                tparam_inds <- NULL
            }

            lna_ess_schedule <-
                prepare_lna_ess_schedule(
                    stem_object = stem_object,
                    initializer = initializer,
                    lna_ess_control = lna_ess_control)

            if(return_ess_rec) {
                ess_record$lna_ess_record <-
                    list(ess_steps  =
                             array(1.0,
                                   dim = c(lna_ess_control$n_updates,
                                           length(lna_ess_schedule),
                                           n_ess_recs)),
                         ess_angles =
                             array(1.0,
                                   dim = c(lna_ess_control$n_updates,
                                           length(lna_ess_schedule),
                                           n_ess_recs)))
            }

        } else if(method == "ode") {

            # extract objects from dynamics
            flow_matrix         <- stem_object$dynamics$flow_matrix_ode
            n_compartments      <- ncol(flow_matrix)
            n_rates             <- nrow(flow_matrix)
            stoich_matrix       <- stem_object$dynamics$stoich_matrix_ode
            proc_pointer        <- stem_object$dynamics$ode_pointers$ode_ptr
            set_pars_pointer    <- stem_object$dynamics$ode_pointers$set_ode_params_ptr
            do_prevalence       <- stem_object$measurement_process$ode_prevalence
            event_inds          <- stem_object$measurement_process$incidence_codes_ode
            initdist_inds       <- stem_object$dynamics$ode_initdist_inds

            # measurement process
            d_meas_pointer <- stem_object$measurement_process$meas_pointers_lna$d_measure_ptr

            # indices of parameters, constants, and time-varying covariates
            param_inds <-
                setdiff(stem_object$dynamics$param_codes,
                        stem_object$dynamics$ode_initdist_inds)
            const_inds <-
                length(stem_object$dynamics$param_codes) +
                seq_along(stem_object$dynamics$const_codes) - 1
            tcovar_inds <-
                length(stem_object$dynamics$param_codes) +
                length(const_inds) + seq_along(stem_object$dynamics$tcovar_codes) - 1

            # obviously, initial distribution is its own ESS update
            joint_initdist_update <- FALSE

            # grab the tparam indices
            if(!is.null(tparam)) {
                tparam_inds <-
                    stem_object$dynamics$ode_rates$ode_param_codes[
                        sapply(tparam, function(x) x$tparam_name)]
            } else {
                tparam_inds <- NULL
            }

            # set some LNA objects to NULL
            svd_U <- NULL
            svd_V <- NULL
            svd_d <- NULL
            lna_ess_schedule <- NULL
        }

        ### Initial distribution objects --------------------------------------------
        if(n_strata == 1) {
            comp_size_vec <- constants["popsize"]
        } else {
            comp_size_vec <- constants[paste0("popsize_", sapply(initializer,"[[","strata"))]
        }

        # list for initial compartment volume objects
        if(mcmc_restart) {
            initdist_objects <- stem_object$restart$initdist_objects

        } else {
            initdist_objects <-
                prepare_initdist_objects(initializer = initializer,
                                         param_codes = param_codes,
                                         comp_size_vec = comp_size_vec)
        }

        # initialize initdist draws
        if(!fixed_inits) {

            # for recording the ess initdist updates
            if(!joint_initdist_update) {
                initdist_ess_control$steps  <- rep(1.0, initdist_ess_control$n_updates)
                initdist_ess_control$angles <- rep(0.0, initdist_ess_control$n_updates)

                initdist_ess_control$angle_mean  <- 0.0
                initdist_ess_control$angle_resid <- 0.0
                initdist_ess_control$angle_var   <- 0.0
            }

            if(return_ess_rec) {
                if(method != "lna" | !joint_initdist_update) {
                    ess_record$initdist_ess_record <-
                        list(ess_steps =
                                 matrix(1.0,
                                        nrow = initdist_ess_control$n_updates,
                                        ncol = n_ess_recs),
                             ess_angles =
                                 matrix(0.0,
                                        nrow = initdist_ess_control$n_updates,
                                        ncol = n_ess_recs))
                }
            }
        }

        # vector of census times
        census_times <-
            sort(unique(c(obstimes,
                          stem_object$dynamics$tcovar[, 1],
                          seq(stem_object$dynamics$t0,
                              stem_object$dynamics$tmax,
                              by = stem_object$dynamics$timestep),
                          stem_object$dynamics$tmax)))
        census_indices <- unique(c(0, findInterval(obstimes, census_times) - 1))
        n_times        <- length(census_times)

        # make sure no times are less than t0
        if(any(obstimes < stem_object$dynamics$t0)) {
            stop("Cannot have observations before time t0.")
        }

        if(any(stem_object$dynamics$tcovar[,1] < stem_object$dynamics$t0)) {
            stop("Cannot have time-varying covariates specified before time t0.")
        }

        ### Set up parameter objects ---------------------------------
        parmat <-
            matrix(0.0,
                   nrow = n_times,
                   ncol = n_pars_tot,
                   dimnames = list(NULL, names(param_codes)))

        # insert parameters into the parameter matrix
        insert_params(parmat       = parmat,
                      param_blocks = param_blocks,
                      nat          = TRUE,
                      prop         = FALSE,
                      rowind       = 0)

        # insert initial conditions
        insert_initdist(parmat           = parmat,
                        initdist_objects = initdist_objects,
                        prop             = FALSE,
                        rowind           = 0,
                        mcmc_rec         = FALSE)

        # insert the constants
        parmat[, const_inds + 1] <-
            matrix(stem_object$dynamics$constants,
                   nrow = nrow(parmat),
                   ncol = length(const_inds), byrow = T)

        # generate forcing indices and other objects
        forcing_inds <- rep(FALSE, length(census_times))

        # insert time varying covariates
        if(!is.null(stem_object$dynamics$tcovar)) {

            tcovar_rowinds <-
                findInterval(census_times, stem_object$dynamics$tcovar[, 1])
            parmat[, tcovar_inds + 1] <-
                stem_object$dynamics$tcovar[tcovar_rowinds, -1]

            # zero out forcings if necessary
            if(!is.null(stem_object$dynamics$forcings)) {

                # get the forcing indices (supplied in the original tcovar matrix)
                for(f in seq_along(stem_object$dynamics$forcings)) {
                    forcing_inds <-
                        forcing_inds |
                        stem_object$dynamics$tcovar[stem_object$dynamics$tcovar[,1] %in% census_times,
                                                    stem_object$dynamics$forcings[[f]]$tcovar_name] != 0
                }

                zero_inds <- !forcing_inds

                # zero out the tcovar elements corresponding to times with no forcings
                for(l in seq_along(stem_object$dynamics$dynamics_args$forcings)) {
                    parmat[zero_inds, stem_object$dynamics$dynamics_args$forcings[[l]]$tcovar_name] <- 0
                }
            }
        }

        # get indices for time-varying parameters
        if (!is.null(tparam)) {

            # verify whether the mcmc is being restarted
            if (!mcmc_restart) {

                # generate the indices for updating the time-varying parameter and initialize the values
                for (s in seq_along(tparam)) {

                    # can get rid of the values slot
                    tparam[[s]]$values <- NULL

                    # indices
                    tparam[[s]]$col_ind <- param_codes[tparam[[s]]$tparam_name]
                    tparam[[s]]$tpar_inds_R <-
                        findInterval(census_times, tparam[[s]]$times, left.open = F)
                    tparam[[s]]$tpar_inds_R[tparam[[s]]$tpar_inds_R == 0] <- 1
                    tparam[[s]]$tpar_inds_Cpp <- tparam[[s]]$tpar_inds_R - 1

                    # values
                    tparam[[s]]$draws_cur  <- rnorm(tparam[[s]]$n_draws)
                    tparam[[s]]$draws_prop <- rnorm(tparam[[s]]$n_draws)
                    tparam[[s]]$draws_ess  <- rnorm(tparam[[s]]$n_draws)
                    tparam[[s]]$log_lik <- sum(dnorm(tparam[[s]]$draws_cur, log = TRUE))

                    # get values
                    tparam[[s]]$tpar_cur <-
                        tparam[[s]]$draws2par(
                            parameters = parmat[1,],
                            draws = tparam[[s]]$draws_cur)[tparam[[s]]$tpar_inds_R]

                    # insert into the parameter matrix
                    vec_2_mat(dest = parmat,
                              orig = tparam[[s]]$tpar_cur,
                              ind  = tparam[[s]]$col_ind)

                    # bracket width
                    tparam[[s]]$bracket_width <- tparam_ess_control$bracket_width

                    # ess counters
                    tparam[[s]]$steps  <- rep(1.0, tparam_ess_control$n_updates)
                    tparam[[s]]$angles <- rep(0.0, tparam_ess_control$n_updates)

                    # for tuning the bracket
                    tparam[[s]]$angle_mean  <- 0
                    tparam[[s]]$angle_var   <- pi^2/3
                    tparam[[s]]$angle_resid <- 0
                }
            } else {

                # if restarting, just copy the values into the parameter matrices
                for (s in seq_along(tparam)) {

                    # insert into the parameter matrix
                    vec_2_mat(dest = parmat,
                              orig = tparam[[s]]$tpar_cur,
                              ind  = tparam[[s]]$col_ind)
                }
            }

            # check if tparam[[s]] depends on initial conditions
            tparam <- check_tpar_depends(tparam = tparam,
                                        initdist_objects = initdist_objects,
                                        parmat = parmat)

            if(return_ess_rec) {
                ess_record$tparam_ess_record <-
                    list(ess_steps  =
                             array(1.0,
                                   dim = c(tparam_ess_control$n_updates,
                                           length(tparam),
                                           n_ess_recs)),
                         ess_angles =
                             array(1.0,
                                   dim = c(tparam_ess_control$n_updates,
                                           length(tparam),
                                           n_ess_recs)))
            }

        } else {
            tparam <- NULL
        }

        # indices for when to update the parameters
        param_update_inds <- rep(FALSE, length(census_times))
        param_update_inds[1] <- TRUE

        if(!is.null(stem_object$dynamics$tcovar)) {
            param_update_inds[census_times %in% stem_object$dynamics$tcovar[,1]] <- TRUE
        }

        if(!is.null(tparam)) {
            for(s in seq_along(tparam)) {
                param_update_inds[census_times %in% tparam[[s]]$times] <- TRUE
            }
        }

        if(length(param_update_inds) == nrow(stem_object$dynamics$tcovar_changemat)) {
            param_update_inds <-
                param_update_inds | apply(stem_object$dynamics$tcovar_changemat, 1, any)
        }

        # generate forcing objects
        if(!is.null(stem_object$dynamics$forcings)) {

            forcings <- stem_object$dynamics$forcings

            # names and indices
            forcing_tcovars   <- sapply(forcings, function(x) x$tcovar_name)
            forcing_tcov_inds <- match(forcing_tcovars, colnames(parmat)) - 1
            forcing_events    <- c(sapply(forcings, function(x) paste0(x$from, "2", x$to)))

            # matrix indicating which compartments are involved in which forcings in and out
            forcings_out <-
                matrix(0.0,
                       nrow = ncol(flow_matrix),
                       ncol = length(forcings),
                       dimnames = list(colnames(flow_matrix),
                                       forcing_tcovars))

            forcing_transfers <-
                array(0.0,
                      dim = c(ncol(flow_matrix),
                              ncol(flow_matrix),
                              length(forcings)),
                      dimnames = list(colnames(flow_matrix),
                                      colnames(flow_matrix),
                                      forcing_tcovars))

            for(s in seq_along(forcings)) {

                forcings_out[forcings[[s]]$from, s] <- 1

                for(t in seq_along(forcings[[s]]$from)) {
                    forcing_transfers[forcings[[s]]$from[t], forcings[[s]]$from[t], s] <- -1
                    forcing_transfers[forcings[[s]]$to[t], forcings[[s]]$from[t], s] <- 1
                }
            }

        } else {
            forcing_tcovars   <- character(0L)
            forcing_tcov_inds <- integer(0L)
            forcing_events    <- character(0L)
            forcings_out      <- matrix(0.0, nrow = 0, ncol = 0)
            forcing_transfers <- array(0.0, dim = c(0,0,0))
        }

        # matrix in which to store the emission probabilities
        emitmat <- cbind(dat[, 1, drop = F],
                         matrix(0.0,
                                nrow = nrow(measproc_indmat),
                                ncol = ncol(measproc_indmat),
                                dimnames = list(NULL, colnames(measproc_indmat))))

        pathmat_prop <- cbind(census_times,
                              matrix(0.0,
                                     nrow = length(census_times),
                                     ncol = nrow(flow_matrix),
                                     dimnames = list(NULL, c(rownames(flow_matrix)))))

        # initialize the lna_param_vec for the measurement process
        param_vec <- parmat[1,]

        # initialize the latent path
        if (mcmc_restart) {
            # recompute the data log likelihood
            data_log_lik_prop <- NULL
            try({
                census_latent_path(
                    path                = path$latent_path,
                    census_path         = censusmat,
                    census_inds         = census_indices,
                    event_inds          = event_inds,
                    flow_matrix         = flow_matrix,
                    do_prevalence       = do_prevalence,
                    parmat              = parmat,
                    initdist_inds       = initdist_inds,
                    forcing_inds        = forcing_inds,
                    forcing_tcov_inds   = forcing_tcov_inds,
                    forcings_out        = forcings_out,
                    forcing_transfers   = forcing_transfers
                )

                # evaluate the density of the incidence counts
                evaluate_d_measure_LNA(
                    emitmat           = emitmat,
                    obsmat            = dat,
                    censusmat         = censusmat,
                    measproc_indmat   = measproc_indmat,
                    parameters        = parmat,
                    param_inds        = param_inds,
                    const_inds        = const_inds,
                    tcovar_inds       = tcovar_inds,
                    param_update_inds = param_update_inds,
                    census_indices    = census_indices,
                    param_vec         = param_vec,
                    d_meas_ptr        = d_meas_pointer
                )

                # compute the data log likelihood
                data_log_lik_prop <- sum(emitmat[, -1][measproc_indmat])
                if (is.nan(data_log_lik_prop)) data_log_lik_prop <- -Inf
            }, silent = TRUE)

            if (is.null(data_log_lik_prop)) {
                stop("Restart attempted with data log likelihood of negative infinity.")
            } else {
                path$data_log_lik <- data_log_lik_prop
            }

        } else {

            if(method == "lna") {
                inits = initialize_lna(
                    dat                     = dat,
                    parmat                  = parmat,
                    param_blocks            = param_blocks,
                    tparam                  = tparam,
                    censusmat               = censusmat,
                    emitmat                 = emitmat,
                    stoich_matrix           = stoich_matrix,
                    census_times            = census_times,
                    proc_pointer            = proc_pointer,
                    set_pars_pointer        = set_pars_pointer,
                    d_meas_pointer          = d_meas_pointer,
                    param_vec               = param_vec,
                    param_inds              = param_inds,
                    const_inds              = const_inds,
                    tcovar_inds             = tcovar_inds,
                    initdist_inds           = initdist_inds,
                    param_update_inds       = param_update_inds,
                    census_indices          = census_indices,
                    event_inds              = event_inds,
                    measproc_indmat         = measproc_indmat,
                    do_prevalence           = do_prevalence,
                    forcing_inds            = forcing_inds,
                    forcing_tcov_inds       = forcing_tcov_inds,
                    forcings_out            = forcings_out,
                    forcing_transfers       = forcing_transfers,
                    initialization_attempts = initialization_attempts,
                    step_size               = step_size,
                    initdist_objects        = initdist_objects,
                    ess_warmup              = approx_warmup)

            } else {
                inits = initialize_ode(
                    dat                     = dat,
                    parmat                  = parmat,
                    param_blocks            = param_blocks,
                    tparam                  = tparam,
                    censusmat               = censusmat,
                    emitmat                 = emitmat,
                    stoich_matrix           = stoich_matrix,
                    proc_pointer            = proc_pointer,
                    set_pars_pointer        = set_pars_pointer,
                    d_meas_pointer          = d_meas_pointer,
                    census_times            = census_times,
                    param_vec               = param_vec,
                    param_inds              = param_inds,
                    const_inds              = const_inds,
                    tcovar_inds             = tcovar_inds,
                    initdist_inds           = initdist_inds,
                    param_update_inds       = param_update_inds,
                    census_indices          = census_indices,
                    event_inds              = event_inds,
                    measproc_indmat         = measproc_indmat,
                    do_prevalence           = do_prevalence,
                    forcing_inds            = forcing_inds,
                    forcing_tcov_inds       = forcing_tcov_inds,
                    forcings_out            = forcings_out,
                    forcing_transfers       = forcing_transfers,
                    initialization_attempts = initialization_attempts,
                    step_size               = step_size,
                    initdist_objects        = initdist_objects)
            }

            # grab the initial path, param_blocks, initdist_objects, and tparam
            path             <- inits$path
            param_blocks     <- inits$param_blocks
            initdist_objects <- inits$initdist_objects
            tparam           <- inits$tparam
        }

        if(method == "lna") {

            # object for proposing new stochastic perturbations
            draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
            copy_mat(draws_prop, path$draws)

            # instatiate matrix for elliptical slice sampling draws
            ess_draws_prop <- matrix(0.0, nrow = nrow(path$draws), ncol = ncol(path$draws))
            copy_mat(ess_draws_prop, path$draws)

            # elliptical slice sampling MCMC record
            for(s in seq_along(lna_ess_schedule)) {
                lna_ess_schedule[[s]]$steps  <- rep(1.0, lna_ess_control$n_updates)
                lna_ess_schedule[[s]]$angles <- rep(0.0, lna_ess_control$n_updates)
            }
        }

        # warmup the LNA, initial conditions, or time-varying parameters
        for(warmup in seq_len(ess_warmup)) {

            if(method == "lna") {
                lna_update(
                    path                  = path,
                    dat                   = dat,
                    iter                  = 0,
                    parmat                = parmat,
                    lna_ess_schedule      = lna_ess_schedule,
                    lna_ess_control       = lna_ess_control,
                    initdist_objects      = initdist_objects,
                    tparam                = tparam,
                    pathmat_prop          = pathmat_prop,
                    censusmat             = censusmat,
                    draws_prop            = draws_prop,
                    ess_draws_prop        = ess_draws_prop,
                    emitmat               = emitmat,
                    flow_matrix           = flow_matrix,
                    stoich_matrix         = stoich_matrix,
                    census_times          = census_times,
                    forcing_inds          = forcing_inds,
                    forcing_tcov_inds     = forcing_tcov_inds,
                    forcings_out          = forcings_out,
                    forcing_transfers     = forcing_transfers,
                    param_vec             = param_vec,
                    param_inds            = param_inds,
                    const_inds            = const_inds,
                    tcovar_inds           = tcovar_inds,
                    initdist_inds         = initdist_inds,
                    param_update_inds     = param_update_inds,
                    census_indices        = census_indices,
                    event_inds            = event_inds,
                    measproc_indmat       = measproc_indmat,
                    svd_d                 = svd_d,
                    svd_U                 = svd_U,
                    svd_V                 = svd_V,
                    proc_pointer          = proc_pointer,
                    set_pars_pointer      = set_pars_pointer,
                    d_meas_pointer        = d_meas_pointer,
                    do_prevalence         = do_prevalence,
                    joint_initdist_update = joint_initdist_update,
                    step_size             = step_size
                )
            }

            if(!fixed_inits && !joint_initdist_update) {
                initdist_update(
                    path                 = path,
                    dat                  = dat,
                    iter                 = 0,
                    parmat               = parmat,
                    initdist_objects     = initdist_objects,
                    initdist_ess_control = initdist_ess_control,
                    tparam               = tparam,
                    pathmat_prop         = pathmat_prop,
                    censusmat            = censusmat,
                    draws_prop           = draws_prop,
                    ess_draws_prop       = ess_draws_prop,
                    emitmat              = emitmat,
                    flow_matrix          = flow_matrix,
                    stoich_matrix        = stoich_matrix,
                    census_times         = census_times,
                    forcing_inds         = forcing_inds,
                    forcing_tcov_inds    = forcing_tcov_inds,
                    forcings_out         = forcings_out,
                    forcing_transfers    = forcing_transfers,
                    param_vec            = param_vec,
                    param_inds           = param_inds,
                    const_inds           = const_inds,
                    tcovar_inds          = tcovar_inds,
                    initdist_inds        = initdist_inds,
                    param_update_inds    = param_update_inds,
                    census_indices       = census_indices,
                    event_inds           = event_inds,
                    measproc_indmat      = measproc_indmat,
                    svd_d                = svd_d,
                    svd_U                = svd_U,
                    svd_V                = svd_V,
                    proc_pointer         = proc_pointer,
                    set_pars_pointer     = set_pars_pointer,
                    d_meas_pointer       = d_meas_pointer,
                    do_prevalence        = do_prevalence,
                    step_size            = step_size
                )
            }

            if(!is.null(tparam)) {
                tparam_update(
                    path               = path,
                    dat                = dat,
                    iter               = 0,
                    parmat             = parmat,
                    tparam             = tparam,
                    tparam_ess_control = tparam_ess_control,
                    pathmat_prop       = pathmat_prop,
                    censusmat          = censusmat,
                    draws_prop         = draws_prop,
                    ess_draws_prop     = ess_draws_prop,
                    emitmat            = emitmat,
                    flow_matrix        = flow_matrix,
                    stoich_matrix      = stoich_matrix,
                    census_times       = census_times,
                    forcing_inds       = forcing_inds,
                    forcing_tcov_inds  = forcing_tcov_inds,
                    forcings_out       = forcings_out,
                    forcing_transfers  = forcing_transfers,
                    param_vec          = param_vec,
                    param_inds         = param_inds,
                    const_inds         = const_inds,
                    tcovar_inds        = tcovar_inds,
                    initdist_inds      = initdist_inds,
                    param_update_inds  = param_update_inds,
                    census_indices     = census_indices,
                    event_inds         = event_inds,
                    measproc_indmat    = measproc_indmat,
                    svd_d              = svd_d,
                    svd_U              = svd_U,
                    svd_V              = svd_V,
                    proc_pointer       = proc_pointer,
                    set_pars_pointer   = set_pars_pointer,
                    d_meas_pointer     = d_meas_pointer,
                    do_prevalence      = do_prevalence,
                    step_size          = step_size
                )
            }
        }

        # objects to store the paths and likelihood terms
        mcmc_samples <-
            list(data_log_lik          = rep(0.0, n_samples),
                 params_log_prior      = rep(0.0, n_samples),
                 parameter_samples_nat = matrix(0.0, nrow = n_samples, ncol = n_model_params,
                                                dimnames = list(NULL, param_names_nat)),
                 parameter_samples_est = matrix(0.0, nrow = n_samples, ncol = n_model_params,
                                                dimnames = list(NULL, param_names_est)),
                 latent_paths =
                     array(0.0, dim = c(length(census_times), 1 + n_rates, n_samples),
                           dimnames = list(NULL, c("time", rownames(flow_matrix)), NULL)))

        if(method == "lna") {
            # vector for saving the log-likelihood of the LNA draws
            mcmc_samples$lna_log_lik <- rep(0.0, n_samples)

            # matrix for saving LNA draws
            mcmc_samples$lna_draws <-
                array(0.0, dim = c(n_rates, length(census_times) - 1, n_samples),
                      dimnames = list(rownames(flow_matrix), NULL, NULL))
        }

        if(!fixed_inits) {

            mcmc_samples$initdist_log_lik <- rep(0.0, n_samples)

            mcmc_samples$initdist_samples <-
                matrix(0.0, nrow = n_samples, ncol = n_compartments,
                       dimnames = list(NULL, initdist_names))

            mcmc_samples$initdist_draws <-
                vector("list", length = length(initdist_objects))

            for(i in seq_along(initdist_objects)) {
                mcmc_samples$initdist_draws[[i]] <-
                    matrix(0.0,
                           nrow = length(initdist_objects[[i]]$draws_cur),
                           ncol = n_samples)
            }
        }

        if(!is.null(tparam)) {

            mcmc_samples$tparam_log_lik <- rep(0.0, n_samples)

            mcmc_samples$tparam_samples <-
                array(0.0, dim = c(n_times, length(tparam), n_samples))

            mcmc_samples$tparam_draws <-
                vector("list", length = length(initdist_objects))

            for(i in seq_along(tparam)) {
                mcmc_samples$tparam_draws[[i]] <-
                    matrix(0.0,
                           nrow = tparam[[i]]$n_draws,
                           ncol = n_samples)
            }
        }

        # record the initial parameter values
        rec_ind <- 0
        if(return_ess_rec) ess_rec_ind <- 0

        # initialize the status file if status updates are required
        if (print_progress) {
            status_file <-
                paste0(status_filename,
                       "_inference_status_",
                       as.numeric(Sys.time()),
                       ".txt")
            cat(
                "Beginning MCMC",
                file = status_file,
                sep = "\n",
                append = FALSE
            )
        }

        # begin the MCMC
        start.time <- Sys.time()
        for (iter in seq_len(iterations)) {

            # Sample new parameter values
            block_order = sample.int(length(param_blocks))
            for(ind in block_order) {

                if(param_blocks[[ind]]$alg == "mvnmh") {

                    # save the covariance matrix if stopping adaptation
                    if (iter == min(max_adaptation, iterations)) {

                        param_blocks[[ind]]$sigma <-
                            param_blocks[[ind]]$mvnmh_objects$proposal_scaling *
                            param_blocks[[ind]]$kernel_cov

                        colnames(param_blocks[[ind]]$sigma) <-
                            rownames(param_blocks[[ind]]$sigma) <-
                                param_blocks[[ind]]$param_names_est

                        comp_chol(param_blocks[[ind]]$kernel_cov_chol,
                                  param_blocks[[ind]]$sigma)
                    }
                    # print(paste("iter:", iter))
                    # print(paste("ind:", ind))
                    # print("param_blocks[[ind]]")
                    # print(param_blocks[[ind]])
                    # print("proposal_scaling:")
                    # print(param_blocks[[ind]]$mvnmh_objects$proposal_scaling)
                    # print("gain_factors")
                    # print(param_blocks[[ind]]$gain_factors[iter])

                    # sample new parameters
                    mvnmh_update(
                        param_blocks      = param_blocks,
                        ind               = ind,
                        iter              = iter,
                        parmat            = parmat,
                        dat               = dat,
                        path              = path,
                        pathmat_prop      = pathmat_prop,
                        tparam            = tparam,
                        census_times      = census_times,
                        flow_matrix       = flow_matrix,
                        stoich_matrix     = stoich_matrix,
                        censusmat         = censusmat,
                        emitmat           = emitmat,
                        param_vec         = param_vec,
                        param_inds        = param_inds,
                        const_inds        = const_inds,
                        tcovar_inds       = tcovar_inds,
                        param_update_inds = param_update_inds,
                        initdist_inds     = initdist_inds,
                        census_indices    = census_indices,
                        event_inds        = event_inds,
                        measproc_indmat   = measproc_indmat,
                        forcing_inds      = forcing_inds,
                        forcing_tcov_inds = forcing_tcov_inds,
                        forcings_out      = forcings_out,
                        forcing_transfers = forcing_transfers,
                        proc_pointer      = proc_pointer,
                        d_meas_pointer    = d_meas_pointer,
                        set_pars_pointer  = set_pars_pointer,
                        do_prevalence     = do_prevalence,
                        step_size         = step_size,
                        svd_d             = svd_d,
                        svd_U             = svd_U,
                        svd_V             = svd_V)

                } else if(param_blocks[[ind]]$alg == "mvnss") {

                    # save the covariance matrix and bracket if stopping adaptation
                    if (iter == min(max_adaptation, iterations)) {

                        # covariance matrix
                        param_blocks[[ind]]$sigma <- param_blocks[[ind]]$kernel_cov

                        colnames(param_blocks[[ind]]$sigma) <-
                            rownames(param_blocks[[ind]]$sigma) <-
                                param_blocks[[ind]]$param_names_est

                        # cholesky
                        comp_chol(param_blocks[[ind]]$kernel_cov_chol,
                                  param_blocks[[ind]]$sigma)

                    }

                    # sample new parameters
                    mvnss_update(
                        param_blocks      = param_blocks,
                        ind               = ind,
                        iter              = iter,
                        parmat            = parmat,
                        dat               = dat,
                        path              = path,
                        pathmat_prop      = pathmat_prop,
                        tparam            = tparam,
                        census_times      = census_times,
                        flow_matrix       = flow_matrix,
                        stoich_matrix     = stoich_matrix,
                        censusmat         = censusmat,
                        emitmat           = emitmat,
                        param_vec         = param_vec,
                        param_inds        = param_inds,
                        const_inds        = const_inds,
                        tcovar_inds       = tcovar_inds,
                        param_update_inds = param_update_inds,
                        initdist_inds     = initdist_inds,
                        census_indices    = census_indices,
                        event_inds        = event_inds,
                        measproc_indmat   = measproc_indmat,
                        forcing_inds      = forcing_inds,
                        forcing_tcov_inds = forcing_tcov_inds,
                        forcings_out      = forcings_out,
                        forcing_transfers = forcing_transfers,
                        proc_pointer      = proc_pointer,
                        d_meas_pointer    = d_meas_pointer,
                        set_pars_pointer  = set_pars_pointer,
                        do_prevalence     = do_prevalence,
                        step_size         = step_size,
                        svd_d             = svd_d,
                        svd_U             = svd_U,
                        svd_V             = svd_V)
                }
            }

            # update the initial compartment volumes
            if(!fixed_inits && !joint_initdist_update) {

                initdist_update(
                    path                 = path,
                    dat                  = dat,
                    iter                 = iter,
                    parmat               = parmat,
                    initdist_objects     = initdist_objects,
                    initdist_ess_control = initdist_ess_control,
                    tparam               = tparam,
                    pathmat_prop         = pathmat_prop,
                    censusmat            = censusmat,
                    draws_prop           = draws_prop,
                    ess_draws_prop       = ess_draws_prop,
                    emitmat              = emitmat,
                    flow_matrix          = flow_matrix,
                    stoich_matrix        = stoich_matrix,
                    census_times         = census_times,
                    forcing_inds         = forcing_inds,
                    forcing_tcov_inds    = forcing_tcov_inds,
                    forcings_out         = forcings_out,
                    forcing_transfers    = forcing_transfers,
                    param_vec            = param_vec,
                    param_inds           = param_inds,
                    const_inds           = const_inds,
                    tcovar_inds          = tcovar_inds,
                    initdist_inds        = initdist_inds,
                    param_update_inds    = param_update_inds,
                    census_indices       = census_indices,
                    event_inds           = event_inds,
                    measproc_indmat      = measproc_indmat,
                    svd_d                = svd_d,
                    svd_U                = svd_U,
                    svd_V                = svd_V,
                    proc_pointer         = proc_pointer,
                    set_pars_pointer     = set_pars_pointer,
                    d_meas_pointer       = d_meas_pointer,
                    do_prevalence        = do_prevalence,
                    step_size            = step_size
                )
            }

            # update the tparam draws
            if (!is.null(tparam)) {

                tparam_update(
                    path               = path,
                    dat                = dat,
                    iter               = iter,
                    parmat             = parmat,
                    tparam             = tparam,
                    tparam_ess_control = tparam_ess_control,
                    pathmat_prop       = pathmat_prop,
                    censusmat          = censusmat,
                    draws_prop         = draws_prop,
                    ess_draws_prop     = ess_draws_prop,
                    emitmat            = emitmat,
                    flow_matrix        = flow_matrix,
                    stoich_matrix      = stoich_matrix,
                    census_times       = census_times,
                    forcing_inds       = forcing_inds,
                    forcing_tcov_inds  = forcing_tcov_inds,
                    forcings_out       = forcings_out,
                    forcing_transfers  = forcing_transfers,
                    param_vec          = param_vec,
                    param_inds         = param_inds,
                    const_inds         = const_inds,
                    tcovar_inds        = tcovar_inds,
                    initdist_inds      = initdist_inds,
                    param_update_inds  = param_update_inds,
                    census_indices     = census_indices,
                    event_inds         = event_inds,
                    measproc_indmat    = measproc_indmat,
                    svd_d              = svd_d,
                    svd_U              = svd_U,
                    svd_V              = svd_V,
                    proc_pointer       = proc_pointer,
                    set_pars_pointer   = set_pars_pointer,
                    d_meas_pointer     = d_meas_pointer,
                    do_prevalence      = do_prevalence,
                    step_size          = step_size
                )
            }

            # Update the path via elliptical slice sampling
            if(method == "lna") {
                lna_update(
                    path                  = path,
                    dat                   = dat,
                    iter                  = iter,
                    parmat                = parmat,
                    lna_ess_schedule      = lna_ess_schedule,
                    lna_ess_control       = lna_ess_control,
                    initdist_objects      = initdist_objects,
                    tparam                = tparam,
                    pathmat_prop          = pathmat_prop,
                    censusmat             = censusmat,
                    draws_prop            = draws_prop,
                    ess_draws_prop        = ess_draws_prop,
                    emitmat               = emitmat,
                    flow_matrix           = flow_matrix,
                    stoich_matrix         = stoich_matrix,
                    census_times          = census_times,
                    forcing_inds          = forcing_inds,
                    forcing_tcov_inds     = forcing_tcov_inds,
                    forcings_out          = forcings_out,
                    forcing_transfers     = forcing_transfers,
                    param_vec             = param_vec,
                    param_inds            = param_inds,
                    const_inds            = const_inds,
                    tcovar_inds           = tcovar_inds,
                    initdist_inds         = initdist_inds,
                    param_update_inds     = param_update_inds,
                    census_indices        = census_indices,
                    event_inds            = event_inds,
                    measproc_indmat       = measproc_indmat,
                    svd_d                 = svd_d,
                    svd_U                 = svd_U,
                    svd_V                 = svd_V,
                    proc_pointer          = proc_pointer,
                    set_pars_pointer      = set_pars_pointer,
                    d_meas_pointer        = d_meas_pointer,
                    do_prevalence         = do_prevalence,
                    joint_initdist_update = joint_initdist_update,
                    step_size             = step_size
                )
            }

            # Save the MCMC sample if called for in this iteration
            if(record_sample && iter %% thinning_interval == 0) {
                save_mcmc_sample(
                    mcmc_samples     = mcmc_samples,
                    rec_ind          = rec_ind,
                    path             = path,
                    parmat           = parmat,
                    param_blocks     = param_blocks,
                    initdist_objects = initdist_objects,
                    tparam           = tparam,
                    tparam_inds      = tparam_inds,
                    method           = method
                )

                if(return_ess_rec) {
                    save_ess_rec(
                        ess_record            = ess_record,
                        ess_rec_ind           = ess_rec_ind,
                        lna_ess_schedule      = lna_ess_schedule,
                        tparam                = tparam,
                        initdist_ess_control  = initdist_ess_control
                    )
                }
            }

            # check if the sample should be recorded
            if(!record_sample && iter == (max_adaptation+1)) {
                record_sample = TRUE
            }

            # print status messages if called for
            if(print_progress && iter %% progress_interval == 0) {

                cat(paste0("Iteration: ",iter),
                    file = status_file,
                    sep = "\n",
                    append = TRUE)

                for(s in seq_along(param_blocks)) {
                    if(param_blocks[[s]]$alg == "mvnmh") {
                        cat(paste0("\t", "Parameter block: ", s),
                            paste0("\t", "\t", "Accepted proposals: ", param_blocks[[s]]$mvnmh_objects$acceptances),
                            paste0("\t", "\t", "Acceptance rate: ", param_blocks[[s]]$mvnmh_objects$acceptances / iter),
                            paste0("\t", "\t", "Proposal scaling: ", param_blocks[[s]]$mvnmh_objects$proposal_scaling),
                            file = status_file,
                            sep = "\n",
                            append = TRUE)
                    } else {
                        cat(paste0("\t", "Parameter block: ", s),
                            paste0("\t", "\t", "Contractions: ", param_blocks[[s]]$mvnss_objects$n_contractions - 0.5),
                            paste0("\t", "\t", "Expansions: ", param_blocks[[s]]$mvnss_objects$n_expansions - 0.5),
                            file = status_file,
                            sep = "\n",
                            append = TRUE)
                    }
                }
            }
        }

        # record the end time
        end.time <- Sys.time()

        # compile the results
        stem_object$results <-
            list(runtime = difftime(end.time, start.time, units = "hours"),
                 posterior = mcmc_samples)

        if(return_ess_rec) stem_object$results$ess_record <- ess_record

        # save inits for restart
        stem_object$dynamics$parameters <- parmat[1, param_inds + 1]
        if(!fixed_inits) stem_object$dynamics$initdist_params <- parmat[1, initdist_inds + 1]
        if(!is.null(tparam)) stem_object$dynamics$tparam <- tparam

        stem_object$restart <-
            list(path                 = path,
                 param_blocks         = param_blocks,
                 initdist_objects     = initdist_objects,
                 tparam               = tparam)

        # return stem_object
        return(stem_object)
    }
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.