R/plotting.R

Defines functions imp_plot var_plot exp_plot

# Plotting emulator expectation
exp_plot <- function(em, plotgrid = NULL, ppd = 30) {
  ranges <- em$ranges
  if (is.null(plotgrid)) {
    plotgrid <- setNames(
      expand.grid(
        seq(ranges[[1]][1], ranges[[1]][2], length.out = ppd),
        seq(ranges[[2]][1], ranges[[2]][2], length.out = ppd)),
      names(ranges)[1:2])
    for (i in 3:length(ranges)) {
      plotgrid[[names(ranges)[i]]] <- sum(ranges[[i]])/2
    }
  }
  em_exp <- em$get_exp(plotgrid[,names(ranges)])
  grid_data <- setNames(cbind(plotgrid[,1:2], em_exp),
                        c(names(plotgrid)[1:2],"E"))
  g <- tryCatch(
    {
      ## To account for rounding problems in ggplot's pretty_isoband_levels
      rng <- range(grid_data$E)
      if (diff(rng) == 0) warning("All output values are identical.")
      ac <- signif(diff(rng), 1)/10
      rng[1] <- floor(rng[1]/ac)*ac
      rng[2] <- ceiling(rng[2]/ac)*ac
      bks <- seq(rng[1], rng[2], length.out = 26)
      x_pos <- as.integer(factor(grid_data[,1],
                                 levels = sort(unique(grid_data[,1]))))
      y_pos <- as.integer(factor(grid_data[,2],
                                 levels = sort(unique(grid_data[,2]))))
      raster <- matrix(NA_real_, nrow = max(y_pos), ncol = max(x_pos))
      raster[cbind(y_pos, x_pos)] <- grid_data$E
      ibs <- isoband::isobands(x = sort(unique(grid_data[,1])),
                               y = sort(unique(grid_data[,2])),
                               z = raster, levels_low = bks[-length(bks)],
                               levels_high = bks[-1])
      int_lo <- gsub(":.*$", "", names(ibs))
      int_hi <- gsub("^[^:]*:", "", names(ibs))
      lab_lo <- format(as.numeric(int_lo), digits = 3, trim = TRUE)
      lab_hi <- format(as.numeric(int_hi), digits = 3, trim = TRUE)
      lab_check <- sprintf("(%s, %s]", lab_lo, lab_hi)
      if (length(unique(lab_check)) == 1)
        warning("Can't produce contours natively due to internal accuracy issues.")
      bns <- min(length(unique(lab_check)), 25)
      ggplot(data = grid_data, aes(x = grid_data[,1],
                                   y = grid_data[,2])) +
        geom_contour_filled(aes(z = grid_data[,'E']),
                            bins = bns, colour = 'black') +
        scale_fill_viridis(discrete = TRUE, option = "magma",
                                    name = "exp",
                                    guide = guide_legend(ncol = 1))
    },
    warning = function(w) { #nocov start
      exp_breaks <- seq(min(em_exp) - diff(range(em_exp))/(2 * 23),
                        max(em_exp) + diff(range(em_exp))/(2*23),
                        length.out = 25)
      exp_breaks <- unique(signif(exp_breaks, 10))
      if (length(exp_breaks) == 1) exp_breaks <- c(min(em_exp)-1e-6,
                                                   max(em_exp)+1e-6)
      intervals <- findInterval(em_exp, exp_breaks)
      fake_breaks <- seq_along(exp_breaks)
      ggplot(data = grid_data, aes(x = grid_data[,1],
                                   y = grid_data[,2])) +
        geom_contour_filled(aes(z = intervals), breaks = fake_breaks,
                            colour = 'black') +
        scale_fill_viridis(discrete = TRUE, option = "magma",
                                    name = "exp",
                                    guide = guide_legend(ncol = 1),
                                    labels = function(b)
                                      {signif(
                                        exp_breaks[as.numeric(
                                          grep("\\d+", b, value = TRUE))],
                                        6)})
    } #nocov end
  )
  if (is.null(em$em_type)) extra_ident <- NULL
  else if (em$em_type == "mean") extra_ident <- "Mean"
  else if (em$em_type == "variance") extra_ident <- "Variance"
  else extra_ident <- em$em_type
  g <- g +
    labs(title = paste(em$output_name, extra_ident, "Emulator Expectation"),
         x = names(grid_data)[1], y = names(grid_data)[2]) +
    scale_x_continuous(expand = c(0,0)) +
    scale_y_continuous(expand = c(0,0)) +
    theme_minimal()
  return(g)
}

