R/extract.R

Defines functions extract_tidy arr_to_tbl extract_div_tidy extract_max_tree_tidy

Documented in extract_div_tidy extract_max_tree_tidy extract_tidy

#' Extract posterior draws from a stanfit
#'
#' Extracts the posterior draws excluding warmup. Formats the results into a
#' tidy tibble.
#'
#' @param f An object of class `stanfit`.
#' @param ... The names of the parameters to be extracted. If empty, all the
#' parameters are extracted.
#' @param with_div Should divergence information be included? Defaults to
#' `FALSE`.
#' @param with_max_tree Should max treedepth information be included?
#' Defaults to `FALSE`.
#'
#' @return A tibble with `d+3` columns where `d` is the maximum dimensionality
#' of the extracted parameters.
#' * `s` The index of the draw.
#' * `d1:dn` The dimensions of the parameters. `NA` if the parameter is of lower dimension.
#' * `par` The name of the parameter, factor.
#' * `val` The value of the parameter.
#'
#' @export
extract_tidy <- function(f, ..., with_div = FALSE, with_max_tree = FALSE) {
  par_names <- enquos(...) %>% map_chr(as_name)
  model_pars <- names(f@sim$dims_oi)

  if (is_empty(par_names)) par_names <- model_pars
  if (any(!par_names %in% model_pars)) stop("Unknown parameters specified.")

  sample <- extract(f, par_names) %>%
    map(arr_to_tbl) %>%
    bind_rows(.id = "par") %>%
    mutate_at("par", factor) %>%
    select("s", matches("^d\\d+$"), "par", "val")
  if (with_div) sample <- inner_join(sample, extract_div_tidy(f), "s")
  if (with_max_tree) sample <- inner_join(sample, extract_max_tree_tidy(f), "s")
  sample
}

arr_to_tbl <- function(arr) {
  nd <- length(dim(arr))
  if (nd == 1L) return(tibble(s = seq_along(arr), val = arr))
  arr_tbl <- as_tibble(arr)
  mutate(arr_tbl, s = 1:n()) %>%
    # gather colnames
    gather("colname", "val", -!!sym("s")) %>%
    # separate colname=2.3 into columns d1=2 and d2=3
    separate("colname", str_c("d", 1:(nd-1L)), "\\.") %>%
    # if nd == 2, the names are V3 instead of 3
    mutate_at("d1", str_remove, pattern = "^V") %>%
    # parse integers
    mutate_at(vars(matches("^d\\d+$")), as.integer)
}


#' Extract divergences from a stanfit
#'
#' Extracts to predictions excluding warmup. Formats the results into a tidy
#' tibble.
#'
#' @param f An object of class `stanfit`.
#'
#' @return A tibble with `2` columns:
#' * `s` The index of the draw.
#' * `div` Boolean value of whether the draw was divergent or not.
#'
#' @export
extract_div_tidy <- function(f) {
  div <- get_divergent_iterations(f)
  tibble(s = seq_along(div), div = div)
}

#' Extract max treedepth iterations from a stanfit
#'
#' Extracts  excluding warmup. Formats the results into a tidy
#' tibble.
#'
#' @param f An object of class `stanfit`.
#'
#' @return A tibble with `2` columns:
#' * `s` The index of the draw.
#' * `max_tree` Boolean value of whether the maximum treedepth was saturated
#'   or not.
#'
#' @export
extract_max_tree_tidy <- function(f) {
  max_tree <- get_max_treedepth_iterations(f)
  tibble(s = seq_along(max_tree), max_tree = max_tree)
}
paasim/stanutils documentation built on July 19, 2019, 12:47 a.m.