R/ppc.mvgam.R

Defines functions pp_check.mvgam ppc.mvgam ppc

Documented in ppc pp_check.mvgam ppc.mvgam

#' @title Plot conditional posterior predictive checks from \pkg{mvgam} models
#' @importFrom stats quantile density ecdf formula terms
#' @importFrom graphics hist abline box rect lines polygon par
#' @importFrom grDevices rgb
#' @name ppc.mvgam
#' @param object \code{list} object returned from \code{mvgam}. See [mvgam()]
#' @param newdata Optional \code{dataframe} or \code{list} of test data containing at least 'series' and 'time'
#' for the forecast horizon, in addition to any other variables included in the linear predictor of \code{formula}. If
#' included, the observed values in the test data are compared to the model's forecast distribution for exploring
#' biases in model predictions.
#' Note this is only useful if the same \code{newdata} was also included when fitting the original model.
#' @param data_test Deprecated. Still works in place of \code{newdata} but users are recommended to use
#' \code{newdata} instead for more seamless integration into `R` workflows
#' @param series \code{integer} specifying which series in the set is to be plotted
#' @param type \code{character} specifying the type of posterior predictive check to calculate and plot.
#' Valid options are: 'rootogram', 'mean', 'hist', 'density', 'prop_zero', 'pit' and 'cdf'
#' @param n_bins \code{integer} specifying the number of bins to use for binning observed values when plotting
#' a rootogram or histogram. Default is `50` bins for a rootogram, which means that if
#' there are `>50` unique observed values, bins will
#' be used to prevent overplotting and facilitate interpretation. Default for a histogram is to use the
#' number of bins returned by a call to `hist` in base `R`
#' @param legend_position The location may also be specified by setting x to a single keyword from the
#' list "bottomright", "bottom", "bottomleft", "left", "topleft", "top", "topright", "right" and "center".
#' This places the legend on the inside of the plot frame at the given location. Or alternatively,
#' use "none" to hide the legend.
#' @param xlab label for x axis.
#' @param ylab label for y axis.
#' @param ... further \code{\link[graphics]{par}} graphical parameters.
#' @details Conditional posterior predictions are drawn from the fitted \code{mvgam} and compared against
#' the empirical distribution of the observed data for a specified series to help evaluate the model's
#' ability to generate unbiased predictions. For all plots apart from `type = 'rootogram'`, posterior predictions
#' can also be compared to out of sample observations as long as these observations were included as
#'' data_test' in the original model fit and supplied here. Rootograms are currently only plotted using the
#'' hanging' style.
#' \cr
#' Note that the predictions used for these plots are *conditional on the observed data*, i.e. they
#' are those predictions that have been generated directly within
#' the `mvgam()` model. They can be misleading if the model included flexible dynamic trend components. For
#' a broader range of posterior checks that are created using *unconditional* "new data" predictions, see
#' \code{\link{pp_check.mvgam}}
#' @return A base \code{R} graphics plot showing either a posterior rootogram (for \code{type == 'rootogram'}),
#' the predicted vs observed mean for the
#' series (for \code{type == 'mean'}), predicted vs observed proportion of zeroes for the
#' series (for \code{type == 'prop_zero'}),predicted vs observed histogram for the
#' series (for \code{type == 'hist'}), kernel density or empirical CDF estimates for
#' posterior predictions (for \code{type == 'density'} or \code{type == 'cdf'}) or a Probability
#' Integral Transform histogram (for \code{type == 'pit'}).
#' @author Nicholas J Clark
#' @seealso \code{\link{pp_check.mvgam}}, \code{\link{predict.mvgam}}
#' @examples
#' \donttest{
#' # Simulate some smooth effects and fit a model
#' set.seed(0)
#' dat <- mgcv::gamSim(1, n = 200, scale = 2)
#' mod <- mvgam(y ~ s(x0) + s(x1) + s(x2) + s(x3),
#'   data = dat,
#'   family = gaussian(),
#'   chains = 2,
#'   silent = 2
#' )
#'
#' # Posterior checks
#' ppc(mod, type = "hist")
#' ppc(mod, type = "density")
#' ppc(mod, type = "cdf")
#'
#' # Many more options are available with pp_check()
#' pp_check(mod)
#' pp_check(mod, type = "ecdf_overlay")
#' pp_check(mod, type = "freqpoly")
#' }
#' @export
ppc <- function(object, ...) {
  UseMethod("ppc", object)
}