# Plotting emulator variance
var_plot <- function(em, plotgrid = NULL, ppd = 30, sd = FALSE) {
  ranges <- em$ranges
  if (is.null(plotgrid)) {
    plotgrid <- setNames(
      expand.grid(
        seq(
          ranges[[1]][1],
          ranges[[1]][2],
          length.out = ppd),
        seq(
          ranges[[2]][1],
          ranges[[2]][2],
          length.out = ppd)),
      names(ranges)[1:2])
    for (i in 3:length(ranges)) {
      plotgrid[[names(ranges)[i]]] <- sum(ranges[[i]])/2
    }
  }
  if(sd)
    em_cov <- sqrt(em$get_cov(plotgrid[,names(ranges)]))
  else
    em_cov <- em$get_cov(plotgrid[,names(ranges)])
  grid_data <- setNames(cbind(plotgrid[,1:2], em_cov),
                        c(names(plotgrid)[1:2], "V"))
  g <- tryCatch(
    {
      ## To account for rounding problems in ggplot's pretty_isoband_levels
      rng <- range(grid_data$V)
      if (diff(rng) == 0) warning("All output values are identical.")
      ac <- signif(diff(rng), 1)/10
      rng[1] <- floor(rng[1]/ac)*ac
      rng[2] <- ceiling(rng[2]/ac)*ac
      bks <- seq(rng[1], rng[2], length.out = 26)
      x_pos <- as.integer(factor(grid_data[,1],
                                 levels = sort(unique(grid_data[,1]))))
      y_pos <- as.integer(factor(grid_data[,2],
                                 levels = sort(unique(grid_data[,2]))))
      raster <- matrix(NA_real_, nrow = max(y_pos), ncol = max(x_pos))
      raster[cbind(y_pos, x_pos)] <- grid_data$V
      ibs <- isoband::isobands(x = sort(unique(grid_data[,1])),
                               y = sort(unique(grid_data[,2])),
                               z = raster, levels_low = bks[-length(bks)],
                               levels_high = bks[-1])
      int_lo <- gsub(":.*$", "", names(ibs))
      int_hi <- gsub("^[^:]*:", "", names(ibs))
      lab_lo <- format(as.numeric(int_lo), digits = 3, trim = TRUE)
      lab_hi <- format(as.numeric(int_hi), digits = 3, trim = TRUE)
      lab_check <- sprintf("(%s, %s]", lab_lo, lab_hi)
      if (length(unique(lab_check)) == 1)
        warning("Can't produce contours natively due to internal accuracy issues.")
      bns <- min(length(unique(lab_check)), 25)
      ggplot(data = grid_data, aes(x = grid_data[,1], y = grid_data[,2])) +
        geom_contour_filled(aes(z = grid_data[,'V']), bins = bns,
                            colour = 'black') +
        scale_fill_viridis(discrete = TRUE, option = "plasma",
                                    name = if(sd) "sd" else "var",
                                    guide = guide_legend(ncol = 1))
    },
    warning = function(w) {
      cov_breaks <- seq(
        min(em_cov) - diff(range(em_cov))/(2 * 23),
        max(em_cov) + diff(range(em_cov))/(2*23),
        length.out = 25)
      cov_breaks <- unique(signif(cov_breaks, 10))
      if (length(cov_breaks) == 1) cov_breaks <- c(min(em_cov)-1e-6,
                                                   max(em_cov)+1e-6)
      intervals <- findInterval(em_cov, cov_breaks)
      fake_breaks <- seq_along(cov_breaks)
      ggplot(data = grid_data, aes(x = grid_data[,1], y = grid_data[,2])) +
        geom_contour_filled(aes(z = intervals), breaks = fake_breaks,
                            colour = 'black') +
        scale_fill_viridis(discrete = TRUE, option = "plasma",
                                    name = if (sd) "sd" else "var",
                                    guide = guide_legend(ncol = 1),
                                    labels = function(b)
                                      {signif(
                                        cov_breaks[as.numeric(
                                          grep("\\d+", b, value = TRUE))],
                                        6)})
    }
  )
  if (is.null(em$em_type)) extra_ident <- NULL
  else if (em$em_type == "mean") extra_ident <- "Mean"
  else if (em$em_type == "variance") extra_ident <- "Variance"
  else extra_ident <- em$em_type
  g <- g +
    labs(title = paste(em$output_name, extra_ident, "Emulator",
                       (if(sd) "Standard Deviation" else "Variance")),
         x = names(grid_data)[1], y = names(grid_data)[2]) +
    scale_x_continuous(expand = c(0,0)) +
    scale_y_continuous(expand = c(0,0)) +
    theme_minimal()
  return(g)
}

