R/plot.R

Defines functions pairs.brmsfit stanplot.brmsfit stanplot mcmc_plot mcmc_plot.brmsfit default_plot_variables plot.brmsfit

Documented in mcmc_plot mcmc_plot.brmsfit pairs.brmsfit plot.brmsfit stanplot stanplot.brmsfit

#' Trace and Density Plots for MCMC Draws
#' 
#' @param x An object of class \code{brmsfit}.
#' @param pars Deprecated alias of \code{variable}.
#'   Names of the parameters to plot, as given by a
#'   character vector or a regular expression.
#' @param variable Names of the variables (parameters) to plot, as given by a
#'   character vector or a regular expression (if \code{regex = TRUE}). By
#'   default, a hopefully not too large selection of variables is plotted.
#' @param combo A character vector with at least two elements. 
#'   Each element of \code{combo} corresponds to a column in the resulting 
#'   graphic and should be the name of one of the available 
#'   \code{\link[bayesplot:MCMC-overview]{MCMC}} functions 
#'   (omitting the \code{mcmc_} prefix).
#' @param N The number of parameters plotted per page.
#' @param theme A \code{\link[ggplot2:theme]{theme}} object 
#'   modifying the appearance of the plots. 
#'   For some basic themes see \code{\link[ggplot2:ggtheme]{ggtheme}}
#'   and \code{\link[bayesplot:theme_default]{theme_default}}.
#' @param regex Logical; Indicates whether \code{variable} should
#'   be treated as regular expressions. Defaults to \code{FALSE}.
#' @param fixed (Deprecated) Indicates whether parameter names 
#'   should be matched exactly (\code{TRUE}) or treated as
#'   regular expressions (\code{FALSE}). Default is \code{FALSE}
#'   and only works with argument \code{pars}.
#' @param plot Logical; indicates if plots should be
#'   plotted directly in the active graphic device.
#'   Defaults to \code{TRUE}.
#' @param ask Logical; indicates if the user is prompted 
#'   before a new page is plotted. 
#'   Only used if \code{plot} is \code{TRUE}.
#' @param newpage Logical; indicates if the first set of plots
#'   should be plotted to a new page. 
#'   Only used if \code{plot} is \code{TRUE}.
#' @param ... Further arguments passed to 
#'   \code{\link[bayesplot:MCMC-combos]{mcmc_combo}}.
#' 
#' @return An invisible list of 
#'   \code{\link[gtable:gtable]{gtable}} objects.
#' 
#' @examples
#' \dontrun{ 
#' fit <- brm(count ~ zAge + zBase * Trt 
#'            + (1|patient) + (1|visit), 
#'            data = epilepsy, family = "poisson")
#' plot(fit)
#' ## plot population-level effects only
#' plot(fit, variable = "^b_", regex = TRUE) 
#' }
#' 
#' @method plot brmsfit
#' @import ggplot2
#' @importFrom graphics plot
#' @importFrom grDevices devAskNewPage
#' @export
plot.brmsfit <- function(x, pars = NA, combo = c("dens", "trace"), 
                         N = 5, variable = NULL, regex = FALSE, fixed = FALSE, 
                         theme = NULL, plot = TRUE, ask = TRUE, 
                         newpage = TRUE, ...) {
  contains_draws(x)
  if (!is_wholenumber(N) || N < 1) {
    stop2("Argument 'N' must be a positive integer.")
  }
  variable <- use_variable_alias(variable, x, pars, fixed = fixed)
  if (is.null(variable)) {
    variable <- default_plot_variables(x)
    regex <- TRUE
  }
  draws <- as.array(x, variable = variable, regex = regex)
  variables <- dimnames(draws)[[3]]
  if (!length(variables)) {
    stop2("No valid variables selected.")
  }
  
  if (plot) {
    default_ask <- devAskNewPage()
    on.exit(devAskNewPage(default_ask))
    devAskNewPage(ask = FALSE)
  }
  n_plots <- ceiling(length(variables) / N)
  plots <- vector(mode = "list", length = n_plots)
  for (i in seq_len(n_plots)) {
    sub_vars <- variables[((i - 1) * N + 1):min(i * N, length(variables))]
    sub_draws <- draws[, , sub_vars, drop = FALSE]
    plots[[i]] <- bayesplot::mcmc_combo(
      sub_draws, combo = combo, gg_theme = theme, ...
    )
    if (plot) {
      plot(plots[[i]], newpage = newpage || i > 1)
      if (i == 1) {
        devAskNewPage(ask = ask)
      }
    }
  }
  invisible(plots) 
}

