R/pp_int_plot.R

Defines functions pp_int_plot

Documented in pp_int_plot

#' Posterior Predictive Intervals
#'
#' Means and credible intervals of \code{ySim} with \code{y} overlaid. Modification
#' of ppc_intervals in the bayesplot package for use with the abdiststan package.
#'
#' @param y the vector of observed y values
#' @param ySim a matrix or data frame of the posterior predictive distribution
#' @param x A numeric vector the same length as \code{y} to use as the x-axis
#'   variable. For example, \code{x} could be a predictor variable from a
#'   regression model, a time variable for time-series models, etc. If \code{x}
#'   is missing or NULL, then \code{1:length(y)} is used for the x-axis.
#' @param point.size size of points representing y. Defaults to 2.4
#' @param cred_inner inner credible interval. Defaults to .05 so that it's not visible.
#' @param cred_outer size of posterior interval. Defaults to .95.
#' @param alpha,size,fatten Arguments passed to geoms. For ribbon plots
#'   \code{alpha} and \code{size} are passed to
#'   \code{\link[ggplot2]{geom_ribbon}}. For interval plots \code{size} and
#'   \code{fatten} are passed to \code{\link[ggplot2]{geom_pointrange}}.
#' @export

pp_int_plot <- function(y, ySim, x = NULL, ..., cred_inner = 0.05, cred_outer = 0.95,
                                   size = 1, fatten = 1, point.size = 2.4) {
  label_x <-
    function(x) {
      if (missing(x)) {
        "Index"
      } else {
        NULL
      }
    }

  full_level_name <-
    function(x) {
      switch(x, l = "light", lh = "light_highlight", m = "mid",
        mh = "mid_highlight", d = "dark", dh = "dark_highlight"
      )
    }

  scheme_level_names <-
    function() {
      c(
        "light", "light_highlight", "mid", "mid_highlight", "dark",
        "dark_highlight"
      )
    }

  prepare_colors <-
    function(scheme) {
      setNames(bayesplot:::master_color_list[[scheme]], scheme_level_names())
    }

  scheme_from_string <-
    function(scheme) {
      stopifnot(length(scheme) == 1)
      if (identical(substr(scheme, 1, 4), "mix-")) {
        to_mix <- unlist(strsplit(scheme, split = "-"))[2:3]
        x <- setNames(mixed_scheme(to_mix[1], to_mix[2]), scheme_level_names())
        structure(x, mixed = TRUE, scheme_name = scheme)
      }
      else {
        scheme <- match.arg(scheme, choices = names(master_color_list))
        x <- prepare_colors(scheme)
        structure(x, mixed = FALSE, scheme_name = scheme)
      }
    }

  color_scheme_get <-
    function(scheme, i) {
      if (!missing(scheme)) {
        scheme <- scheme_from_string(scheme)
      }
      else {
        color_list <- list(
          `light` = "#8e0c09",
          light_highlight = "#d0d9db",
          mid = "#4aaadb",
          mid_highlight = "#2678ad",
          dark = "#2695ff",
          dark_highlight = "#0031d3"
        )

        attr(color_list, "mixed") <- FALSE
        attr(color_list, "scheme_name") <- "blue"
        x <- color_list
        scheme <- as.list(x)[scheme_level_names()]
        attr(scheme, "mixed") <- attr(x, "mixed")
        attr(scheme, "scheme_name") <- attr(x, "scheme_name")
      }
      class(scheme) <- c("bayesplot_scheme", "list")
      if (missing(i)) {
        return(scheme)
      }
      else if (is.character(i)) {
        return(get_color(i))
      }
      stopifnot(all(i %in% seq_along(scheme)), length(unique(i)) ==
        length(i))
      scheme[i]
    }

  get_color <-
    function(levels) {
      sel <- which(!levels %in% scheme_level_names())
      if (length(sel)) {
        levels[sel] <- sapply(levels[sel], full_level_name)
      }
      stopifnot(all(levels %in% scheme_level_names()))
      color_vals <- color_scheme_get()[levels]
      unlist(color_vals, use.names = FALSE)
    }

  create_ySim_ids <-
    function(ids)
      paste("italic(y)[rep] (", ids, ")")

  melt_ySim <-
    function(ySim) {
      out <- ySim %>%
        reshape2::melt(varnames = c("rep_id", "y_id")) %>%
        dplyr::as_data_frame()
      id <- create_ySim_ids(out$rep_id)
      out$rep_label <- factor(id, levels = unique(id))
      out[c("y_id", "rep_id", "rep_label", "value")]
    }

  ySim_label <-
    function()
      expression(bold(y)[rep])

  y_label <-
    function()
      expression(bold(y))

  melt_and_stack <-
    function(y, ySim) {
      y_text <- as.character(y_label())
      ySim_text <- as.character(ySim_label())
      molten_ySim <- melt_ySim(ySim)
      levels(molten_ySim$rep_label) <- c(
        levels(molten_ySim$rep_label),
        y_text
      )
      ydat <- dplyr::data_frame(
        rep_label = factor(y_text, levels = levels(molten_ySim$rep_label)),
        rep_id = NA_integer_, y_id = seq_along(y), value = y
      )
      data <- dplyr::bind_rows(molten_ySim, ydat) %>% mutate(rep_label = relevel(
        .data$rep_label,
        y_text
      ), is_y = is.na(.data$rep_id), is_y_label = ifelse(.data$is_y,
        y_text, ySim_text
      ) %>% factor(levels = c(y_text, ySim_text)))
      data[c(
        "y_id", "rep_id", "rep_label", "is_y", "is_y_label",
        "value"
      )]
    }

  is_vector_or_1Darray <-
    function(x) {
      if (is.vector(x) && !is.list(x)) {
        return(TRUE)
      }
      isTRUE(is.array(x) && length(dim(x)) == 1)
    }

  validate_y <-
    function(y) {
      stopifnot(is.numeric(y))
      if (!(inherits(y, "ts") && is.null(dim(y)))) {
        if (!is_vector_or_1Darray(y)) {
          stop("'y' must be a vector or 1D array.")
        }
        y <- as.vector(y)
      }
      if (anyNA(y)) {
        stop("NAs not allowed in 'y'.")
      }
      unname(y)
    }

  validate_ySim <-
    function(ySim, y) {
      stopifnot(is.matrix(ySim), is.numeric(ySim))
      if (is.integer(ySim)) {
        if (nrow(ySim) == 1) {
          ySim[1, ] <- as.numeric(ySim[1, , drop = FALSE])
        }
        else {
          ySim <- apply(ySim, 2, as.numeric)
        }
      }
      if (anyNA(ySim)) {
        stop("NAs not allowed in 'ySim'.")
      }
      if (ncol(ySim) != length(y)) {
        stop("ncol(ySim) must be equal to length(y).")
      }
      unclass(unname(ySim))
    }

  validate_x <-
    function(x = NULL, y, unique_x = FALSE) {
      if (is.null(x)) {
        if (inherits(y, "ts") && is.null(dim(y))) {
          x <- stats::time(y)
        }
        else {
          x <- seq_along(y)
        }
      }
      stopifnot(is.numeric(x))
      if (!is_vector_or_1Darray(x)) {
        stop("'x' must be a vector or 1D array.")
      }
      x <- as.vector(x)
      if (length(x) != length(y)) {
        stop("length(x) must be equal to length(y).")
      }
      if (anyNA(x)) {
        stop("NAs not allowed in 'x'.")
      }
      if (unique_x) {
        stopifnot(identical(length(x), length(unique(x))))
      }
      unname(x)
    }

  postpred_scatter_data <-

    function(y, ySim, x = NULL, group = NULL, cred_inner = 0.682, cred_outer = 0.95) {
      grouped <- !is.null(group)
      stopifnot(cred_inner > 0 && cred_inner < 1)
      stopifnot(cred_outer > 0 && cred_outer <= 1)
      probs <- sort(c(cred_inner, cred_outer))
      cred_inner <- probs[1]
      cred_outer <- probs[2]
      y <- validate_y(y)
      ySim <- validate_ySim(ySim, y)
      x <- validate_x(x, y)
      long_d <- melt_and_stack(y, ySim)
      long_d$x <- x[long_d$y_id]
      long_d$y_obs <- y[long_d$y_id]
      molten_reps <- long_d[!as.logical(long_d[["is_y"]]), , drop = FALSE]
      molten_reps$is_y <- NULL
      if (grouped) {
        group <- validate_group(group, y)
        molten_reps$group <- group[molten_reps$y_id]
        group_vars <- syms(c("y_id", "y_obs", "group", "x"))
      }
      else {
        group_vars <- syms(c("y_id", "y_obs", "x"))
      }
      grouped_d <- dplyr::group_by(molten_reps, !!!group_vars)
      alpha <- (1 - probs) / 2
      probs <- sort(c(alpha, 0.5, 1 - alpha))
      val_col <- sym("value")
      dplyr::ungroup(dplyr::summarise(grouped_d,
        outer_width = cred_outer,
        inner_width = cred_inner, ll = quantile(!!val_col, probs = probs[1]),
        l = quantile(!!val_col, probs = probs[2]), m = mean(!!val_col),
        h = quantile(!!val_col, probs = probs[4]),
        hh = quantile(!!val_col, probs = probs[5])
      ))
    }


  postpred_scatter_plot <-

    function(data, facet_args = list(), alpha = 0.30, fatten = 1,
                 size = 1, grouped = FALSE, style = c("intervals", "ribbon"),
                 x_lab = NULL, ps = point.size) {
      style <- match.arg(style)
      graph <- ggplot(data = data, mapping = aes_(
        x = ~x, y = ~m,
        ymin = ~l, ymax = ~h
      ))
      if (style == "ribbon") {
        graph <- graph + geom_ribbon(aes_(
          color = "ySim", fill = "ySim",
          ymin = ~ll, ymax = ~hh
        ), alpha = alpha, size = size) +
          geom_ribbon(aes_(color = "ySim", fill = "ySim"),
            alpha = alpha, size = size
          ) + geom_line(aes_(color = "ySim"),
            size = size / 2
          ) + geom_blank(aes_(fill = "y")) +
          geom_line(aes_(
            y = ~y_obs,
            color = "y"
          ), size = size)
      }
      else {
        graph <- graph + geom_pointrange(
          mapping = aes_(
            color = "ySim",
            fill = "ySim", ymin = ~ll, ymax = ~hh, alpha = 0,
          ), shape = 21,
          alpha = alpha, size = size, fatten = fatten
        ) + geom_pointrange(mapping = aes_(
          color = "ySim",
          fill = "ySim"
        ), shape = 21, size = size, fatten = fatten) +
          geom_point(mapping = aes_(
            y = ~y_obs, color = "y",
            fill = "y"
          ), shape = 21, size = ps)
      }
      graph <- graph + scale_color_manual(name = "", values = setNames(get_color(c(
        "lh",
        "dh"
      )), c("ySim", "y")), labels = c(
        ySim = ySim_label(),
        y = y_label()
      )) + scale_fill_manual(name = "", values = c(
        ySim = get_color("l"),
        y = if (style == "ribbon") NA else get_color("d")
      ), labels = c(
        ySim = ySim_label(),
        y = y_label()
      ))
      if (grouped) {
        facet_args[["facets"]] <- "group"
        if (is.null(facet_args[["scales"]])) {
          facet_args[["scales"]] <- "free"
        }
        graph <- graph + do.call("facet_wrap", facet_args)
      }
      graph + labs(y = NULL, x = x_lab %||% expression(italic(x)))
    }


  data <- postpred_scatter_data(
    y = y, ySim = ySim, x = x, group = NULL,
    cred_inner = cred_inner, cred_outer = cred_outer
  )
  postpred_scatter_plot(
    data = data, size = size, fatten = fatten,
    grouped = FALSE, style = "intervals", x_lab = label_x(x)
  )
}
abnormally-distributed/abdisttools documentation built on May 5, 2019, 7:07 a.m.