Nothing
#' Plot posterior distribution from dataframe of posterior draws.
#' @description
#' `r lifecycle::badge('stable')`
#'
#' Plot the posterior distribution of all latent parameters using a dataframe of posterior draws from a `causact_graph` model.
#' @param drawsDF the dataframe output of `dag_numpyro(mcmc=TRUE)` where each column is a parameter and each row a single draw from a representative sample.
#' @param densityPlot If `TRUE`, each parameter gets its own density plot. If `FALSE` (recommended usage), parameters are grouped into facets based on whether they share the same prior or not. 10 and 90 percent credible intervals are displayed for the posterior distributions.
#' @param abbrevLabels If `TRUE`, long labels on the plot are abbreviated to 10 characters. If `FALSE` the entire label is used.
#' @return a credible interval plot of all latent posterior distribution parameters.
#' @examples
#' # A simple example
#' posteriorDF = data.frame(x = rnorm(100),
#' y = rexp(100),
#' z = runif(100))
#' posteriorDF %>%
#' dagp_plot(densityPlot = TRUE)
#'
#' # More complicated example requiring 'numpyro'
#' \dontrun{
#' # Create a 2 node graph
#' graph = dag_create() %>%
#' dag_node("Get Card","y",
#' rhs = bernoulli(theta),
#' data = carModelDF$getCard) %>%
#' dag_node(descr = "Card Probability by Car",label = "theta",
#' rhs = beta(2,2),
#' child = "y")
#' graph %>% dag_render()
#'
#' # below requires Tensorflow installation
#' drawsDF = graph %>% dag_numpyro(mcmc=TRUE)
#' drawsDF %>% dagp_plot()
#' }
#'
#' # A multiple plate example
#' library(dplyr)
#' poolTimeGymDF = gymDF %>%
#' mutate(stretchType = ifelse(yogaStretch == 1,
#' "Yoga Stretch",
#' "Traditional")) %>%
#' group_by(gymID,stretchType,yogaStretch) %>%
#' summarize(nTrialCustomers = sum(nTrialCustomers),
#' nSigned = sum(nSigned))
#' graph = dag_create() %>%
#' dag_node("Cust Signed","k",
#' rhs = binomial(n,p),
#' data = poolTimeGymDF$nSigned) %>%
#' dag_node("Probability of Signing","p",
#' rhs = beta(2,2),
#' child = "k") %>%
#' dag_node("Trial Size","n",
#' data = poolTimeGymDF$nTrialCustomers,
#' child = "k") %>%
#' dag_plate("Yoga Stretch","x",
#' nodeLabels = c("p"),
#' data = poolTimeGymDF$stretchType,
#' addDataNode = TRUE) %>%
#' dag_plate("Observation","i",
#' nodeLabels = c("x","k","n")) %>%
#' dag_plate("Gym","j",
#' nodeLabels = "p",
#' data = poolTimeGymDF$gymID,
#' addDataNode = TRUE)
#' graph %>% dag_render()
#' \dontrun{
#' # below requires Tensorflow installation
#' drawsDF = graph %>% dag_numpyro(mcmc=TRUE)
#' drawsDF %>% dagp_plot()
#' }
#' @importFrom dplyr bind_rows filter group_by
#' @importFrom rlang is_empty UQ enexpr enquo expr_text quo_name eval_tidy .data
#' @importFrom ggplot2 ggplot geom_density facet_wrap aes theme_minimal theme scale_alpha_continuous guides labs geom_segment element_blank after_stat
#' @importFrom tidyr gather
#' @importFrom cowplot plot_grid
#' @importFrom stats quantile
#' @importFrom lifecycle badge
#' @export
dagp_plot = function(drawsDF,densityPlot = FALSE, abbrevLabels = FALSE) { # case where untidy posterior draws are provided
q95 <- density <- reasonableIntervalWidth <- credIQR <- shape <- param <- NULL ## place holder to pass devtools::check
if (densityPlot == TRUE) {
if (abbrevLabels) { ## shorten labels if desired
drawsDF = drawsDF %>%
tidyr::gather() %>%
dplyr::mutate(key = abbreviate(key, minlength = 10))} else {
drawsDF = drawsDF %>%
tidyr::gather()
}
plot = drawsDF %>% ## start with tidy draws
ggplot2::ggplot(ggplot2::aes(x = value,
y = ggplot2::after_stat(density))) +
ggplot2::geom_density(ggplot2::aes(fill = key)) +
ggplot2::facet_wrap( ~ key, scales = "free_x") +
ggplot2::theme_minimal() +
ggplot2::theme(legend.position = "none")
plot
} else { # case where tidy posterior draws are provided
plotList = list()
## filter out NA's like from LKJ prior (we do not know how to plot this)
tryCatch({
if (abbrevLabels) { ## shorten labels if desired
drawsDF = drawsDF %>%
addPriorGroups() %>%
dplyr::mutate(param = abbreviate(param, minlength = 10))} else {
drawsDF = drawsDF %>%
addPriorGroups()
}
drawsDF = drawsDF %>%
dplyr::mutate(priorGroup = ifelse(is.na(priorGroup),999999,priorGroup)) %>%
dplyr::filter(!is.na(priorGroup)) ##if try works, erase this line
priorGroups = unique(drawsDF$priorGroup)
numPriorGroups = length(priorGroups)
for (i in 1:numPriorGroups) {
df = drawsDF %>% dplyr::filter(priorGroup == priorGroups[i])
# create one plot per group
# groups defined as params with same prior
plotList[[i]] = df %>% dplyr::group_by(param) %>%
dplyr::summarize(q05 = stats::quantile(value,0.05),
q25 = stats::quantile(value,0.55),
q45 = stats::quantile(value,0.45),
q50 = stats::quantile(value,0.50),
q55 = stats::quantile(value,0.55),
q75 = stats::quantile(value,0.75),
q95 = stats::quantile(value,0.95)) %>%
dplyr::mutate(credIQR = q75 - q25) %>%
dplyr::mutate(reasonableIntervalWidth = 1.5 * stats::quantile(credIQR,0.75)) %>%
dplyr::mutate(alphaLevel = ifelse(.data$credIQR > .data$reasonableIntervalWidth, 0.3,1)) %>%
dplyr::arrange(alphaLevel,.data$q50) %>%
dplyr::mutate(param = factor(param, levels = param)) %>%
ggplot2::ggplot(ggplot2::aes(y = param, yend = param)) +
ggplot2::geom_segment(ggplot2::aes(x = q05, xend = q95, alpha = alphaLevel), linewidth = 4, color = "#5f9ea0") +
ggplot2::geom_segment(ggplot2::aes(x = q45, xend = q55, alpha = alphaLevel), linewidth = 4, color = "#11114e") +
ggplot2::scale_alpha_continuous(range = c(0.6,1)) +
ggplot2::guides(alpha = "none") +
ggplot2::theme_minimal(12) +
ggplot2::labs(y = ggplot2::element_blank(),
x = "parameter value",
caption = ifelse(i == numPriorGroups,"Credible Intervals - 10% (dark) & 90% (light)",""))
}
nCol <- ifelse(numPriorGroups==1,1,floor(1 + sqrt(numPriorGroups)))
cowplot::plot_grid(plotlist = plotList, ncol = nCol)
},
error = function(c) dagp_plot(drawsDF, densityPlot = T)) # end try
} # end else
} # end function
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.