Nothing
# suppress messages from ggplot2
quietgg <- function(gg) {
suppressMessages(suppressWarnings(print(gg)))
invisible(gg)
}
# traceplot ---------------------------------------------------------------
stan_trace <- function(object, pars, include = TRUE,
unconstrain = FALSE,
inc_warmup = FALSE,
nrow = NULL, ncol = NULL,
..., window = NULL) {
.check_object(object, unconstrain)
plot_data <- .make_plot_data(object, pars, include, inc_warmup, unconstrain)
thm <- rstanvis_theme()
clrs <- rep_len(rstanvis_aes_ops("chain_colors"), plot_data$nchains)
base <-
ggplot2::ggplot(
plot_data$samp,
ggplot2::aes(x = .data$iteration, y = .data$value, color = .data$chain)
)
if (inc_warmup) base <- base +
ggplot2::annotate("rect", xmin = -Inf, xmax = plot_data$warmup,
ymin = -Inf, ymax = Inf, fill = rstanvis_aes_ops("grays")[2L])
graph <-
base +
ggplot2::geom_path(...) +
ggplot2::scale_color_manual(values = clrs) +
ggplot2::labs(x="",y="") +
thm
if (plot_data$nparams == 1)
graph <- graph + ggplot2::ylab(unique(plot_data$samp$parameter))
else
graph <- graph + ggplot2::facet_wrap(~parameter, nrow = nrow, ncol = ncol, scales = "free")
if (!is.null(window)) {
if (!is.numeric(window) || length(window) != 2)
stop("'window' should be a numeric vector of length 2.")
graph <- graph + ggplot2::coord_cartesian(xlim = window)
}
graph
}
# scatterplot -------------------------------------------------------------
stan_scat <- function(object, pars, unconstrain = FALSE, inc_warmup = FALSE,
nrow = NULL, ncol = NULL, ...) {
.check_object(object, unconstrain)
thm <- rstanvis_theme()
dots <- .add_aesthetics(list(...), c("fill", "pt_color", "pt_size", "alpha", "shape"))
if (missing(pars) || length(pars) != 2L)
stop("'pars' must contain exactly two parameter names", call. = FALSE)
# ndivergent <-
# .sampler_params_post_warmup(object, "divergent__", as.df = TRUE)[, -1L]
# treedepth <-
# .sampler_params_post_warmup(object, "treedepth__", as.df = TRUE)[, -1L]
# max_td <- .max_td(object)
# div <- unname(rowSums(ndivergent) == 1)
# hit_max_td <- sapply(1:nrow(treedepth), function(i) any(treedepth[i,] >= max_td))
plot_data <- .make_plot_data(
object,
pars = pars,
include = TRUE,
inc_warmup = inc_warmup,
unconstrain = unconstrain
)
p1 <- plot_data$samp$parameter == pars[1]
val1 <- plot_data$samp[p1, "value"]
val2 <- plot_data$samp[!p1, "value"]
df <- data.frame(x = val1, y = val2)
# nchains <- plot_data$nchains
# sel <- seq_len(nrow(df) / nchains)
# div <- df[div[sel], ]
# td <- df[hit_max_td[sel], ]
base <- ggplot2::ggplot(df, ggplot2::aes(x = .data$x, y = .data$y))
graph <-
base +
do.call(ggplot2::geom_point, dots) +
# geom_point(data = div, aes(x = .data$x, y = .data$y), color = "red") +
# geom_point(data = td, aes(x = .data$x, y = .data$y), color = "yellow") +
ggplot2::labs(x = pars[1], y = pars[2]) +
thm
graph
}
# histograms ---------------------------------------------------------------
stan_hist <- function(object, pars, include = TRUE,
unconstrain = FALSE,
inc_warmup = FALSE,
nrow = NULL, ncol = NULL,
...) {
.check_object(object, unconstrain)
dots <- .add_aesthetics(list(...), c("fill", "color"))
plot_data <- .make_plot_data(object, pars, include, inc_warmup, unconstrain)
thm <- rstanvis_hist_theme()
base <- ggplot2::ggplot(plot_data$samp, ggplot2::aes(x = .data$value, y = ggplot2::after_stat(.data$density)))
graph <-
base +
do.call(ggplot2::geom_histogram, dots) +
ggplot2::xlab("") +
thm
if (plot_data$nparams == 1)
graph + ggplot2::xlab(unique(plot_data$samp$parameter))
else
graph + ggplot2::facet_wrap(~parameter, nrow = nrow, ncol = ncol, scales = "free")
}
# densities -----------------------------------------------------------
stan_dens <- function(object, pars, include = TRUE,
unconstrain = FALSE,
inc_warmup = FALSE,
nrow = NULL, ncol = NULL,
...,
separate_chains = FALSE) {
.check_object(object, unconstrain)
plot_data <- .make_plot_data(object, pars, include, inc_warmup, unconstrain)
clrs <- rep_len(rstanvis_aes_ops("chain_colors"), plot_data$nchains)
thm <- rstanvis_hist_theme()
base <- ggplot2::ggplot(plot_data$samp, ggplot2::aes(x = .data$value)) + ggplot2::xlab("")
if (!separate_chains) {
dots <- .add_aesthetics(list(...), c("fill", "color"))
graph <-
base +
do.call(ggplot2::geom_density, dots) +
thm
} else {
dots <- .add_aesthetics(list(...), c("color", "alpha"))
dots$mapping <- ggplot2::aes(fill = .data$chain)
graph <-
base +
do.call(ggplot2::geom_density, dots) +
ggplot2::scale_fill_manual(values = clrs) +
thm
}
if (plot_data$nparams == 1)
graph + ggplot2::xlab(unique(plot_data$samp$parameter))
else
graph + ggplot2::facet_wrap(~parameter, nrow = nrow, ncol = ncol, scales = "free")
}
# autocorrelation ---------------------------------------------------------
stan_ac <- function(object, pars, include = TRUE,
unconstrain = FALSE,
inc_warmup = FALSE,
nrow = NULL, ncol = NULL,
...,
separate_chains = FALSE,
lags = 25, partial = FALSE) {
.check_object(object, unconstrain)
plot_data <- .make_plot_data(object, pars, include, inc_warmup, unconstrain)
clrs <- rep_len(rstanvis_aes_ops("chain_colors"), plot_data$nchains)
thm <- rstanvis_theme()
dots <- .add_aesthetics(list(...), c("size", "color", "fill"))
dat_args <- list(dat = plot_data$samp, lags = lags, partial = partial)
dat_fn <- ifelse(plot_data$nparams == 1, ".ac_plot_data",".ac_plot_data_multi")
ac_dat <- do.call(dat_fn, dat_args)
if (!separate_chains) {
dots$position <- "dodge"
dots$stat <- "summary"
dots$fun <- "mean"
dots$fun.data <- "mean_se"
y_lab <- paste("Avg.", if (partial) "partial", "autocorrelation")
ac_labs <- ggplot2::labs(x = "Lag", y = y_lab)
y_scale <- ggplot2::scale_y_continuous(breaks = seq(0, 1, 0.25))
base <- ggplot2::ggplot(ac_dat, ggplot2::aes(x = .data$lag, y = .data$ac))
graph <- base +
do.call(ggplot2::geom_bar, dots) +
y_scale +
ac_labs +
thm
if (plot_data$nparams == 1) {
y_lab <- ggplot2::ylab(paste0(y_lab, " (", pars,")"))
return(graph + y_lab)
}
else return(graph + ggplot2::facet_wrap(~parameters, nrow = nrow, ncol = ncol,
scales = "free_x"))
}
dots$position <- "identity"
dots$stat <- "identity"
ac_labs <- ggplot2::labs(x = "Lag", y = if (partial)
"Partial autocorrelation" else "Autocorrelation")
y_scale <- ggplot2::scale_y_continuous(breaks = seq(0, 1, 0.25),
labels = c("0","","0.5","",""))
base <- ggplot2::ggplot(ac_dat, ggplot2::aes(x = .data$lag, y = .data$ac))
graph <- base +
do.call(ggplot2::geom_bar, dots) +
y_scale +
ac_labs +
thm
if (plot_data$nparams == 1) {
graph <- graph + ggplot2::facet_wrap(~chains, nrow = nrow, ncol = ncol)
return(graph)
} else { # nParams > 1
graph <- graph + ggplot2::facet_grid(parameters ~ chains, scales = "free_x")
return(graph)
}
}
# parameter estimates -----------------------------------------------------
stan_plot <- function(object, pars, include = TRUE, unconstrain = FALSE,
...) {
inc_warmup <- FALSE
.check_object(object, unconstrain)
thm <- rstanvis_multiparam_theme()
plot_data <- .make_plot_data(object, pars, include, inc_warmup, unconstrain)
color_by_rhat <- FALSE # FIXME
dots <- list(...)
defs <- list(point_est = "median", show_density = FALSE,
show_outer_line = TRUE, ci_level = 0.8, outer_level = 0.95,
fill_color = rstanvis_aes_ops("fill"),
outline_color = rstanvis_aes_ops("color"),
est_color = rstanvis_aes_ops("color"))
args <- names(defs)
dotenv <- list()
for (j in seq_along(args)) {
if (args[j] %in% names(dots))
dotenv[[args[j]]] <- dots[[args[j]]]
else dotenv[[args[j]]] <- defs[[j]]
}
if (!(dotenv[["point_est"]] %in% c("mean", "median")))
stop("Point estimate should be either 'mean' or 'median'", call. = FALSE)
if (color_by_rhat)
stop("'color_by_rhat' not yet available", call. = FALSE)
if (dotenv[["ci_level"]] > dotenv[["outer_level"]])
stop("'ci_level' should be less than 'outer_level'", call. = FALSE)
ci_level <- dotenv[["ci_level"]]
outer_level <- dotenv[["outer_level"]]
message("ci_level: ", ci_level," (",100 * ci_level, "% intervals)")
message("outer_level: ", outer_level," (",100 * outer_level, "% intervals)")
outer_level <- dotenv[["outer_level"]]
probs.use <- c(0.5 - outer_level / 2, 0.5 - ci_level / 2, 0.5,
0.5 + ci_level / 2, 0.5 + outer_level / 2)
samp <- plot_data$samp
nparams <- plot_data$nparams
statmat <- as.matrix(aggregate(samp$value, by = list(parameter = samp$parameter),
FUN = function(x,...) c(mean(x), quantile(x,...)),
probs = probs.use))
param_names <- rownames(statmat) <- statmat[, 1L]
statmat <- apply(statmat[, -1L, drop=FALSE], 1:2, as.numeric)
colnames(statmat) <- c("mean", "2.5%", "25%", "50%", "75%", "97.5%")
y <- as.numeric(seq(plot_data$nparams, 1, by = -1))
xlim.use <- c(min(statmat[,2L]), max(statmat[,6L]))
xlim.use <- xlim.use + diff(xlim.use) * c(-.05, .05)
xy.df <- data.frame(params = rownames(statmat), y, statmat)
colnames(xy.df) <- c("params", "y", "mean", "ll", "l", "m", "h", "hh")
if (dotenv[["point_est"]] == "mean") xy.df$m <- xy.df$mean
p.base <- ggplot2::ggplot(xy.df)
p.name <- ggplot2::scale_y_continuous(breaks = y, labels = param_names,
limits = c(0.5, nparams + 1))
p.all <- p.base + ggplot2::xlim(xlim.use) + p.name + thm
show_density <- dotenv[["show_density"]]
outline_color <- dotenv[["outline_color"]] %ifNULL% rstanvis_aes_ops("color")
fill_color <- dotenv[["fill_color"]]
est_color <- dotenv[["est_color"]]
if (dotenv[["show_outer_line"]] || show_density) {
p.ci <-
ggplot2::geom_segment(
mapping = ggplot2::aes(
x = .data$ll,
xend = .data$hh,
y = .data$y,
yend = .data$y
),
color = outline_color
)
p.all <- p.all + p.ci
}
if (show_density) { # plot kernel density estimate
npoint.den <- 512
y.den <- x.den <- matrix(0, nrow = npoint.den, ncol = nparams)
for(i in 1:nparams){
d.temp <- density(samp[samp$parameter == param_names[i], "value"],
from = statmat[i,2L],
to = statmat[i,6L],
n = npoint.den)
x.den[,i] <- d.temp$x
y.max <- max(d.temp$y)
y.den[,i] <- d.temp$y / y.max * 0.8 + y[i]
}
df.den <- data.frame(x = as.vector(x.den), y = as.vector(y.den),
name = rep(param_names, each = npoint.den))
p.den <-
ggplot2::geom_line(
data = df.den,
mapping = ggplot2::aes(x = .data$x, y = .data$y, group = .data$name),
color = outline_color
)
#shaded interval
y.poly <- x.poly <- matrix(0, nrow = npoint.den + 2, ncol = nparams)
for(i in 1:nparams){
d.temp <- density(samp[samp$parameter == param_names[i], "value"],
from = statmat[i, 3L],
to = statmat[i, 5L],
n = npoint.den)
x.poly[,i] <- c(d.temp$x[1L], as.vector(d.temp$x), d.temp$x[npoint.den])
y.max <- max(d.temp$y)
y.poly[,i] <- as.vector(c(0, as.vector(d.temp$y) / y.max * 0.8, 0) + y[i])
}
df.poly <- data.frame(x = as.vector(x.poly), y = as.vector(y.poly),
name = rep(param_names, each = npoint.den + 2))
p.poly <- ggplot2::geom_polygon(data = df.poly, mapping=ggplot2::aes(x = .data$x, y = .data$y, group = .data$name, fill = .data$y))
p.col <- ggplot2::scale_fill_gradient(low = fill_color, high = fill_color, guide = "none")
#point estimate
if (color_by_rhat) {
rhat_colors <- dotenv[["rhat_colors"]]
p.point <- ggplot2::geom_segment(ggplot2::aes(x = .data$m, xend = .data$m, y = .data$y, yend = .data$y + 0.25,
color = .data$rhat_id), linewidth = 1.5)
p.all + p.poly + p.den + p.col + p.point + rhat_colors #+ rhat_lgnd
} else {
p.point <- ggplot2::geom_segment(ggplot2::aes(x = .data$m, xend = .data$m, y = .data$y, yend = .data$y + 0.25),
colour = est_color, linewidth = 1.5)
p.all + p.poly + p.den + p.col + p.point
}
} else {
p.ci.2 <- ggplot2::geom_segment(ggplot2::aes(x = .data$l, xend = .data$h, y = .data$y, yend = .data$y),
colour = fill_color, linewidth = 2)
if (color_by_rhat) {
p.point <- ggplot2::geom_point(ggplot2::aes(x = .data$m, y = .data$y, fill = .data$rhat_id),
color = "black", shape = 21, size = 4)
p.all + p.ci.2 + p.point + rhat_colors # + rhat_lgnd
} else {
p.point <- ggplot2::geom_point(ggplot2::aes(x = .data$m, y = .data$y), size = 4,
color = fill_color, fill = est_color, shape = 21)
p.all + p.ci.2 + p.point
}
}
}
# rhat, ess, mcse ---------------------------------------------------------
stan_rhat <- function(object, pars, ...) {
.check_object(object)
.vb_check(object)
if (missing(pars)) pars <- NULL
.rhat_neff_mcse_hist(which = "rhat", object = object, pars = pars, ...)
}
stan_ess <- function(object, pars, ...) {
.check_object(object)
.vb_check(object)
if (missing(pars)) pars <- NULL
.rhat_neff_mcse_hist(which = "n_eff_ratio", object = object, pars = pars, ...)
}
stan_mcse <- function(object, pars, ...) {
.check_object(object)
.vb_check(object)
if (missing(pars)) pars <- NULL
.rhat_neff_mcse_hist(which = "mcse_ratio", object = object, pars = pars, ...)
}
# NUTS --------------------------------------------------------------------
stan_diag <- function(object,
information = c("sample","stepsize","treedepth","divergence"),
chain = 0, ...) {
.vb_check(object)
if ("pars" %in% names(list(...)))
stop("'stan_diag' does not accept a 'pars' argument.")
nchains <- if (is.stanreg(object))
ncol(object$stanfit) else ncol(object)
if (!isTRUE(nchains > 1))
stop("'stan_diag' requires more than one chain.", call. = FALSE)
info <- match.arg(information)
fn <- paste0("stan_", info)
do.call(fn, list(object, chain, ...))
}
stan_stepsize <- function(object, chain = 0, ...) {
.nuts_args_check(...)
thm <- rstanvis_theme()
stepsize <- .sampler_params_post_warmup(object, "stepsize__", as.df = TRUE)
lp <- extract(if (is.stanreg(object)) object$stanfit else object,
pars = "lp__", permuted = FALSE)[,,1L]
graphs <- list()
graphs$stepsize_vs_lp <- .sampler_param_vs_param(p = lp, sp = stepsize[,-1L],
p_lab = .LP_LAB,
sp_lab = .STEPSIZE_LAB,
chain = chain, violin = TRUE)
metrop <- .sampler_params_post_warmup(object, "accept_stat__", as.df = TRUE)
graphs$stepsize_vs_metrop <-
.sampler_param_vs_sampler_param_violin(round(stepsize[,-1L], 4), metrop[,-1L],
lab_x = .STEPSIZE_LAB,
lab_y = .METROP_LAB,
chain = chain)
graphs <- lapply(graphs, function(x) x + thm)
.nuts_return(graphs, ...)
}
stan_sample <- function(object, chain = 0, ...) {
.nuts_args_check(...)
thm <- rstanvis_theme()
hist_thm <- rstanvis_hist_theme()
lp <- extract(if (is.stanreg(object)) object$stanfit else object,
pars = "lp__", permuted = FALSE)[,,1L]
lp_df <- as.data.frame(cbind(iterations = 1:nrow(lp), lp))
metrop <- .sampler_params_post_warmup(object, "accept_stat__", as.df = TRUE)
graphs <- list()
graphs$lp_hist <-
.p_hist(lp_df, lab = .LP_LAB, chain = chain)
graphs$metrop_hist <-
.p_hist(metrop, lab = .METROP_LAB, chain = chain) + xlim(0,1)
graphs <- lapply(graphs, function(x) x + thm)
graphs$metrop_vs_lp <-
.sampler_param_vs_param(p = lp, sp = metrop[,-1L], p_lab = .LP_LAB,
sp_lab = .METROP_LAB, chain = chain) +
thm
.nuts_return(graphs, ...)
}
stan_treedepth <- function(object, chain = 0, ...) {
.nuts_args_check(...)
thm <- rstanvis_theme()
hist_thm <- rstanvis_hist_theme()
lp <- extract(if (is.stanreg(object)) object$stanfit else object,
pars = "lp__", permuted = FALSE)[,,1L]
treedepth <- .sampler_params_post_warmup(object, "treedepth__", as.df = TRUE)
ndivergent <- .sampler_params_post_warmup(object, "divergent__", as.df = TRUE)
metrop <- .sampler_params_post_warmup(object, "accept_stat__", as.df = TRUE)
graphs <- graphs_nd <- list()
graphs$treedepth_vs_lp <-
.sampler_param_vs_param(p = lp, sp = treedepth[, -1L],
p_lab = .LP_LAB,
sp_lab = .TREEDEPTH_LAB,
chain = chain, violin = TRUE)
graphs$treedepth_vs_metrop <-
.sampler_param_vs_sampler_param_violin(treedepth[,-1L], metrop[,-1L],
lab_x = .TREEDEPTH_LAB,
lab_y = .METROP_LAB,
chain = chain)
graphs_nd$treedepth_ndivergent <-
.treedepth_ndivergent_hist(treedepth, ndivergent, chain = chain,
divergent = "All")
any_nd <- any(ndivergent[,-1L] != 0)
if (any_nd) {
graphs_nd$treedepth_ndivergent0 <-
.treedepth_ndivergent_hist(treedepth, ndivergent, chain = chain,
divergent = 0)
graphs_nd$treedepth_ndivergent1 <-
.treedepth_ndivergent_hist(treedepth, ndivergent, chain = chain,
divergent = 1)
}
graphs <- lapply(graphs, function(x) x + thm)
graphs_nd <- lapply(graphs_nd, function(x) x + hist_thm)
graphs <- c(graphs, graphs_nd)
.nuts_return(graphs, ...)
}
stan_divergence <- function(object, chain = 0, ...) {
.nuts_args_check(...)
thm <- rstanvis_theme()
lp <- extract(if (is.stanreg(object)) object$stanfit else object,
pars = "lp__", permuted = FALSE)[,,1L]
ndivergent <- .sampler_params_post_warmup(object, "divergent__", as.df = TRUE)
metrop <- .sampler_params_post_warmup(object, "accept_stat__", as.df = TRUE)
graphs <- list()
graphs$ndivergent_vs_lp <-
.sampler_param_vs_param(p = lp, sp = ndivergent[, -1L],
p_lab = .LP_LAB,
sp_lab = .NDIVERGENT_LAB,
chain = chain, violin = TRUE)
graphs$ndivergent_vs_metrop <-
.sampler_param_vs_sampler_param_violin(ndivergent[,-1L], metrop[,-1L],
lab_x = .NDIVERGENT_LAB,
lab_y = .METROP_LAB,
chain = chain)
graphs <- lapply(graphs, function(x) x + thm)
.nuts_return(graphs, ...)
}
stan_par <- function(object, par, chain = 0, ...) {
if (missing(par))
stop("'par' must be specified", call. = FALSE)
if (is.stanreg(object))
object <- object$stanfit
if (!isTRUE(ncol(object) > 1))
stop("'stan_par' requires more than one chain.", call. = FALSE)
thm <- rstanvis_theme()
samp <- extract(object, pars = c("lp__", par), permuted = FALSE)
par_sel <- which(dimnames(samp)$parameters == par)
cntrl <- object@stan_args[[1L]]$control
if (is.null(cntrl))
max_td <- 11
else {
max_td <- cntrl$max_treedepth
if (is.null(max_td))
max_td <- 10
}
max_td <- .max_td(object)
metrop <- .sampler_params_post_warmup(object, "accept_stat__", as.df = TRUE)[,-1L]
stepsize <- .sampler_params_post_warmup(object, "stepsize__", as.df = TRUE)[,-1L]
ndivergent <- .sampler_params_post_warmup(object, "divergent__", as.df = TRUE)[,-1L]
treedepth <- .sampler_params_post_warmup(object, "treedepth__", as.df = TRUE)[,-1L]
hit_max_td <- apply(treedepth, 2L, function(y) as.numeric(y >= max_td))
graphs <- list()
par_samp <- samp[,, par_sel]
lp <- samp[,, -par_sel]
graphs[[paste0(par,"_vs_lp")]] <-
.sampler_param_vs_param(sp = as.data.frame(lp),
p = par_samp,
divergent = ndivergent,
hit_max_td = as.data.frame(hit_max_td),
sp_lab = .LP_LAB, p_lab = par,
chain = chain)
graphs[[paste0(par,"_vs_metrop")]] <-
.sampler_param_vs_param(p = par_samp, sp = metrop,
divergent = ndivergent,
hit_max_td = as.data.frame(hit_max_td),
p_lab = par, sp_lab = .METROP_LAB,
chain = chain)
graphs[[paste0(par,"_vs_stepsize")]] <-
.sampler_param_vs_param(p = par_samp, sp = stepsize,
p_lab = par, sp_lab = .STEPSIZE_LAB,
chain = chain, violin = TRUE)
graphs <- lapply(graphs, function(x) x + thm)
.nuts_return(graphs, ...)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.