# Plotting emulator implausibility
imp_plot <- function(em, z, plotgrid = NULL, ppd = 30, cb = FALSE, nth = NULL,
                     imp_breaks = NULL) {
  if (!is.null(nth)) {
    if (nth == 1) ns <- ""
    else if (nth == 2) ns <- "Second"
    else if (nth == 3) ns <- "Third"
    else ns <- paste0(nth, "th")
    if (!is.null(em$expectation)) ranges <- em$expectation[[1]]$ranges
    else if (!is.null(em$mode1)) ranges <- em$mode1$expectation[[1]]$ranges
    else ranges <- em[[1]]$ranges
  }
  else ranges <- em$ranges
  if (is.null(plotgrid)) {
    plotgrid <- setNames(
      expand.grid(
        seq(
          ranges[[1]][1],
          ranges[[1]][2],
          length.out = ppd),
        seq(ranges[[2]][1],
            ranges[[2]][2],
            length.out = ppd)),
      names(ranges)[1:2])
    for (i in 3:length(ranges)) {
      plotgrid[[names(ranges)[i]]] <- sum(ranges[[i]])/2
    }
  }
  if (is.null(imp_breaks)) {
    imp_breaks <- c(0, 0.3, 0.7, 1, 1.3, 1.7, 2, 2.3, 2.7,
                    3, 3.5, 4, 4.5, 5, 6, 7, 8, 10, 15, Inf)
    imp_names <- c(0, '', '', 1, '', '', 2, '', '', 3, '',
                  '', '', 5, '', '', '', 10, 15, '')
  } else {
    if (!is.numeric(imp_breaks)) stop("imp_breaks not a numeric vector.")
    if (any(is.na(imp_breaks))) stop("imp_breaks cannot contain missing values.")
    if (any(diff(imp_breaks) <= 0)) stop("imp_breaks must be in ascending order.")
    if (length(imp_breaks) != 20) stop("imp_breaks must be of length 20.")
    imp_names <- as.character(imp_breaks)
  }
  if (!is.null(nth)) {
    if (!inherits(em, "Emulator")) {
      em_imp <- nth_implausible(em, plotgrid[names(ranges)],
                                z, n = nth, max_imp = 99)
    }
    else
      stop(paste("Not all required parameters",
                 "(emulator list, target list, nth)",
                 "passed for nth maximum implausibility."))
  }
  else{
    em_imp <- em$implausibility(plotgrid[names(ranges)], z)
  }
  included <- c(map_lgl(imp_breaks[-1], ~any(em_imp < .)), TRUE)
  grid_data <- setNames(
    cbind(plotgrid[,1:2], em_imp), c(names(plotgrid)[1:2],"I"))
  col_scale <- if(cb) colourblind else redgreen
  g <- ggplot(data = grid_data, aes(x = grid_data[,1], y = grid_data[,2]))
  if (length(unique(em_imp)) == 1) {
    col <- col_scale[which.min(abs(imp_breaks - c(unique(em_imp))))]
    g <- g + geom_raster(aes(fill = grid_data[,"I"])) +
      scale_fill_gradient(low = col, high = col, name = "I")
  }
  else
    g <- g + geom_contour_filled(aes(z = grid_data[,"I"]),
                                 colour = 'black',
                                 breaks = imp_breaks[included]) +
    scale_fill_manual(values = col_scale[included], name = "I",
                      labels = imp_names[included],
                      guide = guide_legend(reverse = TRUE))
  if (is.null(em$em_type)) extra_ident <- NULL
  else if (em$em_type == "mean") extra_ident <- "Mean"
  else if (em$em_type == "variance") extra_ident <- "Variance"
  else extra_ident <- em$em_type
  g <- g + labs(title = paste(
    em$output_name,
    extra_ident,
    (if (is.null(nth)) "Emulator Implausibility" else paste(ns, "Maximum Implausibility"))),
    x = names(grid_data)[1], y = names(grid_data)[2]) +
    scale_x_continuous(expand = c(0,0)) +
    scale_y_continuous(expand = c(0,0)) +
    theme_minimal()
  return(g)
}

