R/stem_measure.R

Defines functions stem_measure

Documented in stem_measure

#' Generate a list of objects governing the measurement process for a stochastic
#' epidemic model.
#'
#' @param emissions list of emmision lists, each generated by a call to the
#'   \code{\link{emission}} function.
#' @param dynamics processed list of objects governing the model dynamics,
#'   returned by the \code{\link{stem_dynamics}} function.
#' @param data matrix/data frame, or a list of matrices/data frames. All columns
#'   must be named according to which compartment_strata are measured. The first
#'   column must consist of observation times, t_1,...,t_L. If data on all
#'   measured compartments are accrued at the same observation times, a single
#'   matrix or data frame may be provided. If each compartment in the data was
#'   measured at different observation times, a list of matrices or data frames
#'   must be provided. Again, the first column of each matrix must consist of
#'   observation times, while subsequent columns must be labeled according to
#'   which compartment being measured.
#' @param messages should compilation messages be printed? defaults to true.
#'
#' @return list with evaluated measurement process functions and objects. The
#'   list contains the following objects: \describe{\item{meas_procs}{list of
#'   parsed measurement process functions} \item{meas_pointers}{external
#'   pointers to compiled functions to simulate from and evaluate the density of
#'   the measurement process} \item{obstimes}{complete vector of observation
#'   times} \item{obstime_inds}{list of indices (C++) of observation times for
#'   each of the measurement variables} \item{obsmat}{either a template for an
#'   observation matrix, or an observation matrix that combines the supplied
#'   list of observation matrices} \item{obscomp_codes}{named numeric vector of
#'   measurement variable codes} \item{measproc_indmat}{indicator matrix for
#'   which measurement variables are measured at which observation times}
#'   \item{meas_inds}{indices (C++) of elements in the observation matrix that
#'   correspond to measurements (non-NAs)} \item{censusmat}{template matrix for
#'   storing the compartment counts at observation times}
#'   \item{tcovar_censmat}{matrix with time-varying coviates censused at
#'   observation times} \item{lna_prevalence}{indicator for whether prevalence
#'   is computed in the LNA} \item{lna_incidence}{indicator for whether
#'   incidence is computed in the LNA} \item{incidence_codes_lna}{C++ column
#'   indices for LNA count processes on transition events for which incidence is
#'   to be computed}}
#' @export
stem_measure <- function(emissions, dynamics, data = NULL, messages = FALSE) {

        if(is.null(data)) {
                if(any(unlist(lapply(lapply(emissions, "[[", "obstimes"), is.null)))) {
                    stop("If no dataset is provided, the observation times for each measurement process must be supplied in an emission list.")
                }
                
                if(max(unlist(sapply(emissions, "[[", "obstimes"))) != dynamics$tmax) {
                    warning("The tmax argument in dynamics is not equal to the maximum observation time.")
                }
        } else {
              if(is.data.frame(data)) data <- as.matrix(data)
              
              if(max(unlist(sapply(emissions, "[[", "obstimes"))) != dynamics$tmax) {
                warning("The tmax argument in dynamics is not equal to the maximum observation time in the emission lists.")
              }
              
              if((is.matrix(data) | is.data.frame(data)) && 
                 max(unlist(sapply(emissions, "[[", "obstimes"))) != max(data[,1])) {
                warning("The tmax argument in dynamics is not equal to the maximum observation time in the data object.")
              }
              
              if((is.matrix(data) | is.data.frame(data)) && 
                 !identical(sort(unique(unlist(sapply(emissions, "[[", "obstimes")))),
                            sort(unique(data[,1])))) {
                warnings("The observation times in the emission lists do not concord with the observation times in the data object.")
              }
              
              if((class(data) != "list" && max(data[,1]) != dynamics$tmax) ||
                 (class(data) == "list" && 
                  max(sapply(data, function(x) max(x[,1]))) != dynamics$tmax)) {
                warning("The maximum observation time in the data object is not equal to the maximum observation time in the emission lists.")
              }
        }
      
        # determine whether the measurement process should be compiled for the LNA/ODE and/or for gillespie
        do_exact  <- !is.null(dynamics$rate_ptrs$lumped_ptr)
        do_approx <- !is.null(dynamics$lna_pointers$lna_ptr) | !is.null(dynamics$ode_pointers$ode_ptr)

        # expand the emission lists if there are any strata specified
        # First, if no strata are specified
        if(all(sapply(sapply(emissions, "[[", "strata"), is.null))) {
              
                strata_specified <- FALSE
                meas_procs       <- vector(mode = "list", length = length(emissions))
                meas_procs_lna   <- vector(mode = "list", length = length(emissions))

                for(k in seq_along(meas_procs)) {
                        meas_procs[[k]] <- vector(mode = "list", length = 7L)
                        names(meas_procs[[k]]) <-
                              c(
                                    "dmeasure",
                                    "rmeasure",
                                    "distribution",
                                    "meas_var",
                                    "emission_params",
                                    "incidence",
                                    "obstimes"
                              )
                        
                        # assign meas_var and incidence
                        meas_procs[[k]]$distribution <- emissions[[k]]$distribution
                        meas_procs[[k]]$meas_var     <- emissions[[k]]$meas_var
                        meas_procs[[k]]$incidence    <- emissions[[k]]$incidence

                        # get the observation times
                        if(is.null(data)) meas_procs[[k]]$obstimes <- emissions[[k]]$obstimes

                        # get the emission distribution params
                        meas_procs[[k]]$emission_params <- 
                                    emissions[[k]]$emission_params
                }

        } else {# If strata are specified
                strata_specified <- TRUE
                meas_procs <- vector(mode = "list", length = length(emissions))

                for(k in seq_along(meas_procs)) {

                        # get the relevant strata
                        if(identical(emissions[[k]]$strata, "ALL")) {
                                rel_strata <- names(dynamics$strata_codes)
                        } else {
                                rel_strata <- emissions[[k]]$strata
                        }

                        # ensure that there is one rate function per stratum
                        meas_procs[[k]] <- vector(mode = "list", length = length(rel_strata))

                        for(j in seq_along(meas_procs[[k]])) {

                                meas_procs[[k]][[j]] <- vector(mode = "list", length = 7L)
                                names(meas_procs[[k]][[j]]) <-
                                      c(
                                            "dmeasure",
                                            "rmeasure",
                                            "distribution",
                                            "meas_var",
                                            "emission_params",
                                            "incidence",
                                            "obstimes"
                                      )
                                
                                # assign meas_var and incidence
                                meas_procs[[k]][[j]]$distribution <- emissions[[k]]$distribution
                                meas_procs[[k]][[j]]$meas_var     <- emissions[[k]]$meas_var
                                meas_procs[[k]][[j]]$incidence    <- emissions[[k]]$incidence

                                # get the observation times
                                if(is.null(data)) meas_procs[[k]][[j]]$obstimes <- emissions[[k]]$obstimes

                                # make SELF replacements in the meas_var argument
                                meas_procs[[k]][[j]]$meas_var <- gsub("SELF", rel_strata[[j]], meas_procs[[k]][[j]]$meas_var)

                                # get the emission distribution params
                                meas_procs[[k]][[j]]$emission_params <- 
                                      sapply(emissions[[k]]$emission_params,
                                             gsub,
                                             pattern = "SELF",
                                             replacement = rel_strata[[j]])
                        }
                }
        }

        if(strata_specified) meas_procs <- unlist(meas_procs, recursive = FALSE)

        # substitute powers in the emission parameters
        for(k in seq_along(meas_procs)) {
                for(j in seq_along(meas_procs[[k]]$emission_params)) {
                  
                  meas_procs[[k]]$emission_params[j] <- 
                    paste0(
                      deparse(
                        sub_powers(
                          parse(text = meas_procs[[k]]$emission_params[j]))[[1]]), 
                      collapse = "")
                  
                  meas_procs[[k]]$emission_params[j] <- 
                    gsub(" ", "", meas_procs[[k]]$emission_params[j])
                }
        }
        
        # set the observation times in meas_procs
        for(s in seq_along(meas_procs)) {
          meas_procs[[s]]$obstimes <- emissions[[s]]$obstimes
        }

        # if a dataset or list of datasets is supplied, extract the observation times, combine them and generate the indicator matrix
        if(!is.null(data)) {
              
                if(class(data)[1] == "data.frame") data <- as.matrix(data)
              
                obsmat          <- build_obsmat(datasets = data)                # observation matrix
                obstimes        <- obsmat[,"time"]                              # vector of observation times
                measproc_indmat <- build_measproc_indmat(obsmat = obsmat)       # indicator matrix for which measurment variables are observed at which times
                meas_inds       <- which(!is.na(obsmat[,-1, drop = FALSE]), arr.ind = T) - 1  # matrix of C++ indices in the observation matrix which are not NA.

        } else if(is.null(data)) {
                # if a dataset is not supplied, create a template for the observation matrix
                obsmat          <- build_obsmat(meas_procs = meas_procs)
                obstimes        <- obsmat[,"time"]
                measproc_indmat <- build_measproc_indmat(obsmat = obsmat)
                meas_inds       <- which(!is.na(obsmat[,-1, drop = FALSE]), arr.ind = T) - 1
        }

        # having made the name substitutions and constructed the observation matrix,
        # proceed to make subsitutions for argument vector indices
        obscomp_codes <- seq_len(ncol(measproc_indmat)); names(obscomp_codes) <- colnames(measproc_indmat)
        obscomp_names <- colnames(measproc_indmat)
        
        # generate a copy of the list for the LNA meas procs
        meas_procs_lna <- meas_procs

        for(s in seq_along(meas_procs)) {

                # make the substitution for meas_var so it is the location in a rowvector of observations
                meas_procs[[s]]$meas_var     <- paste0("record[", obscomp_codes[meas_procs[[s]]$meas_var], "]")
                meas_procs_lna[[s]]$meas_var <- paste0("record[", obscomp_codes[meas_procs_lna[[s]]$meas_var], "]")

                # make substitutions for the emission parameters 
                # substitute compartments, time varying covariates, constants, and parameters
                
                # make the substitutions for the parameter codes
                for(t in seq_along(dynamics$param_codes)) {
                  
                        code_name <- names(dynamics$param_codes)[t]
                        code      <- dynamics$param_codes[t]
                        
                        meas_procs[[s]]$emission_params <- 
                              sapply(meas_procs[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("parameters[",code,"]"))
                        
                        meas_procs_lna[[s]]$emission_params <- 
                              sapply(meas_procs_lna[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("parameters[",code,"]"))
                }

                # make the substitutions for the time-varying covariate codes
                for(t in seq_along(dynamics$tcovar_codes)) {
                  
                        code_name <- names(dynamics$tcovar_codes)[t]
                        code      <- dynamics$tcovar_codes[t]
                        code_lna  <- dynamics$tcovar_codes[t]-1
                        
                        meas_procs[[s]]$emission_params <- 
                              sapply(meas_procs[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("tcovar[",code,"]"))
                        
                        meas_procs_lna[[s]]$emission_params <- 
                              sapply(meas_procs_lna[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("tcovar[",code_lna,"]"))
                }

                # make the substitutions for the constant codes
                for(t in seq_along(dynamics$const_codes)) {
                        code_name <- names(dynamics$const_codes)[t]
                        code      <- dynamics$const_codes[t]
                        
                        meas_procs[[s]]$emission_params <-        
                              sapply(meas_procs[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("constants[",code,"]"))
                        
                        meas_procs_lna[[s]]$emission_params <- 
                              sapply(meas_procs_lna[[s]]$emission_params, 
                                     FUN = gsub,
                                     pattern = paste0('\\<',code_name,'\\>'), 
                                     replacement = paste0("constants[",code,"]"))
                }

                # make the substitutions for the compartment codes
                if(meas_procs[[s]]$incidence) {

                        for(t in seq_along(dynamics$incidence_codes)) {
                                code_name <- names(dynamics$incidence_codes)[t]
                                code      <- dynamics$incidence_codes[t] + 1
                                meas_procs[[s]]$emission_params <- 
                                      sapply(meas_procs[[s]]$emission_params, 
                                             FUN = gsub,
                                             pattern = paste0('\\<',code_name,'\\>'), 
                                             replacement = paste0("state[",code,"]"))
                                
                                meas_procs_lna[[s]]$emission_params <- 
                                      sapply(meas_procs_lna[[s]]$emission_params, 
                                             FUN = gsub,
                                             pattern = paste0('\\<',code_name,'\\>'), 
                                             replacement = paste0("state[",code,"]"))
                        }

                } else {
                        for(t in seq_along(dynamics$comp_codes)) {
                                code_name <- names(dynamics$comp_codes)[t]
                                code      <- dynamics$comp_codes[t] + 1
                                meas_procs[[s]]$emission_params <- 
                                      sapply(meas_procs[[s]]$emission_params, 
                                             FUN = gsub,
                                             pattern = paste0('\\<',code_name,'\\>'), 
                                             replacement = paste0("state[",code,"]"))
                                
                                meas_procs_lna[[s]]$emission_params <- 
                                      sapply(meas_procs_lna[[s]]$emission_params, 
                                             FUN = gsub,
                                             pattern = paste0('\\<',code_name,'\\>'), 
                                             replacement = paste0("state[",code,"]"))
                        }
                }
        }

        # generate the rmeasure and dmeasure functions
        for(k in seq_along(meas_procs)) {
              
                # generate the dmeasure and rmeasure strings 
                if(meas_procs[[k]]$distribution == "poisson") {

                        meas_procs[[k]]$rmeasure <- paste0("Rcpp::rpois(1,", meas_procs[[k]]$emission_params, ")")
                        meas_procs[[k]]$dmeasure <- paste0("Rcpp::dpois(obs,", paste(meas_procs[[k]]$emission_params, collapse = ","), ",1)")
                        
                        meas_procs_lna[[k]]$rmeasure <- paste0("Rcpp::rpois(1,", meas_procs_lna[[k]]$emission_params, ")")
                        meas_procs_lna[[k]]$dmeasure <- paste0("Rcpp::dpois(obs,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ",1)")

                } else if(meas_procs[[k]]$distribution == "negbinomial") {

                        meas_procs[[k]]$rmeasure <- paste0("Rcpp::rnbinom_mu(1,", paste(meas_procs[[k]]$emission_params, collapse = ","), ")")
                        meas_procs[[k]]$dmeasure <- paste0("Rcpp::dnbinom_mu(obs,", paste( meas_procs[[k]]$emission_params, collapse = ","), ",1)")
                        
                        meas_procs_lna[[k]]$rmeasure <- paste0("Rcpp::rnbinom_mu(1,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ")")
                        meas_procs_lna[[k]]$dmeasure <- paste0("Rcpp::dnbinom_mu(obs,", paste( meas_procs_lna[[k]]$emission_params, collapse = ","), ",1)")
                        
                } else if(meas_procs[[k]]$distribution == "binomial") {
                  
                  meas_procs[[k]]$rmeasure <- paste0("Rcpp::rbinom(1,", paste(meas_procs[[k]]$emission_params, collapse = ","), ")")
                  meas_procs[[k]]$dmeasure <- paste0("Rcpp::dbinom(obs,", paste(meas_procs[[k]]$emission_params, collapse = ","), ",1)")
                  
                  meas_procs_lna[[k]]$rmeasure <- paste0("Rcpp::rbinom(1,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ")")
                  meas_procs_lna[[k]]$dmeasure <- paste0("Rcpp::dbinom(obs,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ",1)")
                  
                } else if(meas_procs[[k]]$distribution == "betabinomial") {
                  
                  meas_procs[[k]]$rmeasure <- 
                    paste0("extraDistr::cpp_rbbinom(1,", 
                           paste(paste0("Rcpp::wrap(", meas_procs[[k]]$emission_params, ")"),
                                 collapse = ", "), ")")
                  
                  meas_procs[[k]]$dmeasure <- 
                    paste0("extraDistr::cpp_dbbinom(obs,", 
                           paste(paste0("Rcpp::wrap(", meas_procs[[k]]$emission_params, ")"),
                             collapse = ","), ",true)")
                  
                  meas_procs_lna[[k]]$rmeasure <- 
                    paste0("extraDistr::cpp_rbbinom(1,", 
                           paste(paste0("Rcpp::wrap(", meas_procs_lna[[k]]$emission_params, ")"),
                                 collapse = ","), ")")
                  
                  meas_procs_lna[[k]]$dmeasure <- 
                    paste0("extraDistr::cpp_dbbinom(obs,", 
                           paste(paste0("Rcpp::wrap(", meas_procs_lna[[k]]$emission_params, ")"), 
                                 collapse = ","), ",true)")
                  
                } else if(meas_procs[[k]]$distribution == "gaussian") {

                        meas_procs[[k]]$rmeasure <- paste0("Rcpp::rnorm(1,", paste(meas_procs[[k]]$emission_params, collapse = ","), ")")
                        meas_procs[[k]]$dmeasure <- paste0("Rcpp::dnorm(obs,", paste(meas_procs[[k]]$emission_params, collapse = ","), ",1)")
                        
                        meas_procs_lna[[k]]$rmeasure <- paste0("Rcpp::rnorm(1,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ")")
                        meas_procs_lna[[k]]$dmeasure <- paste0("Rcpp::dnorm(obs,", paste(meas_procs_lna[[k]]$emission_params, collapse = ","), ",true)")
                }
        }

        # get the pointers for the rmeasure and dmeasure functions
        meas_pointers <- 
              if(do_exact) {
                    parse_meas_procs(meas_procs, messages = messages)
              } else {
                    NULL
              }
        
        meas_pointers_lna <- 
              if(do_approx) {
                    parse_meas_procs(meas_procs_lna, messages = messages)
              } else {
                    NULL
              }
        
        # grab the code
        meas_proc_code <- vector("list")
        
        if(do_exact) {
          meas_proc_code$exact_meas_code <- meas_pointers$meas_proc_code
          meas_pointers$meas_proc_code   <- NULL
        } else {
          meas_proc_code$exact_meas_code <- NULL
        }
        
        if(do_approx) {
          meas_proc_code$approx_meas_code  <- meas_pointers_lna$meas_proc_code
          meas_pointers_lna$meas_proc_code <- NULL
        } else {
          meas_proc_code$approx_meas_code <- NULL
        }
        
        # reconcile measproc_indmat with the supplied observation times
        for(s in seq_along(meas_procs)) {
          measproc_indmat[,s] = obstimes %in% meas_procs[[s]]$obstimes
        }
        
        # initialize a matrix for storing the compartment counts at observation times
        censusmat <- matrix(0.0, nrow = length(obstimes),
                            ncol = length(dynamics$comp_codes) + length(dynamics$incidence_codes) + 1)
        colnames(censusmat) <- c("time", names(dynamics$comp_codes), names(dynamics$incidence_codes))
        censusmat[,"time"] <- obstimes

        # get the list of vectors of observation times for each measurement process
        obstime_inds <- 
          lapply(meas_procs,
                 FUN = function(proc) 
                   match(round(proc$obstimes, digits = 8),
                         round(obstimes, digits = 8)) - 1)

        # census the time-varying covariates at observation times
        # tcovar_censmat <- 
          # build_census_path(dynamics$tcovar, 
                            # obstimes, 
                            # seq_len(ncol(dynamics$tcovar) - 1))
        # colnames(tcovar_censmat) <- colnames(dynamics$tcovar)

        # incidence and prevalence codes for the LNA. if prevalence is required
        # for any of the measurement processes, compute prevalence (otherwise
        # leave the entries in the census matrix uninitialized)
        lna_prevalence <- any(sapply(meas_procs, "[[", "incidence") == FALSE)
        lna_incidence  <- any(sapply(meas_procs, "[[", "incidence"))

        if(!is.null(dynamics$lna_pointers)) {
                # identify reactions for which incidence needs to be computed
                if(lna_incidence) {
                        rate_names <- rownames(dynamics$flow_matrix_lna)

                        # C++ column indices of LNA count compartments for which incidence is desired
                        incidence_codes_lna <- which(!is.na(match(rate_names, colnames(censusmat))))
                        names(incidence_codes_lna) <- rate_names[incidence_codes_lna]
                } else {
                  incidence_codes_lna <- NULL
                }
        } else {
                incidence_codes_lna <- -1
        }

        if(!is.null(dynamics$ode_pointers)) {
                # identify reactions for which incidence needs to be computed
                if(lna_incidence) {
                        rate_names <- rownames(dynamics$flow_matrix_ode)

                        # C++ column indices of LNA count compartments for which incidence is desired
                        incidence_codes_ode <- which(!is.na(match(rate_names, colnames(censusmat))))
                        names(incidence_codes_ode) <- rate_names[incidence_codes_ode]
                } else {
                  incidence_codes_ode <- NULL
                }
        } else {
                incidence_codes_ode <- -1
        }

        # generate the measurement process list
        meas_process <- list(data                = data,
                             meas_procs          = meas_procs,
                             meas_pointers       = meas_pointers,
                             meas_pointers_lna   = meas_pointers_lna,
                             obstimes            = obstimes,
                             obstime_inds        = obstime_inds,
                             obsmat              = obsmat,
                             obscomp_codes       = obscomp_codes,
                             measproc_indmat     = measproc_indmat,
                             meas_inds           = meas_inds,
                             censusmat           = censusmat,
                             # tcovar_censmat      = tcovar_censmat,
                             lna_incidence       = lna_incidence,
                             lna_prevalence      = lna_prevalence,
                             ode_incidence       = lna_incidence,
                             ode_prevalence      = lna_prevalence,
                             incidence_codes_lna = incidence_codes_lna,
                             incidence_codes_ode = incidence_codes_ode,
                             meas_proc_code      = meas_proc_code)

        return(meas_process)
}
fintzij/stemr documentation built on March 25, 2022, 12:25 p.m.