# list all parameter classes to be included in plots by default
default_plot_variables <- function(family) {
  c(fixef_pars(), "^sd_", "^cor_", "^sigma_", "^rescor_", 
    paste0("^", valid_dpars(family), "$"), "^delta$",
    "^theta", "^ar", "^ma", "^arr", "^sderr", "^lagsar", "^errorsar", 
    "^car", "^sdcar", "^sds_", "^sdgp_", "^lscale_")
}

#' MCMC Plots Implemented in \pkg{bayesplot} 
#' 
#' Convenient way to call MCMC plotting functions 
#' implemented in the \pkg{bayesplot} package.
#' 
#' @aliases stanplot stanplot.brmsfit 
#' 
#' @inheritParams plot.brmsfit
#' @param object An \R object typically of class \code{brmsfit}
#' @param type The type of the plot. 
#'   Supported types are (as names) \code{hist}, \code{dens}, 
#'   \code{hist_by_chain}, \code{dens_overlay}, 
#'   \code{violin}, \code{intervals}, \code{areas}, \code{acf}, 
#'   \code{acf_bar},\code{trace}, \code{trace_highlight}, \code{scatter},
#'   \code{rhat}, \code{rhat_hist}, \code{neff}, \code{neff_hist}
#'   \code{nuts_acceptance}, \code{nuts_divergence},
#'   \code{nuts_stepsize}, \code{nuts_treedepth}, and \code{nuts_energy}. 
#'   For an overview on the various plot types see
#'   \code{\link[bayesplot:MCMC-overview]{MCMC-overview}}.
#' @param ... Additional arguments passed to the plotting functions.
#'   See \code{\link[bayesplot:MCMC-overview]{MCMC-overview}} for
#'   more details.
#' 
#' @return A \code{\link[ggplot2:ggplot]{ggplot}} object 
#'   that can be further customized using the \pkg{ggplot2} package.
#' 
#' @details 
#'   Also consider using the \pkg{shinystan} package available via 
#'   method \code{\link{launch_shinystan}} in \pkg{brms} for flexible 
#'   and interactive visual analysis. 
#' 
#' @examples
#' \dontrun{
#' model <- brm(count ~ zAge + zBase * Trt + (1|patient),
#'              data = epilepsy, family = "poisson")
#'              
#' # plot posterior intervals
#' mcmc_plot(model)
#' 
#' # only show population-level effects in the plots
#' mcmc_plot(model, variable = "^b_", regex = TRUE)
#' 
#' # show histograms of the posterior distributions
#' mcmc_plot(model, type = "hist")
#' 
#' # plot some diagnostics of the sampler
#' mcmc_plot(model, type = "neff")
#' mcmc_plot(model, type = "rhat")
#' 
#' # plot some diagnostics specific to the NUTS sampler
#' mcmc_plot(model, type = "nuts_acceptance")
#' mcmc_plot(model, type = "nuts_divergence")
#' }
#' 
#' @export
mcmc_plot.brmsfit <- function(object, pars = NA, type = "intervals", 
                              variable = NULL, regex = FALSE,
                              fixed = FALSE, ...) {
  contains_draws(object)
  object <- restructure(object)
  type <- as_one_character(type)
  variable <- use_variable_alias(variable, object, pars, fixed = fixed)
  if (is.null(variable)) {
    variable <- default_plot_variables(object)
    regex <- TRUE
  }
  valid_types <- as.character(bayesplot::available_mcmc(""))
  valid_types <- sub("^mcmc_", "", valid_types)
  if (!type %in% valid_types) {
    stop2("Invalid plot type. Valid plot types are: \n",
          collapse_comma(valid_types))
  }
  mcmc_fun <- get(paste0("mcmc_", type), asNamespace("bayesplot"))
  mcmc_arg_names <- names(formals(mcmc_fun))
  mcmc_args <- list(...)
  if ("x" %in% mcmc_arg_names) {
    if (grepl("^nuts_", type)) {
      # x refers to a molten data.frame of NUTS parameters
      mcmc_args$x <- nuts_params(object)
    } else {
      # x refers to a data.frame of draws
      draws <- as.array(object, variable = variable, regex = regex)
      if (!length(draws)) {
        stop2("No valid parameters selected.")
      }
      sel_variables <- dimnames(draws)[[3]]
      if (type %in% c("scatter", "hex") && length(sel_variables) != 2L) {
        stop2("Exactly 2 parameters must be selected for this type.",
              "\nParameters selected: ", collapse_comma(sel_variables))
      }
      mcmc_args$x <- draws
    }
  }
  if ("lp" %in% mcmc_arg_names) {
    mcmc_args$lp <- log_posterior(object)
  }
  use_nuts <- isTRUE(object$algorithm == "sampling")
  if ("np" %in% mcmc_arg_names && use_nuts) {
    mcmc_args$np <- nuts_params(object)
  }
  interval_type <- type %in% c("intervals", "areas")
  if ("rhat" %in% mcmc_arg_names && !interval_type) {
    mcmc_args$rhat <- rhat(object)
  }
  if ("ratio" %in% mcmc_arg_names) {
    mcmc_args$ratio <- neff_ratio(object)
  }
  do_call(mcmc_fun, mcmc_args)
}