#' Plot Emulator Outputs
#'
#' A function for plotting emulator expectations, variances, and implausibilities
#'
#' Given a single emulator, or a set of emulators, the emulator statistics can be plotted
#' across a two-dimensional slice of the parameter space. Which statistic is plotted is
#' determined by \code{plot_type}: options are `exp', `var', `sd', `imp', and `nimp', which
#' correspond to expectation, variance, standard deviation, implausibility, and nth-max
#' implausibility.
#'
#' By default, the slice varies in the first two parameters of the emulators, and all other
#' parameters are taken to be fixed at their mid-range values. This behaviour can be changed
#' with the \code{params} and \code{fixed_vals} parameters (see examples).
#'
#' If the statistic is `exp', `var' or `sd', then the minimal set of parameters to pass to this
#' function are \code{ems} (which can be a list of emulators or a single one) and \code{plot_type}.
#' If the statistic is `imp' or `nimp', then the \code{targets} must be supplied - it is not
#' necessary to specify the individual target for a single emulator plot. If the statistic is
#' `nimp', then the level of maximum implausibility can be chosen with the parameter \code{nth}.
#'
#' Implausibility plots are typically coloured from green (low implausibility) to red (high
#' implausibility): a colourblind-friendly option is available and can be turned on by setting
#' \code{cb = TRUE}.
#'
#' The granularity of the plot is controlled by the \code{ppd} parameter, determining the number
#' of points per dimension in the grid. For higher detail, at the expense of longer computing
#' time, increase this value. The default is 30.
#'
#' @import ggplot2
#' @importFrom viridis scale_fill_viridis
#' @importFrom GGally ggally_text ggally_blank ggmatrix putPlot
#' @importFrom stringr str_pad
#'
#' @param ems An \code{\link{Emulator}} object, or a list thereof.
#' @param plot_type The statistic to plot (see description or examples).
#' @param ppd The number of points per plotting dimension
#' @param targets If required, the targets from which to calculate implausibility
#' @param cb A boolean representing whether a colourblind-friendly plot is produced.
#' @param params Which two input parameters should be plotted?
#' @param fixed_vals For fixed input parameters, the values they are held at.
#' @param nth If plotting nth maximum implausibility, which level maximum to plot.
#' @param imp_breaks If plotting nth maximum implausibility, defines the levels at
#'                   which to draw contours.
#' @param include_legend For multiple plots, should a combined legend be appended?
#'
#' @return A ggplot object, or collection thereof.
#'
#' @family visualisation tools
#' @export
#'
#' @examples
#'  # Reducing ppd to 10 for speed.
#'  emulator_plot(SIREmulators$ems, ppd = 10)
#'  emulator_plot(SIREmulators$ems$nS, ppd = 10)
#'  emulator_plot(SIREmulators$ems, plot_type = 'var', ppd = 10, params = c('aIR', 'aSR'))
#'  \donttest{ # Excessive runtime
#'     emulator_plot(SIREmulators$ems, plot_type = 'imp', ppd = 10,
#'      targets = SIREmulators$targets,
#'      fixed_vals = list(aSR = 0.02))
#'     emulator_plot(SIREmulators$ems, plot_type = 'nimp', cb = TRUE,
#'      targets = SIREmulators$targets, nth = 2, ppd = 10)
#'  }
#'
emulator_plot <- function(ems, plot_type = 'exp', ppd = 30, targets = NULL,
                          cb = FALSE, params = NULL, fixed_vals = NULL,
                          nth = 1, imp_breaks = NULL, include_legend = TRUE) {
  if (plot_type == "imp" || plot_type == "nimp")
    include_legend = FALSE
  if (inherits(ems, "Emulator")){
    ranges <- ems$ranges
    single_em <- TRUE
  }
  else {
    if (!is.null(ems$expectation)) ranges <- ems$expectation[[1]]$ranges
    else if (!is.null(ems$mode1)) ranges <- ems$mode1$expectation[[1]]$ranges
    else ranges <- ems[[1]]$ranges
    single_em <- FALSE
  }
  if (!is.null(ems$expectation)) ems <- ems$expectation
  if (is.null(params) ||
      length(params) != 2 ||
      any(!params %in% names(ranges))) p_vals <- c(1,2)
  else p_vals <- which(names(ranges) %in% params)
  plotgrid <- setNames(
    expand.grid(
      seq(
        ranges[[p_vals[1]]][1],
        ranges[[p_vals[1]]][2],
        length.out = ppd),
      seq(
        ranges[[p_vals[2]]][1],
        ranges[[p_vals[2]]][2],
        length.out = ppd)),
    names(ranges)[p_vals])
  if (!is.null(fixed_vals) && all(names(fixed_vals) %in% names(ranges))) {
    for (i in seq_along(fixed_vals))
      plotgrid[[names(fixed_vals)[i]]] <- fixed_vals[[names(fixed_vals)[i]]]
    used_names <- c(names(ranges)[p_vals], names(fixed_vals))
  }
  else used_names <- names(ranges)[p_vals]
  used_names <- unique(used_names)
  if (length(used_names) < length(ranges)) {
    unused_nms <- names(ranges)[which(!names(ranges) %in% used_names)]
    for (i in seq_along(unused_nms))
      plotgrid[[unused_nms[i]]] <- sum(ranges[[unused_nms[i]]])/2
  }
  get_plot <- function(em) {
    if (plot_type == 'exp') return(exp_plot(em, plotgrid, ppd))
    if (plot_type == 'var') return(var_plot(em, plotgrid, ppd))
    if (plot_type == 'sd') return(var_plot(em, plotgrid, ppd, sd = TRUE))
    if (plot_type == 'imp') {
      if (is.null(targets))
        stop("Cannot plot implausibility without target value.")
    }
    if (!is.null(targets$val)) return(imp_plot(em, targets, plotgrid, ppd, cb, NULL, imp_breaks))
    else return(imp_plot(em, targets[[em$output_name]], plotgrid, ppd, cb, NULL, imp_breaks))
  }
  if (single_em) return(get_plot(ems))
  if (plot_type == 'nimp') return(imp_plot(ems, targets, plotgrid, ppd, cb, nth, imp_breaks))
  else {
    plotlist <- map(ems, get_plot)
    replacement_function <- function(plots, title = NULL) {
      titles <- map_chr(
        plots,
        ~sub("(.*) Emulator (Expectation|Variance|Implausibility)",
             "\\1", .$labels$title))
      create_name_plot <- function(name) {
        return(ggally_text(label = name, colour = 'black') + theme_void())
      }
      plot_cols <- ceiling(sqrt(length(plots)))
      plot_rows <- ceiling(length(plots)/plot_cols)
      n_empty <- plot_cols^2 - length(plots)
      plot_list <- list()
      for (i in 1:(2*plot_rows)) {
        for (j in 1:plot_cols) {
          if (i %% 2 == 1) {
            in_bounds <- plot_cols*(i-1)/2+j
            if (in_bounds <= length(titles)) {
              plot_list[[length(plot_list)+1]] <- create_name_plot(titles[in_bounds])
            }
            else plot_list[[length(plot_list)+1]] <- ggally_blank()
          }
          else {
            in_bounds <- plot_cols*(i-2)/2+j
            if (in_bounds <= length(plots)) {
              plot_list[[length(plot_list)+1]] <- plots[[in_bounds]]
            }
            else plot_list[[length(plot_list)+1]] <- ggally_blank()
          }
        }
      }
      if (include_legend) {
        plt_labs <- map(plots, ~.$plot_env$bks)
        plt_labs_comb <- apply(do.call('cbind.data.frame', plt_labs), 1, function(x) {
          paste0(str_pad(round(x,4), 4, side = 'right', pad = " "), collapse = "      ")
        })
        if (plot_type == 'exp')
          plt_cols <- viridis(length(plt_labs[[1]]), option = "magma")
        else
          plt_cols <- viridis(length(plt_labs[[1]]), option = "plasma")
        fake_dat <- expand.grid(x = seq(1, 5), y = seq(1, 5))
        fake_dat$z <- (plt_labs[[1]][-1] + plt_labs[[1]][-length(plt_labs[[1]])])/2
        x <- y <- z <- NULL
        p_temp <- ggplot(data = fake_dat, aes(x = x, y = y, z = z)) +
          geom_contour_filled(colour = 'black', breaks = plt_labs[[1]]) +
          scale_fill_manual(name = paste0("      ", map_chr(ems, "output_name"), collapse = ""),
                            values = plt_cols, labels = plt_labs_comb) +
          theme(legend.key.size = unit(0.5, 'cm'))
        if (length(plots) == 3)
          p_temp <- p_temp + guides(fill = guide_legend(ncol = 2))
        else
          p_temp <- p_temp + guides(fill = guide_legend(ncol = 1))
        the_legend <- grab_legend(p_temp)
      }
      if (include_legend && (length(plots) == 4 || length(plots) == 2)) {
        plot_list <- append(plot_list, list(ggally_blank()), plot_cols)
        if (length(plots) == 4) {
          plot_list <- append(plot_list, list(ggally_blank()), 2*(plot_cols+1)-1)
          plot_list <- append(plot_list, list(ggally_blank()), 3*(plot_cols+1)-1)
        }
        plot_list[[length(plot_list)+1]] <- ggally_blank()
        main_plt <- ggmatrix(plot_list, ncol = plot_cols+1, nrow = 2*plot_rows,
                 xlab = plots[[1]]$labels$x, ylab = plots[[1]]$labels$y,
                 title = title,
                 xProportions = c(rep(1, plot_cols), 0.5),
                 yProportions = rep(c(0.05, 1), plot_rows),
                 progress = FALSE)
      }
      else {
        main_plt <- ggmatrix(plot_list, ncol = plot_cols, nrow = 2*plot_rows,
                 xlab = plots[[1]]$labels$x, ylab = plots[[1]]$labels$y,
                 title = title, yProportions = rep(c(0.05, 1), plot_rows),
                 progress = FALSE)
      }
      if (include_legend) {
        if (length(plots) == 2)
          main_plt <- putPlot(main_plt, the_legend, 2, 3)
        else if (length(plots) == 3)
          main_plt <- putPlot(main_plt, the_legend, 4, 2)
        else if (length(plots) == 4)
          main_plt <- putPlot(main_plt, the_legend, 4, 3)
      }
      return(main_plt)
    }
    if (plot_type == "exp") plot_title <- "Emulator Expectations"
    else if (plot_type == "var") plot_title <- "Emulator Variances"
    else if (plot_type == "imp") plot_title <- "Emulator Implausibilities"
    else plot_title <- NULL
    return(replacement_function(plotlist, plot_title))
  }
}

