R/plot-cdi.R

Defines functions plot_bayesian_cdi2 plot_bayesian_cdi

Documented in plot_bayesian_cdi plot_bayesian_cdi2

#' Bayesian version of the CDI plot
#' 
#' The CDI plot presents the coefficients for the variable of interest (top-left panel), the spread of the data 
#' (bottom-left panel), and the influence statistic (bottom-right panel).
#' 
#' @param fit An object of class \code{brmsfit}.
#' @param xfocus The column name of the variable to be plotted on the x axis. 
#'   This column name must match one of the column names in the 
#'   \code{data.frame} that was passed to \code{brm} as the \code{data} argument.
#' @param yfocus The column name of the variable to be plotted on the y axis. 
#'   This column name must match one of the column names in the 
#'   \code{data.frame} that was passed to \code{brm} as the \code{data} argument. 
#'   This is generally the temporal variable in a generalised linear model (e.g. year).
#' @param hurdle If a hurdle model then use the hurdle.
#' @param sort_coefs Should the coefficients be sorted from highest to lowest.
#' @param axis.text.x.bl Include the x axis labels on the bottom-left (bl) bubble plot panel.
#' @param xlab The x axis label.
#' @param ylab The y axis label.
#' @param colour The colour to use in the plot.
#' @param p_margin The margin between panels on the plot. This is passed to \code{margin} within \code{theme}.
#' @param legend To show the legend or not.
#' @param sum_by Sum to 1 by row, sum to 1 by column, sum to 1 across all data, or raw. The size of the bubbles will be 
#'   the same for all and raw, but the legend will change from numbers of records to a proportion.
#' @param ... Further arguments passed to nothing.
#' @return a \code{ggplot} object.
#' @seealso \code{\link{get_coefs}}, \code{\link{get_influ}}, \code{\link{plot_bubble}}
#' @importFrom gtable is.gtable gtable_filter
#' @importFrom stats poly
#' @importFrom tidyselect all_of
#' @importFrom stats median
#' @import ggplot2
#' @import dplyr
#' @import patchwork
#' @export
#' 
plot_bayesian_cdi <- function(fit, 
                              xfocus = "area", yfocus = "fishing_year",
                              xlab = NULL, ylab = NULL, 
                              hurdle = FALSE,
                              sort_coefs = FALSE, 
                              axis.text.x.bl = TRUE,
                              colour = "purple", 
                              p_margin = 0.05, 
                              legend = TRUE, 
                              sum_by = "row", ...) {
  
  if (!is.brmsfit(fit)) stop("fit is not an object of class brmsfit.")
  if (is.null(xlab)) xlab <- xfocus
  if (is.null(ylab)) ylab <- yfocus
  y_coefs <- "Conditional effect"
  
  # Identify the type of variable we are dealing with
  type <- id_var_type(fit = fit, xfocus = xfocus, hurdle = hurdle)
  
  # Posterior samples of coefficients
  # if (type %in% c("random_effect")) {
  if (type %in% c("fixed_effect", "random_effect")) {
    coefs <- get_coefs(fit = fit, var = xfocus, hurdle = hurdle)
  } else {
    # this would plot the marginal/conditional effect, but if it is a hurdle model it ignores the hurdle bit
    coefs <- get_marginal(fit = fit, var = xfocus)# %>%
      # mutate(value = log(value))
    # library(marginaleffects)
    # pred <- predictions(model = fit)
    # coefs <- get_coefs_raw(fit = fit, var = xfocus)
  }
  
  # If using the lognormal distribution then transform the coefs
  if (fit$family$family == "lognormal") {
    coefs <- coefs %>% mutate(value = exp(.data$value))
    y_coefs <- "Coefficient"
  }
  
  # Model data
  if (is.numeric(coefs$variable)) {
    data <- fit$data %>% select(all_of(c(yfocus, xfocus)))
    length.out <- 15
    dmin <- min(data[,xfocus])
    dmax <- max(data[,xfocus])
    breaks <- seq(dmin, dmax, length.out = length.out)
    midpoints <- breaks[-length(breaks)] + diff(breaks) / 2
    data[,xfocus] <- cut(data[,xfocus], breaks = breaks, labels = sprintf("%.2f", round(midpoints, 2)), include.lowest = TRUE)
  } else {
    data <- fit$data
  }

  # Sort the coefficients if required
  sort_order <- NULL
  if (sort_coefs) {
    coefs_1 <- coefs %>%
      group_by(.data$variable) %>%
      summarise(value = median(.data$value))
    coefs_s <- coefs_1 %>%
      arrange(.data$value) %>% 
      select(.data$variable)
    
    # reorder coefficients
    coefs$variable <- factor(coefs$variable, levels = coefs_s$variable)
    
    # reorder bubbles
    coefs_o <- match(coefs_1$variable, coefs_s$variable)
    bubble_o <- data.frame(xfocus = data[,xfocus], order = as.numeric(data[,xfocus])) %>%
      distinct()
    bubble_o <- bubble_o[match(coefs_o, bubble_o$order),]
    data[,xfocus] <- factor(data[,xfocus], levels = bubble_o$xfocus)
    sort_order <- bubble_o$xfocus
  }
  
  # Influence
  influ <- get_influ2(fit = fit, group = c(yfocus, xfocus), hurdle = hurdle)
  
  # Extract the legend on its own
  g2 <- function(a.gplot) {
    if (!is.gtable(a.gplot))
      a.gplot <- ggplotGrob(a.gplot)
    gtable_filter(a.gplot, 'guide-box', fixed = TRUE)
  }
  
  # The bubble plot (bottom-left) and the legend for the bubble plot (top-right)
  p3a <- plot_bubble(df = data, group = c(yfocus, xfocus), sum_by = sum_by, 
                     xlab = xlab, ylab = ylab, zlab = "", fill = colour, sort_order = sort_order)
  
  p2 <- g2(p3a)
  
  if (axis.text.x.bl) {
    p3 <- p3a + theme(legend.position = "none", plot.margin = margin(t = p_margin, r = p_margin, unit = "cm"), 
                      axis.text.x = element_text(angle = 45, hjust = 1))
  } else {
    p3 <- p3a + theme(legend.position = "none", plot.margin = margin(t = p_margin, r = p_margin, unit = "cm"), 
                      axis.text.x = element_blank())
  }
  
  # The coefficients (top-left)
  p1 <- ggplot(data = coefs, aes(x = .data$variable, y = .data$value)) +
    labs(x = NULL, y = y_coefs) +
    theme_bw() +
    theme(axis.title.x = element_blank(), 
          axis.text.x = element_blank(), axis.ticks.x = element_blank(),
          plot.margin = margin(b = p_margin, r = p_margin, unit = "cm"))
  
  if (is.numeric(coefs$variable)) {
    p3 <- p3 + scale_x_discrete(expand = expansion(mult = 0.05))
    
    p1 <- p1 +
      stat_summary(geom = "ribbon", alpha = 0.5, fill = colour, 
                   fun.min = function(x) quantile(x, probs = 0.025), 
                   fun.max = function(x) quantile(x, probs = 0.975)) +
      stat_summary(fun = "median", geom = "line", colour = colour) +
      scale_x_continuous(position = "top", breaks = midpoints, minor_breaks = NULL, expand = expansion(mult = 0.05)) +
      coord_cartesian(xlim = c(midpoints[1], midpoints[length(midpoints)]))
  } else {
    p1 <- p1 +
      # geom_point() +
      geom_violin(colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
      geom_hline(yintercept = 0, linetype = "dashed") +
      scale_x_discrete(position = "top")# +
      # scale_x_discrete(position = "top", breaks = midpoints, minor_breaks = NULL, expand = expansion(mult = 0.05)) +
      # coord_cartesian(xlim = c(midpoints[1], midpoints[length(midpoints)]))
  }
  
  # The influence plot (bottom-right)
  p4 <- ggplot(data = influ, aes_string(x = as.character(yfocus))) +
    geom_hline(yintercept = 1, linetype = "dashed") +
    geom_violin(aes(y = exp(.data$delta)), colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
    # geom_violin(aes(y = .data$delta), colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
    coord_flip() +
    scale_x_discrete(position = "top") +
    labs(x = NULL, y = "Influence") +
    theme_bw() +
    theme(legend.position = "none", plot.margin = margin(t = p_margin, l = p_margin, unit = "cm"))
  
  if (legend) {
    p <- p1 + p2 + p3 + p4 + plot_layout(nrow = 2, ncol = 2, heights = c(1, 2), widths = c(2, 1))
  } else {
    pv <- ggplot() + theme_void()
    p <- p1 + pv + p3 + p4 + plot_layout(nrow = 2, ncol = 2, heights = c(1, 2), widths = c(2, 1))
  }
  
  return(p)
}


#' Bayesian version of the CDI plot (depreciated)
#' 
#' @param fit a model fit
#' @param xfocus The column name of the variable to be plotted on the x axis. This column name must match one of the
#'   column names in the \code{data.frame} that was passed to \code{brm} as the \code{data} argument.
#' @param yfocus The column name of the variable to be plotted on the y axis. This column name must match one of the
#'   column names in the \code{data.frame} that was passed to \code{brm} as the \code{data} argument. This is generally the
#'   temporal variable in a generalised linear model (e.g. year).
#' @param hurdle if a hurdle model then use the hurdle
#' @param xlab the x axis label
#' @param ylab the y axis label
#' @param colour the colour to use in the plot
#' @return a ggplot object
#' 
#' @importFrom gtable is.gtable gtable_filter
#' @importFrom stats poly
#' @import ggplot2
#' @import dplyr
#' @import patchwork
#' @export
#' 
plot_bayesian_cdi2 <- function(fit,
                               xfocus = "area", yfocus = "fishing_year",
                              hurdle = FALSE,
                              xlab = "Month", 
                              ylab = "Fishing year", 
                              colour = "purple") {

  # Posterior samples of coefficients
  coefs <- get_coefs(fit = fit, var = xfocus, normalise = TRUE, hurdle = hurdle)
  n_iterations <- max(coefs$iteration)
  
  get_midpoint <- function(cut_label) {
    mean(as.numeric(unlist(strsplit(gsub("\\(|\\)|\\[|\\]", "", as.character(cut_label)), ","))))
  }
  
  # Model data
  is_poly <- FALSE
  if (any(grepl("poly", coefs$variable))) {
    is_poly <- TRUE
    data <- fit$data %>%
      select(-starts_with("poly"))
    dmin <- min(data[,xfocus])
    dmax <- max(data[,xfocus])
    data[,xfocus] <- cut(data[,xfocus], breaks = seq(dmin, dmax, length.out = 20), include.lowest = TRUE)
    # breaks <- unique(quantile(data[,xfocus], probs = seq(0, 1, length.out = 15)))
    # data[,xfocus] <- cut(data[,xfocus], breaks = breaks, include.lowest = TRUE)
    data[,xfocus] <- sapply(data[,xfocus], get_midpoint)

    z <- poly(fit$data[,xfocus], 3)
    x_new <- data.frame(id = 1:length(unique(data[,xfocus])), variable = sort(unique(data[,xfocus])))
    x_poly <- poly(x_new$variable, 3, coefs = attr(z, "coefs"))

    # Do the matrix multiplication
    Xbeta <- matrix(NA, nrow = n_iterations, ncol = nrow(x_poly))
    for (i in 1:n_iterations) {
      Xbeta[i,] <- x_poly %*% filter(coefs, .data$iteration == i)$value
    }
    coefs <- melt(Xbeta, varnames = c("iteration", "id")) %>%
      left_join(x_new, by = "id") %>%
      select(-id)
  } else if (length(unique(coefs$variable)) == 1) {
    data <- fit$data# %>%
      # select(xfocus)
    dmin <- min(data[,xfocus])
    dmax <- max(data[,xfocus])
    data[,xfocus] <- cut(data[,xfocus], breaks = seq(dmin, dmax, length.out = 20), include.lowest = TRUE)
    data[,xfocus] <- sapply(data[,xfocus], get_midpoint)
    
    x_new <- data.frame(id = 1:length(unique(data[,xfocus])), variable = sort(unique(data[,xfocus])))
    Xbeta <- matrix(NA, nrow = n_iterations, ncol = nrow(x_new))
    for (i in 1:n_iterations) {
      Xbeta[i,] <- as.matrix(x_new$variable) %*% filter(coefs, .data$iteration == i)$value
    }
    coefs <- melt(Xbeta, varnames = c("iteration", "id")) %>%
      left_join(x_new, by = "id") %>%
      select(-id)    
  } else {
    data <- fit$data %>%
      mutate_at(vars(matches(xfocus)), factor)
  }
  
  # Influence
  influ <- get_influ(fit = fit, group = c(yfocus, xfocus), hurdle = hurdle)
  
  if (nrow(fit$ranef) > 0) {
    ylab1 <- "Coefficient"
  } else {
    ylab1 <- "Relative coefficient"
  }

  # Extract the legend on its own
  g2 <- function(a.gplot) {
    if (!is.gtable(a.gplot))
      a.gplot <- ggplotGrob(a.gplot)
    gtable_filter(a.gplot, 'guide-box', fixed = TRUE)
  }

  # Build the plot
  sp <- 0.05
  
  # The coefficients (top-left)
  p1 <- ggplot(data = coefs, aes(x = factor(.data$variable), y = exp(.data$value))) +
    geom_hline(yintercept = 1, linetype = "dashed") +
    geom_violin(colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
    labs(x = NULL, y = ylab1) +
    scale_x_discrete(position = "top") +
    theme_bw() +
    theme(axis.title.x = element_blank(), axis.text.x = element_blank(), plot.margin = margin(b = sp, r = sp, unit = "cm"))
  
  # The bubble plot (bottom-left) and the legend for the bubble plot (top-right)
  p3a <- plot_bubble(df = data, group = c(yfocus, xfocus), sum_by = "row", xlab = xlab, ylab = ylab, zlab = "", fill = colour)
  p2 <- g2(p3a)
  p3 <- p3a +
    theme(legend.position = "none", plot.margin = margin(t = sp, r = sp, unit = "cm"), axis.text.x = element_text(angle = 45, hjust = 1))

  # The influence plot (bottom-right)
  p4 <- ggplot(data = influ, aes_string(x = as.character(yfocus))) +
    geom_hline(yintercept = 1, linetype = "dashed") +
    # geom_violin(aes(y = .data$delta), colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
    geom_violin(aes(y = exp(.data$delta)), colour = colour, fill = colour, alpha = 0.5, draw_quantiles = 0.5, scale = "width") +
    coord_flip() +
    scale_x_discrete(position = "top") +
    labs(x = NULL, y = "Influence") +
    theme_bw() +
    theme(legend.position = "none", plot.margin = margin(t = sp, l = sp, unit = "cm"))
  
  p1 + p2 + p3 + p4 + plot_layout(nrow = 2, ncol = 2, heights = c(1, 2), widths = c(2, 1))
}
quantifish/influ2 documentation built on Dec. 14, 2024, 5:10 a.m.