#' @rdname mcmc_plot.brmsfit
#' @export
mcmc_plot <- function(object, ...) {
  UseMethod("mcmc_plot")
}

# 'stanplot' has been deprecated in brms 2.10.6; remove in brms 3.0
#' @export
stanplot <- function(object, ...) {
  UseMethod("stanplot")
}

#' @export
stanplot.brmsfit <- function(object, ...) {
  warning2("Method 'stanplot' is deprecated. Please use 'mcmc_plot' instead.")
  mcmc_plot.brmsfit(object, ...)
}

#' Create a matrix of output plots from a \code{brmsfit} object
#'
#' A \code{\link[graphics:pairs]{pairs}} 
#' method that is customized for MCMC output.
#' 
#' @param x An object of class \code{brmsfit}
#' @inheritParams plot.brmsfit
#' @param ... Further arguments to be passed to 
#'   \code{\link[bayesplot:MCMC-scatterplots]{mcmc_pairs}}.
#'  
#' @details For a detailed description see  
#'   \code{\link[bayesplot:MCMC-scatterplots]{mcmc_pairs}}.
#'  
#' @examples 
#' \dontrun{
#' fit <- brm(count ~ zAge + zBase * Trt 
#'            + (1|patient) + (1|visit), 
#'            data = epilepsy, family = "poisson")  
#' pairs(fit, variable = variables(fit)[1:3])
#' pairs(fit, variable = "^sd_", regex = TRUE)
#' }
#'
#' @export
pairs.brmsfit <- function(x, pars = NA, variable = NULL, regex = FALSE, 
                          fixed = FALSE, ...) {
  variable <- use_variable_alias(variable, x, pars, fixed = fixed)
  if (is.null(variable)) {
    variable <- default_plot_variables(x)
    regex <- TRUE
  }
  draws <- as.array(x, variable = variable, regex = regex)
  bayesplot::mcmc_pairs(draws, ...)
}

#' Default \pkg{bayesplot} Theme for \pkg{ggplot2} Graphics
#' 
#' This theme is imported from the \pkg{bayesplot} package.
#' See \code{\link[bayesplot:theme_default]{theme_default}}
#' for a complete documentation.
#' 
#' @name theme_default
#' 
#' @param base_size base font size
#' @param base_family base font family
#' 
#' @return A \code{theme} object used in \pkg{ggplot2} graphics.
#' 
#' @importFrom bayesplot theme_default
#' @export theme_default
NULL

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Aug. 23, 2021, 5:08 p.m.