#' Emulator Expectation Against Target Outputs
#'
#' Plots emulator expectation across the parameter space, with comparison to the corresponding
#' target values (with appropriate uncertainty).
#'
#' If a \code{points} data.frame is not provided, then points are sampled uniformly from the
#' input region. Otherwise, the provided points are used: for example, if a representative
#' sample of the current NROY space is available.
#'
#' @importFrom ggplot2 ggplot aes labs geom_line geom_point geom_errorbar
#'
#' @param ems The \code{\link{Emulator}} objects.
#' @param targets A named list of observations, given in the usual form.
#' @param points A list of points at which the emulators should be evaluated.
#' @param npoints If no points are provided, the number of input points to evaluate at.
#'
#' @return A ggplot object
#'
#' @family visualisation tools
#' @export
#'
#' @examples
#'  output_plot(SIREmulators$ems, SIREmulators$targets)
#'  output_plot(SIREmulators$ems, SIREmulators$targets, points = SIRSample$training)
output_plot <- function(ems, targets, points = NULL, npoints = 1000) {
  ranges <- if (inherits(ems, "Emulator")) ems$ranges else ems[[1]]$ranges
  if (is.null(points)) {
    points <- data.frame(map(ranges, ~runif(npoints, .[1], .[2])))
  }
  em_exp <- setNames(
    data.frame(map(ems, ~.$get_exp(points))), names(targets))
  em_exp$run <- seq_len(nrow(points))
  em_exp <- reshape(em_exp, varying = seq_len(length(em_exp)-1),
                      times = seq_len(length(em_exp)-1), idvar = "run",
                      direction = "long", v.names = "values") |>
    setNames(c("run", "name", "value"))
  em_exp$name <- names(ems)[em_exp$name]
  #em_exp <- pivot_longer(em_exp, cols = !'run')
  for (i in seq_along(targets))
  {
    if (!is.atomic(targets[[i]]))
      targets[[i]] <- c(targets[[i]]$val - 3 * targets[[i]]$sigma,
                        targets[[i]]$val + 3*targets[[i]]$sigma)
  }
  target_data <- data.frame(label = names(targets),
                            mn = map_dbl(targets, ~.[1]),
                            md = map_dbl(targets, mean),
                            mx = map_dbl(targets, ~.[2]))
  name <- value <- run <- mn <- md <- mx <- label <- NULL
  em_exp$name <- factor(em_exp$name, levels = names(targets))
  ggplot(data = em_exp, aes(x = name, y = value)) +
    geom_line(colour = 'purple', aes(group = run), linewidth = 1) +
    geom_point(data = target_data, aes(x = label, y = md), size = 2) +
    geom_errorbar(data = target_data, aes(x = label, y = md, ymin = mn,
                                          ymax = mx), width = .1, linewidth = 1.25) +
    labs(title = "Emulator Runs versus Observations")
}