#' @rdname ppc.mvgam
#' @method ppc mvgam
#' @export
ppc.mvgam <- function(
  object,
  newdata,
  data_test,
  series = 1,
  type = "hist",
  n_bins,
  legend_position,
  xlab,
  ylab,
  ...
) {
  # Check arguments
  type <- match.arg(
    arg = type,
    choices = c(
      "rootogram",
      "mean",
      "histogram",
      "density",
      "pit",
      "cdf",
      "prop_zero"
    )
  )
  if (type == "histogram") type <- "hist"

  if (type == "rootogram") {
    if (
      !object$family %in%
        c(
          "poisson",
          "negative binomial",
          "tweedie",
          "nmix",
          "binomial",
          "beta_binomial"
        )
    ) {
      stop(
        "Rootograms not supported for checking non-count data",
        call. = FALSE
      )
    }
  }

  optional_args <- list(...)

  if (!(inherits(object, "mvgam"))) {
    stop('argument "object" must be of class "mvgam"')
  }

  if (sign(series) != 1) {
    stop('argument "series" must be a positive integer', call. = FALSE)
  } else {
    if (series %% 1 != 0) {
      stop('argument "series" must be a positive integer', call. = FALSE)
    }
  }

  if (series > NCOL(object$ytimes)) {
    stop(
      paste0(
        "object only contains data / predictions for ",
        NCOL(object$ytimes),
        " series"
      ),
      call. = FALSE
    )
  }

  if (type == "rootogram" & missing(n_bins)) {
    n_bins <- 50

    if (sign(n_bins) != 1) {
      stop('argument "n_bins" must be a positive integer', call. = FALSE)
    } else {
      if (n_bins %% 1 != 0) {
        stop('argument "n_bins" must be a positive integer', call. = FALSE)
      }
    }
  }

  if (!missing("newdata")) {
    data_test <- newdata

    # Ensure outcome is labelled 'y' when feeding data to the model for simplicity
    if (terms(formula(object$call))[[2]] != "y") {
      data_test$y <- data_test[[terms(formula(object$call))[[2]]]]
    }
  }

  # Pull out observations and posterior predictions for the specified series
  data_train <- object$obs_data
  ends <- seq(
    0,
    dim(mcmc_chains(object$model_output, "ypred"))[2],
    length.out = NCOL(object$ytimes) + 1
  )
  starts <- ends + 1
  starts <- c(1, starts[-c(1, (NCOL(object$ytimes) + 1))])
  ends <- ends[-1]

  # Colours needed for plotting quantiles
  c_light <- c("#DCBCBC")
  c_light_highlight <- c("#C79999")
  c_mid <- c("#B97C7C")
  c_mid_highlight <- c("#A25050")
  c_dark <- c("#8F2727")
  c_dark_highlight <- c("#7C0000")
  s_name <- levels(data_train$series)[series]

  if (!missing(data_test)) {
    if (class(data_test)[1] == "list") {
      if (!"time" %in% names(data_test)) {
        stop('data_test does not contain a "time" column')
      }

      if (!"series" %in% names(data_test)) {
        data_test$series <- factor("series1")
      }
    } else {
      if (!"time" %in% colnames(data_test)) {
        stop('data_test does not contain a "time" column')
      }

      if (!"series" %in% colnames(data_test)) {
        data_test$series <- factor("series1")
      }
    }

    if (class(object$obs_data)[1] == "list") {
      truths <- data.frame(
        y = data_test$y,
        time = data_test$time,
        series = data_test$series
      ) %>%
        dplyr::arrange(time, series) %>%
        dplyr::filter(series == s_name) %>%
        dplyr::pull(y)

      if (object$fit_engine == "stan") {
        # For stan objects, ypred is stored as a vector in column-major order
        preds <- mcmc_chains(object$model_output, "ypred")[, seq(
          series,
          dim(mcmc_chains(object$model_output, "ypred"))[2],
          by = NCOL(object$ytimes)
        )]
      } else {
        preds <- mcmc_chains(object$model_output, "ypred")[,
          starts[series]:ends[series]
        ]
      }

      preds <- preds[,
        ((length(data_train$y) / NCOL(object$ytimes)) + 1):((length(
          data_train$y
        ) /
          NCOL(object$ytimes)) +
          length(truths))
      ]
    } else {
      truths <- data_test %>%
        dplyr::filter(series == s_name) %>%
        dplyr::select(time, y) %>%
        dplyr::distinct() %>%
        dplyr::arrange(time) %>%
        dplyr::pull(y)

      if (object$fit_engine == "stan") {
        # For stan objects, ypred is stored as a vector in column-major order
        preds <- mcmc_chains(object$model_output, "ypred")[, seq(
          series,
          dim(mcmc_chains(object$model_output, "ypred"))[2],
          by = NCOL(object$ytimes)
        )]
      } else {
        preds <- mcmc_chains(object$model_output, "ypred")[,
          starts[series]:ends[series]
        ]
      }

      preds <- preds[,
        ((NROW(data_train) / NCOL(object$ytimes)) + 1):((NROW(data_train) /
          NCOL(object$ytimes)) +
          length(truths))
      ]
    }

    if (NROW(preds) > 4000) {
      preds <- preds[sample(1:NROW(preds), 4000, F), ]
    }
  } else {
    if (class(object$obs_data)[1] == "list") {
      truths <- data.frame(
        y = data_train$y,
        time = data_train$time,
        series = data_train$series
      ) %>%
        dplyr::arrange(series, time) %>%
        dplyr::filter(series == s_name) %>%
        dplyr::pull(y)
    } else {
      truths <- data_train %>%
        dplyr::filter(series == s_name) %>%
        dplyr::select(time, y) %>%
        dplyr::distinct() %>%
        dplyr::arrange(time) %>%
        dplyr::pull(y)
    }

    if (object$fit_engine == "stan") {
      # For stan objects, ypred is stored as a vector in column-major order
      preds <- mcmc_chains(object$model_output, "ypred")[, seq(
        series,
        dim(mcmc_chains(object$model_output, "ypred"))[2],
        by = NCOL(object$ytimes)
      )]
    } else {
      preds <- mcmc_chains(object$model_output, "ypred")[,
        starts[series]:ends[series]
      ]
    }

    preds <- preds[, 1:length(truths), drop = FALSE]

    if (NROW(preds) > 4000) {
      preds <- preds[sample(1:NROW(preds), 4000, F), ]
    }
  }

  # Can't deal with missing values in these diagnostic plots
  preds[is.nan(preds)] <- NA
  preds <- preds[, !is.na(truths)]
  truths <- truths[!is.na(truths)]
  probs <- c(0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95)

  if (type == "prop_zero") {
    pred_props <- apply(preds, 1, function(x) length(which(x == 0)) / length(x))
    lower <- quantile(pred_props, probs = 0.01, na.rm = TRUE)
    upper <- quantile(pred_props, probs = 0.99, na.rm = TRUE)

    if (lower == 0 & upper == 0) {
      stop("No predictions covered zero")
    }

    pred_props <- pred_props[-which(pred_props > upper)]
    if (lower != 0) {
      pred_props <- pred_props[-which(pred_props < lower)]
    }
    obs_prop <- length(which(truths == 0)) / length(truths)

    if (missing(ylab)) {
      ylab <- "Density"
    }

    if (missing(xlab)) {
      xlab <- paste0(
        "Predicted proportion of zeroes for ",
        levels(data_train$series)[series]
      )
    }

    hist(
      pred_props,
      lwd = 2,
      xlim = c(
        min(min(pred_props), min(obs_prop)),
        max(max(pred_props), max(obs_prop))
      ),
      main = "",
      breaks = seq(min(pred_props), max(pred_props), length.out = 15),
      border = "#B97C7C",
      col = "#C79999",
      ylab = ylab,
      xlab = xlab,
      ...
    )
    abline(v = obs_prop, lwd = 3, col = "white")
    abline(v = obs_prop, lwd = 2.5, col = "black")
    box(bty = "L", lwd = 2)

    if (missing(legend_position)) {
      legend_position <- "topright"
    }

    if (legend_position != "none") {
      legend(
        legend_position,
        legend = c(expression(hat(y)[propzero]), expression(y[propzero])),
        bg = "white",
        col = c(c_mid, "black"),
        lty = 1,
        lwd = 2,
        bty = "n"
      )
    }
  }

  if (type == "rootogram") {
    ymax <- floor(max(max(truths), quantile(preds, prob = 0.99, na.rm = TRUE)))
    ymin <- 0L
    xpos <- ymin:ymax

    # Bin if necessary to prevent overplotting
    if (length(xpos) > n_bins) {
      cutpoints <- seq(ymin, ymax, length.out = n_bins)
      xpos <- floor(cutpoints)

      # Find the cutpoint interval that each prediction falls in
      tpreds <- as.list(rep(NA, NROW(preds)))
      for (i in seq_along(tpreds)) {
        tpreds[[i]] <- table(xpos[findInterval(preds[i, ], cutpoints)])
        matches <- match(xpos, rownames(tpreds[[i]]))
        tpreds[[i]] <- as.numeric(tpreds[[i]][matches])
      }
      tpreds <- do.call(rbind, tpreds)
      tpreds[is.na(tpreds)] <- 0
      tyquantile <- sqrt(t(apply(
        tpreds,
        2,
        quantile,
        probs = probs,
        na.rm = TRUE
      )))
      tyexp <- tyquantile[, 5]

      # Repeat for truths
      ty <- table(xpos[findInterval(truths, cutpoints)])
      ty <- sqrt(as.numeric(ty[match(xpos, rownames(ty))]))
    } else {
      tpreds <- as.list(rep(NA, NROW(preds)))
      for (i in seq_along(tpreds)) {
        tpreds[[i]] <- table(as.vector(preds[i, ]))
        matches <- match(xpos, rownames(tpreds[[i]]))
        tpreds[[i]] <- as.numeric(tpreds[[i]][matches])
      }
      tpreds <- do.call(rbind, tpreds)
      tpreds[is.na(tpreds)] <- 0
      tyquantile <- sqrt(t(apply(
        tpreds,
        2,
        quantile,
        probs = probs,
        na.rm = TRUE
      )))
      tyexp <- tyquantile[, 5]
      ty <- table(truths)
      ty <- sqrt(as.numeric(ty[match(xpos, rownames(ty))]))
    }

    ty[is.na(ty)] <- 0
    ypos <- ty / 2
    ypos <- tyexp - ypos
    data <- data.frame(xpos, ypos, ty, tyexp, tyquantile)
    N <- length(xpos)
    idx <- rep(1:N, each = 2)
    repped_x <- rep(xpos, each = 2)
    x <- sapply(
      1:length(idx),
      function(k) {
        if (k %% 2 == 0) {
          repped_x[k] + min(diff(xpos)) / 2
        } else {
          repped_x[k] - min(diff(xpos)) / 2
        }
      }
    )

    if (missing(xlab)) {
      xlab <- expression(y)
    }

    if (missing(ylab)) {
      ylab <- expression(sqrt(frequency))
    }

    # Plot the rootogram
    plot(
      1,
      type = "n",
      bty = "L",
      xlab = xlab,
      ylab = ylab,
      xlim = range(xpos),
      ylim = range(c(data$tyexp, data[, 13], data[, 5], data$tyexp - data$ty)),
      ...
    )
    rect(
      xleft = x[seq(1, N * 2, by = 2)],
      xright = x[seq(2, N * 2, by = 2)],
      ytop = data$tyexp,
      ybottom = data$tyexp - data$ty,
      col = "grey80",
      border = "grey10",
      lwd = 2
    )
    rect(
      xleft = x[seq(1, N * 2, by = 2)],
      xright = x[seq(2, N * 2, by = 2)],
      ytop = data[, 13],
      ybottom = data[, 5],
      col = "#DCBCBC85",
      border = "transparent"
    )
    rect(
      xleft = x[seq(1, N * 2, by = 2)],
      xright = x[seq(2, N * 2, by = 2)],
      ytop = data[, 12],
      ybottom = data[, 6],
      col = "#C7999985",
      border = "transparent"
    )
    rect(
      xleft = x[seq(1, N * 2, by = 2)],
      xright = x[seq(2, N * 2, by = 2)],
      ytop = data[, 11],
      ybottom = data[, 7],
      col = "#B97C7C85",
      border = "transparent"
    )
    rect(
      xleft = x[seq(1, N * 2, by = 2)],
      xright = x[seq(2, N * 2, by = 2)],
      ytop = data[, 10],
      ybottom = data[, 8],
      col = "#A2505085",
      border = "transparent"
    )
    abline(h = 0, col = "white", lwd = 3)
    abline(h = 0, col = "black", lwd = 2.5)
    for (k in 1:N) {
      lines(
        x = c(x[seq(1, N * 2, by = 2)][k], x[seq(2, N * 2, by = 2)][k]),
        y = c(data[k, 9], data[k, 9]),
        col = "#8F2727",
        lwd = 3
      )
    }
    box(bty = "L", lwd = 2)
  }

  if (type == "mean") {
    # Plot observed and predicted means
    pred_means <- apply(preds, 1, mean, na.rm = TRUE)
    lower <- quantile(pred_means, probs = 0.01, na.rm = TRUE)
    upper <- quantile(pred_means, probs = 0.99, na.rm = TRUE)
    pred_means <- pred_means[-which(pred_means > upper)]
    if (lower != 0) {
      pred_means <- pred_means[-which(pred_means < lower)]
    }
    obs_mean <- mean(truths)

    if (missing(ylab)) {
      ylab <- "Density"
    }

    if (missing(xlab)) {
      xlab <- paste0("Predicted mean for ", levels(data_train$series)[series])
    }

    hist(
      pred_means,
      xlim = c(
        min(min(pred_means, na.rm = TRUE), min(obs_mean, na.rm = TRUE)),
        max(max(pred_means, na.rm = TRUE), max(obs_mean, na.rm = TRUE))
      ),
      lwd = 2,
      main = "",
      breaks = seq(
        min(pred_means, na.rm = TRUE),
        max(pred_means, na.rm = TRUE),
        length.out = 20
      ),
      border = "#B97C7C",
      col = "#C79999",
      ylab = ylab,
      xlab = xlab,
      ...
    )
    abline(v = obs_mean, lwd = 3, col = "white")
    abline(v = obs_mean, lwd = 2.5, col = "black")
    box(bty = "L", lwd = 2)

    if (missing(legend_position)) {
      legend_position <- "topright"
    }

    if (legend_position != "none") {
      legend(
        legend_position,
        legend = c(expression(hat(mu)), expression(mu)),
        bg = "white",
        col = c(c_mid, "black"),
        lty = 1,
        lwd = 2,
        bty = "n"
      )
    }
  }

  # Generate a sample sequence and plot
  if (type == "density") {
    max_x <- max(
      max(density(preds[1, ], na.rm = TRUE)$x),
      max(density(truths, na.rm = TRUE)$x)
    )
    min_x <- min(
      min(density(preds[1, ], na.rm = TRUE)$x),
      min(density(truths, na.rm = TRUE)$x)
    )
    pred_densities <- do.call(
      rbind,
      (lapply(1:NROW(preds), function(x) {
        if (length(which(is.na(preds[x, ]))) > (length(preds[x, ]) - 3)) {
          rep(
            0,
            length(density(truths, from = min_x, to = max_x, na.rm = TRUE)$y)
          )
        } else {
          dens <- density(preds[x, ], from = min_x, to = max_x, na.rm = TRUE)
          dens$y
        }
      }))
    )

    cred <- sapply(
      1:NCOL(pred_densities),
      function(n) quantile(pred_densities[, n], probs = probs, na.rm = TRUE)
    )
    true_dens <- density(truths, from = min_x, to = max_x, na.rm = TRUE)
    ymax <- max(c(max(cred, na.rm = TRUE), max(true_dens$y, na.rm = TRUE)))

    if (missing(ylab)) {
      ylab <- paste0(
        "Predictive density for ",
        levels(data_train$series)[series]
      )
    }

    if (missing(xlab)) {
      xlab <- ""
    }

    if (object$family == "beta") {
      xlimits <- c(0, 1)
    } else if (
      object$family %in% c("poisson", "negative binomial", "lognormal", "Gamma")
    ) {
      xlimits <- c(0, max_x)
    } else {
      xlimits <- c(min_x, max_x)
    }

    plot(
      1,
      type = "n",
      bty = "L",
      xlab = xlab,
      ylab = ylab,
      xlim = xlimits,
      ylim = c(0, ymax),
      ...
    )

    polygon(
      c(true_dens$x, rev(true_dens$x)),
      c(cred[1, ], rev(cred[9, ])),
      col = c_light,
      border = NA
    )
    polygon(
      c(true_dens$x, rev(true_dens$x)),
      c(cred[2, ], rev(cred[8, ])),
      col = c_light_highlight,
      border = NA
    )
    polygon(
      c(true_dens$x, rev(true_dens$x)),
      c(cred[3, ], rev(cred[7, ])),
      col = c_mid,
      border = NA
    )
    polygon(
      c(true_dens$x, rev(true_dens$x)),
      c(cred[4, ], rev(cred[6, ])),
      col = c_mid_highlight,
      border = NA
    )
    lines(true_dens$x, cred[5, ], col = c_dark, lwd = 2.5)

    lines(x = true_dens$x, y = true_dens$y, lwd = 3, col = "white")
    lines(x = true_dens$x, y = true_dens$y, lwd = 2.5, col = "black")

    if (missing(legend_position)) {
      legend_position <- "topright"
    }

    if (legend_position != "none") {
      legend(
        legend_position,
        legend = c(expression(hat(y)), "y"),
        bg = "white",
        col = c(c_mid, "black"),
        lty = 1,
        lwd = 2,
        bty = "n"
      )
    }
    box(bty = "L", lwd = 2)
  }

  if (type == "hist") {
    if (missing(n_bins)) {
      n_bins <- max(c(
        length(hist(c(truths, as.vector(preds)), plot = F)$breaks),
        20
      ))
    }

    if (sign(n_bins) != 1) {
      stop('argument "n_bins" must be a positive integer', call. = FALSE)
    } else {
      if (n_bins %% 1 != 0) {
        stop('argument "n_bins" must be a positive integer', call. = FALSE)
      }
    }

    xlim <- c(
      min(
        min(density(preds[1, ], na.rm = TRUE)$x),
        min(density(truths, na.rm = TRUE)$x)
      ),
      max(
        max(density(preds[1, ], na.rm = TRUE)$x),
        max(density(truths, na.rm = TRUE)$x)
      )
    )

    if (object$family == "beta") {
      xlim <- c(0, 1)
    } else if (
      object$family %in% c("poisson", "negative binomial", "lognormal", "Gamma")
    ) {
      xlim <- c(0, xlim[2])
    } else {
      xlim <- xlim
    }

    breaks <- seq(xlim[1], xlim[2], length.out = n_bins)
    truths <- truths[truths <= xlim[2]]
    truths <- truths[truths >= xlim[1]]
    preds <- preds[preds <= xlim[2]]
    preds <- preds[preds >= xlim[1]]

    ylim <- c(
      0,
      max(
        c(
          max(hist(truths, breaks = breaks, plot = F)$density, na.rm = TRUE),
          max(hist(preds, breaks = breaks, plot = F)$density, na.rm = TRUE)
        ),
        na.rm = TRUE
      )
    )

    if (missing(xlab)) {
      xlab <- paste0("Count")
    }

    if (missing(ylab)) {
      ylab <- ""
    }

    hist(
      preds,
      breaks = breaks,
      lwd = 2,
      main = "",
      xlab = xlab,
      ylab = ylab,
      ylim = ylim,
      xlim = xlim,
      border = "#B97C7C",
      col = "#C79999",
      freq = F,
      ...
    )

    par(lwd = 2)
    hist(
      truths,
      breaks = breaks,
      main = "",
      xlab = "",
      ylim = ylim,
      xlim = xlim,
      ylab = "",
      yaxt = "n",
      col = rgb(red = 0, green = 0, blue = 0, alpha = 0),
      border = "black",
      add = T,
      freq = F
    )
    par(lwd = 1)
    box(bty = "L", lwd = 2)

    if (missing(legend_position)) {
      legend_position <- "topright"
    }

    if (legend_position != "none") {
      legend(
        legend_position,
        legend = c(expression(hat(y)), "y"),
        bg = "white",
        col = c(c_mid, "black"),
        lty = 1,
        lwd = 2,
        bty = "n"
      )
    }
  }

  if (type == "cdf") {
    ecdf_plotdat <- function(vals, x) {
      if (length(which(is.na(vals))) > (length(vals) - 3)) {
      } else {
        func <- ecdf(vals)
        func(x)
      }
    }

    plot_x <- seq(
      from = min(truths, na.rm = T),
      to = max(truths, na.rm = T),
      length.out = 100
    )

    pred_cdfs <- do.call(
      rbind,
      (lapply(1:NROW(preds), function(x) {
        ecdf_plotdat(preds[x, ], x = plot_x)
      }))
    )

    probs <- c(0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95)
    cred <- sapply(
      1:NCOL(pred_cdfs),
      function(n) quantile(pred_cdfs[, n], probs = probs, na.rm = TRUE)
    )

    if (missing(ylab)) {
      ylab <- paste0("Predictive CDF for ", levels(data_train$series)[series])
    }

    if (missing(xlab)) {
      xlab <- ""
    }

    plot(
      1,
      type = "n",
      bty = "L",
      xlab = xlab,
      ylab = ylab,
      xlim = c(min(plot_x, na.rm = TRUE), max(plot_x, na.rm = TRUE)),
      ylim = c(0, 1),
      ...
    )

    polygon(
      c(plot_x, rev(plot_x)),
      c(cred[1, ], rev(cred[9, ])),
      col = c_light,
      border = NA
    )
    polygon(
      c(plot_x, rev(plot_x)),
      c(cred[2, ], rev(cred[8, ])),
      col = c_light_highlight,
      border = NA
    )
    polygon(
      c(plot_x, rev(plot_x)),
      c(cred[3, ], rev(cred[7, ])),
      col = c_mid,
      border = NA
    )
    polygon(
      c(plot_x, rev(plot_x)),
      c(cred[4, ], rev(cred[6, ])),
      col = c_mid_highlight,
      border = NA
    )
    lines(plot_x, cred[5, ], col = c_dark, lwd = 2.5)

    lines(x = plot_x, y = ecdf_plotdat(truths, plot_x), col = "white", lwd = 3)
    lines(
      x = plot_x,
      y = ecdf_plotdat(truths, plot_x),
      col = "black",
      lwd = 2.5
    )

    if (missing(legend_position)) {
      legend_position <- "bottomright"
    }

    if (legend_position != "none") {
      legend(
        legend_position,
        legend = c(expression(hat(y)), "y"),
        bg = "white",
        col = c(c_mid, "black"),
        lty = 1,
        lwd = 2,
        bty = "n"
      )
    }
    box(bty = "L", lwd = 2)
  }

  if (type == "pit") {
    # Calculate emipirical cumulative distribution function as the
    # portion of (y_predicted <= y_true)
    n_pred <- ncol(preds)
    P_x <- vapply(
      seq_along(truths),
      function(i) {
        sum(preds[i, ] <= truths[i]) / n_pred
      },
      .0
    )

    P_xm1 <- vapply(
      seq_along(truths),
      function(i) {
        sum(preds[i, ] <= truths[i] - 1.e-6) / n_pred
      },
      .0
    )
    # 1000 replicates for randomised PIT
    u <- replicate(1000, P_xm1 + stats::runif(length(truths)) * (P_x - P_xm1))
    pit_hist <- hist(u, breaks = seq(0, 1, by = 0.1), plot = F)$density
    pit_hist <- (pit_hist / sum(pit_hist)) * 10

    barplot(
      pit_hist,
      lwd = 2,
      col = "#B97C7C",
      xlab = paste0("Predictive PIT for ", levels(data_train$series)[series]),
      border = NA,
      ...
    )
    abline(h = 1, col = "#FFFFFF60", lwd = 2.85)
    abline(h = 1, col = "black", lwd = 2.5, lty = "dashed")
    box(bty = "L", lwd = 2)
  }
}

