R/tempo_mcmc.R

Defines functions tempo_mcmc

Documented in tempo_mcmc

#' tempo_mcmc
#'
#' MCMC sampler for the discrete-time binary state-transition model.
#'
#' @param y dataframe; A dataframe with 3 columns: \code{obs_id}
#' containing a unique ID for each observation; \code{c1}, the time step of the
#' last observed 0 (or \code{NA} if a 0 was never observed); and \code{c2}, the
#' time step of the first observed 1 (or \code{NA} if a 1 was never observed).
#' In the case of exact observations, e.g. it is known that the even occured on
#' day 42, \code{c1} should be 41, and \code{c2} should be 42.
#' @param covariates list; A list of \code{c} covariates matrices, each of
#' dimension \code{[n,t]}, where n is the number of observations, t is the
#' number of time steps during which state transition could take place, and c
#' is the number of covariates. Use the function
#' \code{tempo_wrangle} provided in this package to convert long form
#' covariate data to the format necessary for input into \code{tempo_mcmc}.
#' @param beta_start vector; a vector of starting values for model regression
#' coefficients of length \code{c + 1}, where \code{c} is the number of
#' covariates. The first value of \code{beta_start} is for the intercept, and
#' the remaining \code{c} values correspond to the covariates in the same order
#' as \code{covariates}.
#' @param n_iter integer; The number of mcmc iterations per chain after any burn
#' in. Defaults to 1000.
#' @param n_burnin integer; The number of samplers per chain to discard as burn
#' in. Defaults to 1000.
#' @param n_chains integer; the number of MCMC chains to sample. Defaults to 3.
#' @param n_workers integer; the number of parallel workers to use. Defaults to
#' 1.
#' @param pp_checks Character vector; one or both of "mean", "variance".
#' Specifies the statistics for which posterior predictive checks should be run.
#' Set to \code{NA} to prevent posterior predictive checks from being
#' calculated.
#' @param beta_prior_mean vector; Prior means for model regression coefficients
#' in the the same order as beta_start. If \code{NULL} (the default) then priors
#' means are set to 0.
#' @param beta_prior_sd vector; Prior standard deviations for model
#' regression coefficients in the the same order as beta_start. If \code{NULL}
#' (the default) then priors standard deviations are set to 5. Note that 
#' standard deviations that are too restrictive (including the somewhat standard
#' value of 1.5) can influence model fit significantly because posterior 
#' parameter magnitudes can be very large.
#' @param random_effects vector; a vector of the names of covariates that
#' should be modeled with group-level (random) effects.
#' Names correspond to the names of the list elements of \code{covariates}. Use
#' "intercept" to specify if the intercept should be modeled with random
#' effects. For example, \code{random_effects = c("intercept", "precipitation",
#' "temperature")} would specify that random effects should be modeled on the
#' intercept and "precipitation" and "temperature" covariates. Defaults to
#' \code{NULL} in which case no random effects are modeled.
#' @param group_ids vector; If \code{random_effects != NULL}, indicies specifying
#' to which group each row in \code{y} belongs. Defaults to \code{NULL}.
#' @param sd_eps_start vector; a vector of starting values for random effect
#' standard deviations. Values correspond to the items in \code{random_effects}.
#' May be set to \code{NULL} if \code{random_effects = NULL}.
#' @param correlated Boolean; Should correlations between random effects be
#' modeled? Only applies if there is more than one random effect modeled.
#' Defaults to \code{TRUE}.
#' @param monitor_random_effects Boolean; Defaults to \code{FALSE}. Return 
#' posterior samples for group effects (model parameter \code{eps})? 
#' Only applies when modeling random effects.
#' @import nimble
#' @importFrom parallel makeCluster stopCluster
#' @importFrom doParallel registerDoParallel
#' @importFrom foreach "%dopar%" foreach
#' @importFrom stringr str_replace regex str_sub str_detect str_split
#' @rdname tempo_mcmc
#' @export
tempo_mcmc <- function(y,
                       covariates,
                       beta_start,
                       n_iter = 1000,
                       n_burnin = 1000,
                       n_chains = 3,
                       n_workers = 1,
                       pp_checks = c("mean", "variance"),
                       beta_prior_mean = NULL,
                       beta_prior_sd = NULL,
                       random_effects = NULL,
                       group_ids = NULL,
                       sd_eps_start = NULL,
                       correlated = TRUE,
                       monitor_random_effects = FALSE
                       ) {
  if (is.null(beta_prior_mean)) {
    beta_prior_mean <- rep(0, length(covariates) + 1)
  }
  if (is.null(beta_prior_sd)) {
    beta_prior_sd <- rep(5, length(covariates) + 1)
  }

  # append interecept term to covariates
  covariates <- c(list(intercept = array(1, dim = dim(covariates[[1]]))),
                       covariates)

  # get indices of random effects
  if (!is.null(random_effects)) { # TODO double check influence of order of random_effects elements
    beta_random_indices <- seq_along(covariates)[names(covariates)
                                            %in% random_effects]
  } else {
    beta_random_indices <- NULL
  }

  # convert covariates to array
  covariate_array <- do.call(abind, c(covariates, along = 3))
  dimnames(covariate_array) <- list(rownames(covariates[[2]]),
                                    NULL,
                                    names(covariates))
  # Convert NAs in y to correct numbers
  y$c1[is.na(y$c1)] <- 0
  y$c2[is.na(y$c2)] <- dim(covariate_array)[2] + 1

  # get model inputs
  inputs <- nimble_inputs(y, covariate_array, beta_prior_mean, beta_prior_sd,
                          beta_start, beta_random_indices, group_ids,
                          sd_eps_start, correlated, pp_checks) 
  
  if (is.null(random_effects)) {
    monitors <- c("Beta")
  } else if (length(beta_random_indices) == 1 | !correlated) {
    monitors <- c("mean_beta", "sd_eps")
    if (monitor_random_effects) {
      monitors <- c(monitors, "eps") 
    }
  } else {
    monitors <- c("mean_beta", "sd_eps", "rho")
    if (monitor_random_effects) {
      monitors <- c(monitors, "eps") 
    }
  }

  # Update variables to monitor based on pp_checks
  if ("mean" %in% pp_checks) {
    monitors <- c(monitors, "p_mean")
  }
  if ("variance" %in% pp_checks) {
    monitors <- c(monitors, "p_var")
  }

  # Parallel or serial processing
  if (n_workers > 1) {
    cat(
      sprintf(
        "Starting up nimble to run %s chain(s) in parallel on %s workers. Progress bars not shown in parallel mode...\n",
        n_chains,
        n_workers)
    )
    message("Compiling models and running samplers in parallel. This may take a while...")
    cl <- makeCluster(n_workers)
    registerDoParallel(cl)
    out <- foreach(i = 1:n_chains,
                   .combine = list,
                   .packages = c("nimble", "tempo"),
                   .multicombine = TRUE) %dopar% {
        
      suppressWarnings(registerDistributions(list(
       dtvgeom_nimble = list(
         BUGSdist = "dtvgeom_nimble(prob)",
         pqAvail = TRUE,
         discrete = TRUE,
         range = c(1, Inf),
         types = c("prob = double(1)")
       )), verbose = F, userEnv = .GlobalEnv))

      pheno_model <- suppressMessages(
        nimbleModel(code = eval(parse(text = nimble_model(beta_random_indices,
                                                          correlated))),
                    name = "pheno",
                    constants = inputs$pheno_consts,
                    data = inputs$pheno_data,
                    inits = inputs$pheno_inits # TODO: [[i]] # different inits for each chain
        )
      )
      
      mcmc_cfg <- configureMCMC(pheno_model,
                                monitors = monitors)
      
      pheno_mcmc <- buildMCMC(mcmc_cfg)
      
      compiled_mcmc <- suppressMessages(compileNimble(pheno_model,
                                                      pheno_mcmc))
        
      out_temp <- runMCMC(compiled_mcmc$pheno_mcmc,
                          niter = n_iter + n_burnin,
                          nburnin = n_burnin)
      
      return(out_temp)
      }
    stopCluster(cl)
    names(out) <- paste0("chain_", seq_len(n_chains))
  } else {
    cat(sprintf("Starting up nimble to run %s chain(s) in serial\n",
                n_chains))
    message("Building and compiling model. This may take a few minutes...")
    
    suppressWarnings(registerDistributions(list(
      dtvgeom_nimble = list(
        BUGSdist = "dtvgeom_nimble(prob)",
        pqAvail = TRUE,
        discrete = TRUE,
        range = c(1, Inf),
        types = c("prob = double(1)")
      )), verbose = F, userEnv = .GlobalEnv))
    
    pheno_model <- suppressMessages(
      nimbleModel(code = eval(parse(text = nimble_model(beta_random_indices,
                                                        correlated))),
                  name = "pheno",
                  constants = inputs$pheno_consts,
                  data = inputs$pheno_data,
                  inits = inputs$pheno_inits
      )
    )
    
    mcmc_cfg <- configureMCMC(pheno_model,
                              monitors = monitors)

    pheno_mcmc <- buildMCMC(mcmc_cfg)
    
    compiled_mcmc <- suppressMessages(compileNimble(pheno_model, pheno_mcmc))

    out <- runMCMC(compiled_mcmc$pheno_mcmc,
                   niter = n_iter + n_burnin,
                   nburnin = n_burnin,
                   nchains = n_chains)
    
    if(n_chains == 1) {
      out <- list(out)
    }
    names(out) <- paste0("chain_", seq_len(n_chains))
  }
  
  # Calculate Bayesian p values
  # mean
  if (any(c("variance", "mean") %in% pp_checks)) {
    draws_all <- do.call(rbind, out)
    
    if ("mean" %in% pp_checks) {
      ppp_mean <- mean(draws_all[, "p_mean"])
    } else {
      ppp_mean <- "not calculated"
    }
    
    if ("variance" %in% pp_checks) {
      ppp_var <- mean(draws_all[, "p_var"])
    } else {
      ppp_var <- "not calculated"
    }
    
  } else {
    ppp_mean <- "not calculated"
    ppp_var <- "not calculated"
  }
  
  # Remove posterior predictive check vectors from outputs
  # Remove pp check vectors, lazy for loops
  if ("mean" %in% pp_checks) {
    for (c in 1:n_chains) {
      out[[c]] <- out[[c]][, !colnames(out[[c]]) %in% c("p_mean")]
    }
  }
  if ("variance" %in% pp_checks) {
    for (c in 1:n_chains) {
      out[[c]] <- out[[c]][, !colnames(out[[c]]) %in% c("p_var")]
    }
  }
  

  # Rename output elements
  for (d in 1:n_chains) { #length(out) is number of chains
    out[[d]] <- as.data.frame(out[[d]])
    
    ## rename betas
    beta_columns <- str_detect(colnames(out[[d]]), 
                               pattern = regex("beta\\[.+?\\]", 
                                               ignore_case = TRUE))

    # The ones with random effects
    colnames(out[[d]])[beta_columns][beta_random_indices] <- 
      paste0("mean_beta_", names(covariates)[beta_random_indices])

    # The ones without random effects
    if (is.null(random_effects)) {
      colnames(out[[d]])[beta_columns] <- 
        paste0("beta_", names(covariates))
    } else {
      colnames(out[[d]])[beta_columns][-beta_random_indices] <- 
        paste0("beta_", names(covariates)[-beta_random_indices])
    }

    ## Rename standard deviations
    sd_eps_columns <- str_detect(colnames(out[[d]]), pattern = "^sd_eps")
    colnames(out[[d]])[sd_eps_columns] <- paste("sd_eps",
                                                names(covariates)[beta_random_indices],
                                                sep = "_")
    
    # Rename rhos if needed
    if (correlated & length(beta_random_indices) > 1) {
      rho_names <- paste("rho",
                         names(covariates)[beta_random_indices][inputs$pheno_consts$rho_indices[, 1]],
                         names(covariates)[beta_random_indices][inputs$pheno_consts$rho_indices[, 2]],
                         sep = "_")
      rho_columns <- str_detect(colnames(out[[d]]), pattern = "rho\\[.+?\\]")
      colnames(out[[d]])[rho_columns] <- rho_names
    }
    
    if (monitor_random_effects) {
      eps_columns <- str_detect(names(out[[d]]), pattern = "^eps")
      eps_names_old <- names(out[[d]])[eps_columns]
      all_indices <- do.call(rbind, str_split(str_sub(eps_names_old, 5, -2), pattern = ", "))
      
      covariate_indices <- as.numeric(all_indices[, 1])
      group_indices <- as.numeric(all_indices[, 2])
      
      eps_names <- paste0("eps_",
                         names(covariates)[beta_random_indices[covariate_indices]],
                         "_",
                         group_indices)
      
      colnames(out[[d]])[eps_columns] <- eps_names
      
    }
    
    # Reorder the columns sensibly (runMCMC puts them in alphabetical)
    out[[d]] <- cbind(
      out[[d]][, str_detect(names(out[[d]]), pattern = "^beta"), drop = FALSE],
      out[[d]][, str_detect(names(out[[d]]), pattern = "^mean"), drop = FALSE],
      out[[d]][, str_detect(names(out[[d]]), pattern = "^sd"), drop = FALSE],
      out[[d]][, str_detect(names(out[[d]]), pattern = "^rho"), drop = FALSE],
      out[[d]][, str_detect(names(out[[d]]), pattern = "^eps")]
    )
  }
  
  str_detect(names(out[[d]]), pattern = "^beta")
  ## Assign attributes to out for later functions
  # general attributes
  attr(out, "n_chains") <- n_chains
  attr(out, "covariates") <- covariate_array
  attr(out, "y") <- y
  attr(out, "ppp_mean") <- ppp_mean
  attr(out, "ppp_variance") <- ppp_var
  
  # specific to random effects
  if (!is.null(random_effects)) {
    attr(out, "rand_eff_bool") <- TRUE
    attr(out, "rand_eff_ind") <- beta_random_indices
    attr(out, "correlated") <- correlated & (length(beta_random_indices) > 1)
    attr(out, "monitor_random_effects") <- monitor_random_effects
    attr(out, "group_ids") <- group_ids
  } else {
    attr(out, "monitor_random_effects") <- FALSE
    attr(out, "correlated") <- FALSE
    attr(out, "rand_eff_bool") <- FALSE
    attr(out, "rand_eff_ind") <- NA
    attr(out, "group_ids") <- NA
  }

  out
}
vlandau/tempo documentation built on March 18, 2020, 12:04 a.m.