#' Plot Lattice of Emulator Implausibilities
#'
#' Plots a set of projections of the full-dimensional input space.
#'
#' The plots are:
#'
#' One dimensional optical depth plots (diagonal);
#'
#' Two dimensional optical depth plots (lower triangle);
#'
#' Two dimensional minimum implausibility plots (upper triangle).
#'
#' The optical depth is calculated as follows. A set of points is constructed across the
#' full d-dimensional parameter space, and implausibility is calculated at each point.
#' The points are collected into groups based on their placement in a projection to a
#' one- or two-dimensional slice of the parameter space. For each group, the proportion
#' of non-implausible points is calculated, and this value in [0,1] is plotted. The
#' minimum implausibility plots are similar, but with minimum implausibility calculated
#' rather than proportion of non-implausible points.
#'
#' The \code{maxpoints} argument is used as a cutoff for if a regular ppd grid would
#' result in a very large number of points. If this is the case, then \code{maxpoints} points
#' are sampled uniformly from the region instead of regularly spacing them.
#'
#' If only a subset of parameters are relevant, then the \code{plot_vars} and \code{fixed_vars}
#' can be used to specify the subset. If \code{plot_vars} is provided, corresponding to a list
#' of parameter names, then those parameters not included are fixed to their mid-range values;
#' if \var{fixed_vars} is provided as a named list, then the parameters not included are fixed
#' to the corresponding specified values.
#'
#' @importFrom stats xtabs
#' @importFrom GGally ggmatrix grab_legend
#' @import ggplot2
#'
#' @param ems The \code{\link{Emulator}} objects in question.
#' @param targets The corresponding target values.
#' @param ppd The number of points to sample per dimension.
#' @param cb Whether or not a colourblind-friendly plot should be produced.
#' @param cutoff The cutoff value for non-implausible points.
#' @param maxpoints The limit on the number of points to be evaluated.
#' @param imp_breaks If plotting nth maximum implausibility, defines the levels at
#'                   which to draw contours.
#' @param contour Logical determining whether to plot implausibility contours or not.
#' @param ranges Parameter ranges. If not supplied, defaults to emulator ranges.
#' @param raster_imp Should the implausibility plots be rasterised?
#' @param plot_vars If provided, indicates which subset of parameters to plot.
#' @param fixed_vars If provided, indicates the fixed value of the plot-excluded parameters.
#'
#' @return A ggplot object
#'
#' @family visualisation tools
#' @export
#'
#' @references Bower, Goldstein & Vernon (2010) <doi:10.1214/10-BA524>
#'
#' @examples
#' \donttest{ # Excessive runtime
#'  plot_lattice(SIREmulators$ems, SIREmulators$targets, ppd = 10)
#'  plot_lattice(SIREmulators$ems$nS, SIREmulators$targets)
#'  plot_lattice(SIREmulators$ems, SIREmulators$targets, plot_vars = c('aSI', 'aIR'))
#'  plot_lattice(SIREmulators$ems, SIREmulators$targets, fixed_vars = list(aSR = 0.03))
#' }
plot_lattice <- function(ems, targets, ppd = 20, cb = FALSE,
                         cutoff = 3, maxpoints = 5e4, imp_breaks = NULL,
                         contour = TRUE, ranges = NULL, raster_imp = FALSE,
                         plot_vars = NULL, fixed_vars = NULL) {
  ems <- collect_emulators(ems)
  if (is.null(ranges))
    ranges <- if (inherits(ems, "Emulator")) ems$ranges else ems[[1]]$ranges
  if (!is.null(plot_vars) && is.null(fixed_vars)) {
    pvars <- names(ranges)[!names(ranges) %in% plot_vars]
    fixed_vars <- purrr::map(pvars, ~mean(ranges[[.]])) |>
      setNames(pvars)
  }
  if (!is.null(fixed_vars))
    if (length(ranges) < length(fixed_vars)+2)
      stop("Fewer than two parameters are varying - cannot produce plot.")
  if (is.null(fixed_vars)) red_ranges <- ranges
  else red_ranges <- ranges[!names(ranges) %in% names(fixed_vars)]
  if (ppd^length(red_ranges) > maxpoints) {
    point_grid <- setNames(
      data.frame(
        do.call('cbind',
                map(red_ranges, ~runif(maxpoints, .[[1]], .[[2]])))),
      names(red_ranges))
    nbins <- 19
  }
  else {
    dim_bounds <- map(red_ranges, ~seq(.[[1]], .[[2]], length.out = ppd+1))
    dim_unif <- map(dim_bounds,
                           ~map_dbl(1:(length(.)-1),
                                           function(i) mean(.[i:(i+1)])))
    point_grid <- expand.grid(dim_unif)
  }
  if (!is.null(fixed_vars)) {
    fixed_df <- cbind.data.frame(purrr::map(names(fixed_vars), ~rep(fixed_vars[[.]], nrow(point_grid)))) |>
      setNames(names(fixed_vars))
    point_grid <- (cbind.data.frame(point_grid, fixed_df))[,names(ranges)]
  }
  point_grid$I <- nth_implausible(ems, point_grid, targets)
  one_dim <- function(data, parameter) {
    param_seq <- seq(
      red_ranges[[parameter]][1],
      red_ranges[[parameter]][2], length.out = ppd + 1)
    collection <- map(1:ppd, function(x) {
      valid_points <- data[data[,parameter] >= param_seq[x] &
                             data[,parameter] <= param_seq[x+1],]
      how_many_valid <- if (nrow(valid_points) == 0)
        NA
      else
        sum(valid_points$I <= cutoff)/nrow(valid_points)
      return(c(param_seq[x], how_many_valid))
    })
    setNames(do.call('rbind.data.frame', collection), c(parameter, 'op'))
  }
  two_dim <- function(data, parameters, op = FALSE) {
    param_seqs <- map(red_ranges[parameters],
                             ~seq(.[[1]], .[[2]], length.out = ppd + 1))
    param_list <- list()
    for (i in 1:ppd) {
      for (j in 1:ppd) {
        valid_points <- data[data[,parameters[1]] >= param_seqs[[1]][i] &
                               data[,parameters[1]] <= param_seqs[[1]][i+1] &
                               data[,parameters[2]] >= param_seqs[[2]][j] &
                               data[,parameters[2]] <= param_seqs[[2]][j+1],]
        if (op) {
          rel_stat <- if (nrow(valid_points) == 0)
            NA
          else
            sum(valid_points$I <= cutoff)/nrow(valid_points)
        }
        else {
          rel_stat <- if (nrow(valid_points) == 0) NA else min(valid_points$I)
        }
        param_list[[length(param_list)+1]] <- c(param_seqs[[1]][i],
                                                param_seqs[[2]][j], rel_stat)
      }
    }
    return(setNames(do.call('rbind.data.frame', param_list),
                    c(parameters, if(op) "op" else "imp")))
  }
  cols <- if(cb) colourblind else redgreen
  colscont <- if(cb) colourblindcont else redgreencont
  if (is.null(imp_breaks)) {
    imp_breaks <- c(0, 0.3, 0.7, 1, 1.3, 1.7, 2, 2.3, 2.7,
                    3, 3.5, 4, 4.5, 5, 6, 7, 8, 10, 15, Inf)
    imp_names <- c(0, '', '', 1, '', '', 2, '', '', 3, '',
                   '', '', 5, '', '', '', 10, 15, '')
  } else {
    if (!is.numeric(imp_breaks)) stop("imp_breaks not a numeric vector.")
    if (any(is.na(imp_breaks))) stop("imp_breaks cannot contain missing values.")
    if (any(diff(imp_breaks) <= 0)) stop("imp_breaks must be in ascending order.")
    if (length(imp_breaks) != 20) stop("imp_breaks must be of length 20.")
    imp_names <- as.character(imp_breaks)
  }
  parameter_combinations <- expand.grid(
    names(red_ranges),
    names(red_ranges),
    stringsAsFactors = FALSE)
  plot_list <- map(seq_len(nrow(parameter_combinations)), function(x) {
    parameters <- unlist(parameter_combinations[x,], use.names = FALSE)
    if (parameters[1] == parameters[2]) {
      pt <- one_dim(point_grid, parameters[1])
      g <- suppressMessages(ggplot(mapping = aes(x = pt[,1], y = pt[,2])) +
        geom_smooth(colour = 'black', se = FALSE) +
        scale_y_continuous(expand = c(0, 0), limits = c(0, 1)))
    }
    else if (which(
      names(red_ranges) == parameters[1]) >
      which(names(red_ranges) == parameters[2])) {
      pt <- two_dim(point_grid, parameters)
      g <- ggplot(mapping = aes(x = pt[,1], y = pt[,2]))
      if (!raster_imp)
        g <- g + geom_contour_filled(aes(z = pt[,3]), breaks = imp_breaks, colour = ifelse(contour, "black", NA)) +
        scale_fill_manual(values = cols, labels = imp_names,
                          name = "Min. I",
                          guide = guide_legend(reverse = TRUE), drop = FALSE)
      else
        g <- g + geom_raster(aes(fill = pt[,3]), interpolate = TRUE) +
        scale_fill_gradient2(low = colscont$low, mid = colscont$mid, high = colscont$high,
                             midpoint = 3, breaks = imp_breaks, labels = imp_names, name = "Min. I",
                             guide = guide_legend(reverse = TRUE))
      g <- g + scale_y_continuous(expand = c(0,0))
    }
    else {
      pt <- two_dim(point_grid, parameters, TRUE)
      g <- ggplot(mapping = aes(x = pt[,1], y = pt[,2], fill = pt[,3])) +
        geom_raster(interpolate = TRUE) +
        scale_fill_gradient(low = 'black', high = 'white',
                            breaks = seq(0, 1, by = 0.1),
                            name = "Op. Depth") +
        scale_y_continuous(expand = c(0,0))
    }
    return(g + scale_x_continuous(expand = c(0,0)) + theme_minimal())
  })
  x <- y <- z <- fill <- NULL
  pointless_data <- expand.grid(x = 1:10, y = 1:10)
  pointless_data$fill <- seq(0, 1, length.out = 100)
  pointless_data$z <- seq(0, 20, length.out = 100)
  pointless_plot <- ggplot(data = pointless_data, aes(x = x, y = y)) +
    geom_line(aes(colour = fill)) +
    geom_contour_filled(aes(z = z), breaks = imp_breaks) +
    scale_fill_manual(values = cols, labels = imp_names,
                      name = "Min. Imp", drop = FALSE) +
    scale_colour_gradient(low = 'black', high = 'white',
                          breaks = seq(0, 1, by = 0.1), name = "Op. Depth") +
    guides(fill = guide_legend(order = 1, reverse = TRUE))
  return(ggmatrix(plot_list, length(red_ranges), length(red_ranges),
                  title = "Minimum Implausibility and Optical Depth",
                  xAxisLabels = names(red_ranges), yAxisLabels = names(red_ranges),
                  showYAxisPlotLabels = FALSE,
                  legend = grab_legend(pointless_plot)))
}

