#' Extract posterior draws from a stanfit
#'
#' Extracts the posterior draws excluding warmup. Formats the results into a
#' tidy tibble.
#'
#' @param f An object of class `stanfit`.
#' @param ... The names of the parameters to be extracted. If empty, all the
#' parameters are extracted.
#' @param with_div Should divergence information be included? Defaults to
#' `FALSE`.
#' @param with_max_tree Should max treedepth information be included?
#' Defaults to `FALSE`.
#'
#' @return A tibble with `d+3` columns where `d` is the maximum dimensionality
#' of the extracted parameters.
#' * `s` The index of the draw.
#' * `d1:dn` The dimensions of the parameters. `NA` if the parameter is of lower dimension.
#' * `par` The name of the parameter, factor.
#' * `val` The value of the parameter.
#'
#' @export
extract_tidy <- function(f, ..., with_div = FALSE, with_max_tree = FALSE) {
par_names <- enquos(...) %>% map_chr(as_name)
model_pars <- names(f@sim$dims_oi)
if (is_empty(par_names)) par_names <- model_pars
if (any(!par_names %in% model_pars)) stop("Unknown parameters specified.")
sample <- extract(f, par_names) %>%
map(arr_to_tbl) %>%
bind_rows(.id = "par") %>%
mutate_at("par", factor) %>%
select("s", matches("^d\\d+$"), "par", "val")
if (with_div) sample <- inner_join(sample, extract_div_tidy(f), "s")
if (with_max_tree) sample <- inner_join(sample, extract_max_tree_tidy(f), "s")
sample
}
arr_to_tbl <- function(arr) {
nd <- length(dim(arr))
if (nd == 1L) return(tibble(s = seq_along(arr), val = arr))
arr_tbl <- as_tibble(arr)
mutate(arr_tbl, s = 1:n()) %>%
# gather colnames
gather("colname", "val", -!!sym("s")) %>%
# separate colname=2.3 into columns d1=2 and d2=3
separate("colname", str_c("d", 1:(nd-1L)), "\\.") %>%
# if nd == 2, the names are V3 instead of 3
mutate_at("d1", str_remove, pattern = "^V") %>%
# parse integers
mutate_at(vars(matches("^d\\d+$")), as.integer)
}
#' Extract divergences from a stanfit
#'
#' Extracts to predictions excluding warmup. Formats the results into a tidy
#' tibble.
#'
#' @param f An object of class `stanfit`.
#'
#' @return A tibble with `2` columns:
#' * `s` The index of the draw.
#' * `div` Boolean value of whether the draw was divergent or not.
#'
#' @export
extract_div_tidy <- function(f) {
div <- get_divergent_iterations(f)
tibble(s = seq_along(div), div = div)
}
#' Extract max treedepth iterations from a stanfit
#'
#' Extracts excluding warmup. Formats the results into a tidy
#' tibble.
#'
#' @param f An object of class `stanfit`.
#'
#' @return A tibble with `2` columns:
#' * `s` The index of the draw.
#' * `max_tree` Boolean value of whether the maximum treedepth was saturated
#' or not.
#'
#' @export
extract_max_tree_tidy <- function(f) {
max_tree <- get_max_treedepth_iterations(f)
tibble(s = seq_along(max_tree), max_tree = max_tree)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.