R/stat_spike.R

Defines functions compute_slab_spike

# spike stat
#
# Author: mjskay
###############################################################################


#' Spike plot (ggplot2 stat)
#'
#' Stat for drawing "spikes" (optionally with points on them) at specific points
#' on a distribution (numerical or determined as a function of the distribution),
#' intended for annotating [stat_slabinterval()] geometries.
#'
#' @details
#' This stat computes slab values (i.e. PDF and CDF values) at specified locations
#' on a distribution, as determined by the `at` parameter.
#' @param at The points at which to evaluate the PDF and CDF of the distribution. One of:
#'   - [numeric] vector: points to evaluate the PDF and CDF of the distributions at.
#'   - function or string: function (or name of a function) which,
#'     when applied on a distribution-like object (e.g. a \pkg{distributional} object or a
#'     [posterior::rvar()]), returns a vector of values to evaluate the distribution functions at.
#'   - a [list] where each element is any of the above (e.g. a [numeric], function, or
#'     name of a function): the evaluation points determined by each element of the
#'     list are concatenated together. This means, e.g., `c(0, median, qi)` would add
#'     a spike at `0`, the median, and the endpoints of the `qi` of the distribution.
#' @inheritParams stat_slab
#' @inheritParams geom_spike
#' @eval rd_layer_params("spike", StatSpike, as_dots = TRUE)
#' @param geom Use to override the default connection between [stat_spike()] and [geom_spike()]
#' @template details-x-y-xdist-ydist
#' @return A [ggplot2::Stat] representing a spike geometry which can be added to a [ggplot()] object.
#' @eval rd_spike_aesthetics("spike", StatSpike)
#' @eval rd_slabinterval_computed_variables(StatSpike)
#' @seealso See [geom_spike()] for the geom underlying this stat.
#'   See [stat_slabinterval()] for the stat this shortcut is based on.
#' @family slabinterval stats
#' @examples
#' library(ggplot2)
#' library(distributional)
#' library(dplyr)
#'
#' df = tibble(
#'   d = c(dist_normal(1), dist_gamma(2,2)), g = c("a", "b")
#' )
#'
#' # annotate the density at the mode of a distribution
#' df %>%
#'   ggplot(aes(y = g, xdist = d)) +
#'   stat_slab(aes(xdist = d)) +
#'   stat_spike(at = "Mode") +
#'   # need shared thickness scale so that stat_slab and geom_spike line up
#'   scale_thickness_shared()
#'
#' # annotate the endpoints of intervals of a distribution
#' # here we'll use an arrow instead of a point by setting size = 0
#' arrow_spec = arrow(angle = 45, type = "closed", length = unit(4, "pt"))
#' df %>%
#'   ggplot(aes(y = g, xdist = d)) +
#'   stat_halfeye(point_interval = mode_hdci) +
#'   stat_spike(
#'     at = function(x) hdci(x, .width = .66),
#'     size = 0, arrow = arrow_spec, color = "blue", linewidth = 0.75
#'   ) +
#'   scale_thickness_shared()
#'
#' # annotate quantiles of a sample
#' set.seed(1234)
#' data.frame(x = rnorm(1000, 1:2), g = c("a","b")) %>%
#'   ggplot(aes(x, g)) +
#'   stat_slab() +
#'   stat_spike(at = function(x) quantile(x, ppoints(10))) +
#'   scale_thickness_shared()
#'
#' @name stat_spike
NULL


# compute_slab ------------------------------------------------------------

#' StatSpike$compute_slab()
#' @noRd
compute_slab_spike = function(
  self, data, scales, trans, input, orientation,
  slab_type, at,
  ...
) {
  define_orientation_variables(orientation)

  # calculate slab functions
  s_data = ggproto_parent(StatSlab, self)$compute_slab(
    data, scales = scales, trans = trans, input = input, orientation = orientation,
    slab_type = slab_type,
    ...
  )
  dist = data$dist
  pdf_fun = approx_pdf(dist, s_data$.input, s_data$pdf)
  cdf_fun = approx_cdf(dist, s_data$.input, s_data$cdf)

  # determine evaluation points
  if (!is.list(at)) at = list(at)
  at = lapply(at, function(at_i) {
    if (!is.numeric(at_i)) {
      at_fun = match_function(at_i)
      at_i = at_fun(dist)
    }
    at_i
  })
  # needs to be a vector (e.g. in cases of interval functions
  # like qi() which return matrices)
  at = as.vector(unlist(at, use.names = FALSE))

  # evaluate functions
  pdf = pdf_fun(at)
  cdf = cdf_fun(at)

  data.frame(
    .input = at,
    f = get_slab_function(slab_type, list(pdf = pdf, cdf = cdf)),
    pdf = pdf,
    cdf = cdf,
    n = s_data$n[[1]]
  )
}


# stat_spike --------------------------------------------------------------

#' @rdname ggdist-ggproto
#' @format NULL
#' @usage NULL
#' @import ggplot2
#' @export
StatSpike = ggproto("StatSpike", StatSlab,
  default_params = defaults(list(
    at = "median"
  ), StatSlab$default_params),

  # workaround (#84)
  compute_slab = function(self, ...) compute_slab_spike(self, ...)
)
#' @rdname stat_spike
#' @export
stat_spike = make_stat(StatSpike, geom = "spike")

Try the ggdist package in your browser

Any scripts or data that you put into this service are public.

ggdist documentation built on Nov. 27, 2023, 9:06 a.m.