#' Posterior Predictive Checks for \code{mvgam} models
#'
#' Perform unconditional posterior predictive checks with the help
#' of the \pkg{bayesplot} package.
#'
#' @aliases pp_check
#' @inheritParams brms::pp_check
#' @importFrom insight get_predictors
#' @importFrom brms do_call
#' @param object An object of class \code{mvgam}.
#' @param newdata Optional \code{dataframe} or \code{list} of test data containing the
#' variables included in the linear predictor of \code{formula}. If not supplied,
#' predictions are generated for the original observations used for the model fit.
#' Ignored if using one of the residual plots (i.e. 'resid_hist')
#' @param ... Further arguments passed to \code{\link{predict.mvgam}}
#'   as well as to the PPC function specified in \code{type}.
#' @inheritParams brms::prepare_predictions.brmsfit
#'
#' @return A ggplot object that can be further
#'  customized using the \pkg{ggplot2} package.
#'
#' @details Unlike the conditional posterior checks provided by \code{\link{ppc}},
#' This function computes *unconditional* posterior predictive checks (i.e. it generates
#' predictions for fake data without considering the true observations associated with those
#' fake data). For a detailed explanation of each of the ppc functions,
#' see the \code{\link[bayesplot:PPC-overview]{PPC}}
#' documentation of the \pkg{\link[bayesplot:bayesplot-package]{bayesplot}}
#' package.
#' @seealso \code{\link{ppc}}, \code{\link{predict.mvgam}}
#' @author Nicholas J Clark
#' @examples
#' \donttest{
#' simdat <- sim_mvgam(seasonality = "hierarchical")
#' mod <- mvgam(
#'   y ~ series +
#'     s(season, bs = "cc", k = 6) +
#'     s(season, series, bs = "fs", k = 4),
#'   data = simdat$data_train,
#'   chains = 2,
#'   silent = 2
#' )
#'
#' # Use pp_check(mod, type = "xyz") for a list of available plot types
#'
#' # Default is a density overlay for all observations
#' pp_check(mod)
#'
#' # Rootograms particularly useful for count data
#' pp_check(mod, type = "rootogram")
#'
#' # Grouping plots by series is useful
#' pp_check(mod,
#'   type = "bars_grouped",
#'   group = "series", ndraws = 50
#' )
#' pp_check(mod,
#'   type = "ecdf_overlay_grouped",
#'   group = "series", ndraws = 50
#' )
#' pp_check(mod,
#'   type = "stat_freqpoly_grouped",
#'   group = "series", ndraws = 50
#' )
#'
#' # Several types can be used to plot distributions of randomized
#' # quantile residuals
#' pp_check(
#'   object = mod,
#'   x = "season",
#'   type = "resid_ribbon"
#' )
#' pp_check(
#'   object = mod,
#'   x = "season",
#'   group = "series",
#'   type = "resid_ribbon_grouped"
#' )
#' pp_check(mod,
#'   ndraws = 5,
#'   type = "resid_hist_grouped",
#'   group = "series"
#' )
#'
#' # Custom functions accepted
#' pp_check(mod, type = "stat", stat = function(x) mean(x == 0))
#' pp_check(mod,
#'   type = "stat_grouped",
#'   stat = function(x) mean(x == 0),
#'   group = "series"
#' )
#'
#' # Some functions accept covariates to set the x-axes
#' pp_check(mod,
#'   x = "season",
#'   type = "ribbon_grouped",
#'   prob = 0.5,
#'   prob_outer = 0.8,
#'   group = "series"
#' )
#'
#' # Many plots can be made without the observed data
#' pp_check(mod, prefix = "ppd")
#' }
#'
#' @importFrom bayesplot pp_check
#' @export pp_check
#' @author Nicholas J Clark
#' @export
pp_check.mvgam <- function(
  object,
  type,
  ndraws = NULL,
  prefix = c("ppc", "ppd"),
  group = NULL,
  x = NULL,
  newdata = NULL,
  ...
) {
  # Set red colour scheme
  col_scheme <- attr(color_scheme_get(), "scheme_name")
  color_scheme_set("red")

  dots <- list(...)
  if (missing(type)) {
    type <- "dens_overlay"
  }

  prefix <- match.arg(prefix)
  ndraws_given <- "ndraws" %in% names(match.call())

  if (is.null(newdata)) {
    newdata <- object$obs_data
  }

  if (prefix == "ppc") {
    # No type checking for prefix 'ppd' yet
    valid_types <- sort(
      c(
        as.character(bayesplot::available_ppc("")),
        "ppc_resid_hist",
        "ppc_resid_hist_grouped",
        "ppc_resid_ribbon",
        "ppc_resid_ribbon_grouped"
      )
    )
    valid_types <- sub("^ppc_", "", valid_types)
    if (!type %in% valid_types) {
      stop(
        "Type '",
        type,
        "' is not a valid ppc type. ",
        "Valid types are:\n",
        paste0("'", valid_types, "'", collapse = ", "),
        call. = FALSE
      )
    }
  }

  bptype <- type

  if (bptype %in% c("resid_hist", "resid_hist_grouped")) {
    if (is.null(object$resids)) {
      object <- add_residuals(object)
    }
    bptype <- sub("resid", "error", bptype)
  }

  if (bptype %in% c("resid_ribbon", "resid_ribbon_grouped")) {
    if (is.null(object$resids)) {
      object <- add_residuals(object)
    }
    bptype <- sub("resid_", "", bptype)
  }

  ppc_fun <- get(paste0(prefix, "_", bptype), asNamespace("bayesplot"))

  family <- object$family
  if (family == "nmix") {
    stop("'pp_check' is not implemented for this family.", call. = FALSE)
  }
  valid_vars <- names(get_predictors(object))
  if ("group" %in% names(formals(ppc_fun))) {
    if (is.null(group)) {
      stop(
        "Argument 'group' is required for ppc type '",
        type,
        "'.",
        call. = FALSE
      )
    }
    if (!group %in% valid_vars) {
      stop(
        "Variable '",
        group,
        "' could not be found in the data.",
        call. = FALSE
      )
    }
  }
  if ("x" %in% names(formals(ppc_fun))) {
    if (!is.null(x) && !x %in% valid_vars) {
      stop("Variable '", x, "' could not be found in the data.", call. = FALSE)
    }
  }
  if (type == "error_binned") {
    method <- "posterior_epred"
  } else {
    method <- "posterior_predict"
  }
  if (!ndraws_given) {
    aps_types <- c(
      "error_scatter_avg",
      "error_scatter_avg_vs_x",
      "intervals",
      "intervals_grouped",
      "loo_intervals",
      "loo_pit",
      "loo_pit_overlay",
      "loo_pit_qq",
      "loo_ribbon",
      "pit_ecdf",
      "pit_ecdf_grouped",
      "ribbon",
      "ribbon_grouped",
      "rootogram",
      "scatter_avg",
      "scatter_avg_grouped",
      "stat",
      "stat_2d",
      "stat_freqpoly_grouped",
      "stat_grouped",
      "violin_grouped"
    )
    if (type %in% aps_types) {
      ndraws <- NULL
      message("Using all posterior draws for ppc type '", type, "' by default.")
    } else {
      ndraws <- 10
      message("Using 10 posterior draws for ppc type '", type, "' by default.")
    }
  }

  y <- NULL
  if (prefix == "ppc") {
    # y is ignored in prefix 'ppd' plots; get the response variable,
    # but take care that binomial models use the cbind() lhs
    resp_terms <- as.character(terms(formula(object$call))[[2]])
    if (length(resp_terms) == 1) {
      out_name <- as.character(terms(object$call)[[2]])
    } else {
      if (any(grepl("cbind", resp_terms))) {
        resp_terms <- resp_terms[-grepl("cbind", resp_terms)]
        out_name <- resp_terms[1]
      }
    }
    y <- newdata[[out_name]]
  }

  # For plotting DS residuals, set y to zero and take
  # -1 * residual so that errors are in the correct direction
  if (grepl("resid", type)) {
    y[!is.na(y)] <- 0
    yrep <- t(-1 * residuals(object, summary = FALSE))

    if (!is.null(ndraws)) {
      yrep <- yrep[1:ndraws, ]
    }
  } else {
    pred_args <- list(
      object,
      newdata = newdata,
      ndraws = ndraws,
      ...
    )
    yrep <- do_call(method, pred_args)
  }

  if (anyNA(y)) {
    warning("NA responses are not shown in 'pp_check'.")
    take <- !is.na(y)
    y <- y[take]
    yrep <- yrep[, take, drop = FALSE]
  } else {
    take <- NULL
  }

  # Prepare plotting arguments
  ppc_args <- list()
  if (prefix == "ppc") {
    ppc_args$y <- y
    ppc_args$yrep <- yrep
  } else if (prefix == "ppd") {
    ppc_args$ypred <- yrep
  }
  if (!is.null(group)) {
    if (!exists(group, newdata)) {
      stop(paste0("Variable ", group, " not in newdata"), call. = FALSE)
    }
    ppc_args$group <- newdata[[group]]

    if (!is.null(take)) {
      ppc_args$group <- ppc_args$group[take]
    }
  }

  is_like_factor <- function(x) {
    is.factor(x) || is.character(x) || is.logical(x)
  }

  if (!is.null(x)) {
    ppc_args$x <- newdata[[x]]
    if (!is_like_factor(ppc_args$x)) {
      ppc_args$x <- as.numeric(ppc_args$x)
    }

    if (!is.null(take)) {
      ppc_args$x <- ppc_args$x[take]
    }
  }

  if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
    # ppc_args$psis_object <- do_call(
    #   compute_loo, c(pred_args, criterion = "psis")
    # )
    # compute_loo() not available yet for mvgam
    ppc_args$psis_object <- NULL
  }
  if ("lw" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
    # ppc_args$lw <- weights(
    #   do_call(compute_loo, c(pred_args, criterion = "psis"))
    # )
    # compute_loo() not available yet for mvgam
    ppc_args$lw <- NULL
  }

  # Most ... arguments are meant for the prediction function
  for_pred <- names(dots) %in% names(formals(posterior_predict.mvgam))
  ppc_args <- c(ppc_args, dots[!for_pred])

  # Generate plot
  out_plot <- do_call(ppc_fun, ppc_args)

  if ("x" %in% names(formals(ppc_fun)) && !is.null(x)) {
    out_plot <- out_plot +
      ggplot2::labs(x = x)
  }

  # Improve labels for residual plots
  if (type %in% c("resid_hist", "resid_hist_grouped")) {
    out_plot <- out_plot +
      ggplot2::labs(x = "DS residuals")
  }

  if (type %in% c("resid_ribbon", "resid_ribbon_grouped")) {
    out_plot <- out_plot +
      ggplot2::theme(legend.position = "none") +
      ggplot2::labs(y = "DS residuals")
  }

  # Reset color scheme and return the plot
  color_scheme_set(col_scheme)
  return(out_plot)
}
nicholasjclark/mvgam documentation built on April 17, 2025, 9:39 p.m.