R/density_plot.R

Defines functions density_plot

Documented in density_plot

#' Create a density for a single desired node
#'
#' @param post an object of class \code{mcmc.list}
#' @param p_one a character vector of length == 1. Should be a node reference
#'   to a single element in the model. E.g., \code{"pi[1]"}, not \code{"pi"}.
#' @param show_diags a character vector of length == 1. Must be one of
#'   \code{"always"}, \code{"never"}, \code{"if_poor_Rhat"}. Defaults to
#'   \code{"if_poor_Rhat"}, which will display the Rhat and effective MCMC samples
#'   if the Rhat statistic is greater than 1.1
#' @importFrom StatonMisc %!in%
#' @export

density_plot = function(post, p_one, show_diags = "if_poor_Rhat") {
  # return error if p_one has length > 1
  if (length(p_one) > 1) stop ("p_one must have only one element, and it must match the desired node exactly")

  # return error if show_diags is invalid
  valid_show_diags = c("always", "never", "if_poor_Rhat")
  if (show_diags %!in% valid_show_diags) {
    stop ("show_diags must be one of: ", StatonMisc::list_out(valid_show_diags, final = "or", wrap = '"'))
  }

  # lock in the search string. It must be exact for the rest of this code to work
  p_one = ins_regex_lock(p_one)

  # extract this node's samples
  post_sub = post_subset(post, p_one, matrix = T, iters = T, chains = T)

  # fit the KDE
  dens = suppressWarnings(density(post_sub[,3]))

  # determine axis limits
  x_lim = range(dens$x)
  y_max = max(tapply(post_sub[,3], post_sub[,"CHAIN"], function(x) suppressWarnings(max(density(x)$y))))

  # set up device
  par(yaxs = "i", mar = c(1.5,2,2.5,0.5), tcl = -0.25,
      mgp = c(1.5,0.4,0))

  # create an empty plot
  plot(1,1,type = "n",
       xlim = x_lim,  xlab = "",
       ylim = c(0, y_max) * 1.05, ylab = "")

  # loop through chains plotting the distribution of each
  junk = sapply(unique(post_sub[,"CHAIN"]), function(c) {
    chain_sub = post_sub[post_sub[,"CHAIN"] == c,]
    dens = suppressWarnings(density(chain_sub[,3]))
    lines(dens$y ~ dens$x, col = c)
  })

  # calculate convergence diagnostics if requested
  if (show_diags %in% c("always", "if_poor_Rhat")) {
    diags = post_summ(post, p_one, Rhat = T, ess = T)[c("Rhat", "ess"),]
    if (!is.na(diags["Rhat"])) {
      if (diags["Rhat"] >= 1.1 | show_diags == "always") {
        Rhat_text = paste0("Rhat: ", diags["Rhat"])
        ess_text = paste0("ESS: ", StatonMisc::prettify(diags["ess"]))
        legend("topright", legend = c(Rhat_text, ess_text), bty = "n")
      }
    }
  }

  box()
}
bstaton1/codaTools documentation built on Dec. 8, 2019, 7:54 p.m.