R/cmdstan_momentum_draws.R

Defines functions cmdstan_momentum_draws

Documented in cmdstan_momentum_draws

#' Get momentum draws from a cmdstanr fit
#' object run with save_latent_dynamics = TRUE
#'
#' @export
#' @param fit A cmdstan fit object run with \code{save_latent_dynamics = TRUE}
#' @return A posterior matrix of momentum draws
cmdstan_momentum_draws <- function(fit) {
  if (!setequal(class(fit), c("CmdStanMCMC", "CmdStanFit", "R6"))) {
    msg <- "fit argument should be an MCMC fit object from cmdstanr"
    stop(msg)
  }

  latent_dynamics_files <- fit$latent_dynamics_files()
  latent_dynamics <- cmdstanr::read_cmdstan_csv(latent_dynamics_files)
  draws_df <- latent_dynamics$post_warmup_draws %>%
    posterior::as_draws_df()

  # Dimension of unconstrained space
  N <- (ncol(draws_df) - 4) / 3
  inds <- (N + 2):(2 * N + 1)

  # TODO: use quoted string literals
  # Would this do the same thing?
  #   col_names <- c(".chain", ".iteration", ".draw")
  #   col_inds <- match(col_names, names(draws_df))
  #   out <- draws_df %>% dplyr::select(inds, col_inds)
  out <- draws_df %>% dplyr::select(inds, `.chain`, `.iteration`, `.draw`)
  return(out)
}
jtimonen/stanbreaker documentation built on Jan. 20, 2021, 12:34 a.m.