#' Active variable plotting
#'
#' For a set of emulators, demonstrate which variables are active.
#'
#' Each emulator has a list of `active' variables; those which contribute in an appreciable way
#' to its regression surface. It can be instructive to examine the differences in active variables
#' for a give collection of emulators. The plot here produces an nxp grid for n emulators in p
#' inputs; a square is blacked out if that variable does not contribute to that output.
#'
#' Both the outputs and inputs can be restricted to collections of interest, if desired, with the
#' optional \code{output_names} and \code{input_names} parameters.
#'
#' @param ems The list of emulators to consider
#' @param output_names The names of the outputs to include in the plot, if not all
#' @param input_names The names of the inputs to include in the plot, if not all
#' @return A ggplot object corresponding to the plot
#'
#' @family visualisation tools
#' @export
#'
#' @examples
#'  plot_actives(SIREmulators$ems)
#'  # Remove the nR output and aIR input from the plot
#'  plot_actives(SIREmulators$ems, c('nS', 'nI'), c('aSI', 'aSR'))
#'  # Note that we can equally restrict the emulator list...
#'  plot_actives(SIREmulators$ems[c('nS', 'nI')], input_names = c('aSI', 'aSR'))
plot_actives <- function(ems, output_names = NULL, input_names = NULL) {
  if (inherits(ems, "Emulator")) {
    ems <- list(ems)
  }
  in_names <- names(ems[[1]]$ranges)
  active_list <- setNames(
    data.frame(
      do.call('rbind', map(ems, ~.$active_vars))), in_names)
  if (!is.null(input_names)) active_list <- active_list[, input_names, drop = FALSE]
  if (!is.null(output_names))
    active_list <- active_list[row.names(active_list) %in% output_names, , drop = FALSE]
  if (nrow(active_list) == 0 || length(active_list) == 0)
    stop("No inputs/outputs to plot.")
  pivoted <- stack(active_list) |> setNames(c("value", "Var2"))
  #pivoted <- pivot_longer(active_list, cols = everything(), names_to = "Var2")
  pivoted$Var1 <- rep(row.names(active_list), times = length(active_list))
  pivoted$value <- factor(pivoted$value, levels = c("FALSE", "TRUE"))
  pivoted$Var1 <- factor(pivoted$Var1, levels = map(ems, ~.$output_name))
  pivoted$Var2 <- factor(pivoted$Var2, levels = names(ems[[1]]$ranges))
  Var1 <- Var2 <- value <- NULL
  g <- ggplot(data = pivoted, aes(x = Var2, y = Var1, fill = value)) +
    geom_tile(colour = 'black') +
    scale_fill_manual(values = c('black', 'green'), drop = FALSE,
                      labels = c("FALSE", "TRUE"), name = "Active?") +
    labs(title = "Active variables", x = "Parameter", y = "Output") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))
  return(g)
}

#' Plot proposed points
#'
#' A wrapper around R's base plot to show proposed points
#'
#' Given a set of points proposed from emulators at a given wave, it's often useful to look at
#' how they are spread and where in parameter space they tend to lie relative to the original
#' ranges of the parameters. This function provides pairs plots of the parameters, with the
#' bounds of the plots calculated with respect to the parameter ranges provided.
#'
#' @param points The points to plot
#' @param ranges The parameter ranges
#' @param p_size The size of the plotted points (passed to \code{cex})
#'
#' @return The corresponding pairs plot
#'
#' @family visualisation tools
#' @export
#' @examples
#'  plot_wrap(SIRSample$training[,1:3], SIREmulators$ems[[1]]$ranges)
#'
plot_wrap <- function(points, ranges = NULL, p_size = 0.5) { #nocov start
  if (is.null(ranges))
    boundary_points <- setNames(
      do.call(
        'cbind.data.frame',
        map(names(points),
                   ~c(min(points[,.]), max(points[,.])))), names(points))
  else
    boundary_points <- setNames(
      do.call(
        'rbind.data.frame',
        map(
          1:2,
          ~map_dbl(ranges, function(x) x[[.]]))), names(ranges))
  plot(rbind(points, boundary_points),
       pch = 16, cex = p_size,
       col = c(rep('black', nrow(points)), 'white', 'white'))
} #nocov end
Tandethsquire/hmer documentation built on Oct. 25, 2024, 11:09 a.m.