R/forecasting.R

Defines functions crps daily_coverage daily_error plot_metrics plot_coverage posterior_metrics posterior_coverage evaluate_forecast

Documented in evaluate_forecast plot_coverage plot_metrics posterior_coverage posterior_metrics

#' Posterior model evaluations
#'
#' Calculate daily error using one of three metrics, and also return coverage
#' of credible intervals. Uses continuous ranked probability
#' score (CRPS), mean absolute error and median absolute error.
#'
#' @inherit plot_obs params
#' @param newdata  If provided, the original \code{data} used
#'  in \code{object} is overridden. Useful for forecasting
#' @param metrics A string or character vector specifying the plotted
#'  forecast error metrics. One of \code{NULL}, \code{"crps"},
#'  \code{"mean_abs_error"}
#' @return A named list with dataframes giving metrics and coverage.
#' @export
evaluate_forecast <-
  function(object,
           newdata = NULL,
           type,
           groups = NULL,
           metrics = NULL,
           levels = c(50, 95)) {
    if (is.null(type)) {
      stop("must specify an observation type")
    }
    alltypes <- sapply(object$obs, function(x) .get_obs(formula(x)))
    w <- which(type %in% alltypes)
    if (length(w) == 0) {
      stop(paste0("obs does not contain any observations
    for type '", type, "'"), call. = FALSE)
    }

    ok_metrics <- c("crps", "mean_abs_error", "median_abs_error")
    metrics <- metrics %ORifNULL% ok_metrics
    if (any(!(metrics %in% ok_metrics))) {
      stop("Unrecognised metrics. Allowed metrics include ",
        paste(ok_metrics, collapse = ", "),
        call. = FALSE
      )
    }
    levels <- check_levels(levels)

    # process data
    groups <- groups %ORifNULL% object$groups
    # simulate from posterior predictive
    obs <- posterior_predict(
      object = object,
      types = type,
      newdata = newdata
    )

    obs <- gr_subset(obs, groups)

    if (is.null(newdata)) {
      data <- object$data
      data <- data[data$group %in% groups, ]
    } else {
      check_data(newdata, object$rt, object$inf, object$obs, object$groups)
      data <- parse_data(newdata, object$rt, object$inf, object$obs, object$groups)
    }    

    # get observed outcomes
    obj <- epiobs_(object$obs[[w]], data)
    y <- get_obs(obj)

    return(list(
      error = daily_error(obs, metrics, y), 
      coverage = daily_coverage(obs, levels, y))
      )
  }

#' Coverage of posterior credible intervals
#'
#' @inherit evaluate_forecast
#' @return A dataframe indicating whether observations fall within the
#'  specified credible intervals
#' @export
posterior_coverage <-
  function(object,
           type,
           newdata = NULL,
           groups = NULL,
           levels = c(50, 95)) {
    out <- evaluate_forecast(
      object = object,
      type = type,
      newdata = newdata,
      groups = groups,
      levels = levels
    )
    return(out$coverage)
}

#' CRPS, Mean Absolute Error, Median Absolute Error
#' 
#' @inherit evaluate_forecast
#' @return A dataframe giving forecast error for each metric and observation
#' @export
posterior_metrics <-
  function(object,
           type,
           newdata = NULL,
           groups = NULL,
           metrics = NULL) {
    out <- evaluate_forecast(
      object = object,
      type = type,
      newdata = newdata,
      groups = groups,
      metrics = metrics
    )
    return(out$error)
  }

#' Plot coverage probability of posterior credible intervals
#'
#' Plots histograms showing empirical coverage of credible intervals
#' specified using 'levels'. Can bucket by time period, by group, by
#' whether the observation is new (not used in fitting).
#'
#' @inherit evaluate_forecast params
#' @inherit plot_obs params
#' @param period Buckets computed empirical probabilities into time periods
#' if specified.
#' @param by_group Plot coverage for each group individually
#' @param by_unseen Plot coverage separately for seen and unseen observations.
#' Observations are 'seen' if they were used for fitting.
#' @export
#' @return A \code{ggplot} object. 
plot_coverage <-
  function(object,
           type,
           newdata = NULL,
           groups = NULL,
           levels = c(50, 95),
           period = NULL,
           by_group = FALSE,
           by_unseen = FALSE) {
    
    groups <- groups %ORifNULL% object$groups
    cov <- posterior_coverage(
      object = object,
      type = type,
      groups = groups,
      newdata = newdata,
      levels = levels
    )
    
    if (!is.null(period)) {
      cov$period <- cut(cov$date, period)
    }
    
    cols <- c(
      "tag",
      if (!is.null(period)) "period",
      if (by_group) "group",
      if (by_unseen) "unseen"
    )
    
    if (by_unseen) { # need to check which observations are new
      data <- object$data
      data <- data[data$group %in% groups, c("group", "date", type)]
      data <- data %>% dplyr::rename("DUMMY" = type)
      cov <- dplyr::left_join(cov, data, by = c("group", "date"))
      cov <- cov %>% dplyr::rename("unseen" = .data$DUMMY)
      w <- is.na(cov$unseen)
      cov$unseen[w] <- "Unseen"
      cov$unseen[!w] <- "Seen"
    }
    
    df <- cov %>%
      dplyr::group_by_at(cols) %>%
      dplyr::summarise(value = mean(.data$in_ci))
    
    if (is.null(period)) {
      p <- ggplot2::ggplot(
        df,
        ggplot2::aes(x = .data$tag, y = .data$value, fill = .data$tag)
      ) +
        ggplot2::labs(
          y = "Mean Coverage",
          x = "Credible Interval"
        )
    } else {
      p <- ggplot2::ggplot(
        df,
        ggplot2::aes(x = period, y = .data$value, fill = .data$tag)
      ) +
        ggplot2::labs(
          y = "Mean Coverage",
          x = "period"
        )
    }
    
    # general formatting
    p <- p + ggplot2::geom_bar(
      stat = "identity",
      position = "dodge"
    ) +
      ggplot2::scale_y_continuous(
        labels = scales::percent_format(),
        minor_breaks = seq(0, 1, 0.05),
        breaks = seq(0, 1, 0.1)
      ) +
      hrbrthemes::theme_ipsum() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(angle = 50, vjust = 0.5)
      )
    
    if ("group" %in% cols && "unseen" %in% cols) {
      p <- p + ggplot2::facet_grid(ggplot2::vars(.data$group), ggplot2::vars(.data$unseen))
    } else if ("group" %in% cols) {
      p <- p + ggplot2::facet_wrap(.data$group)
    } else if ("unseen" %in% cols) {
      p <- p + ggplot2::facet_wrap(.data$unseen)
    }
    
    p <- p +
      ggplot2::scale_fill_manual(
        name = "Fill",
        values = ggplot2::alpha(
          "deepskyblue4",
          rev(levels) / 100
        )
      )
    
    return(p)
  }


#' Plot CRPS, Median/Mean Absolute Error
#'
#' Plots various metrics for evaluating probabilistic forecasts by group.
#'
#' @inherit evaluate_forecast params
#' @inherit plot_rt return
#' @export
plot_metrics <-
  function(object,
           groups = NULL,
           type,
           metrics = NULL,
           newdata = NULL) {
    groups <- groups %ORifNULL% object$groups
    
    df <- posterior_metrics(
      object = object,
      type = type,
      groups = groups,
      newdata = newdata,
      metrics = metrics
    )
    metrics <- colnames(df)[colnames(df) %in% c("crps", "mean_abs_error", "median_abs_error")]
    
    df <- df %>%
      tidyr::pivot_longer(
        c(dplyr::all_of(metrics)),
        names_to = "metric",
        values_to = "value"
      )
    
    data <- object$data
    data <- data[data$group %in% groups, c("group", "date", type)]
    data <- data %>% dplyr::rename("DUMMY" = type)
    df <- dplyr::left_join(df, data, by = c("group", "date"))
    df <- df %>% dplyr::rename("unseen" = .data$DUMMY)
    w <- is.na(df$unseen)
    df$unseen[w] <- "Unseen"
    df$unseen[!w] <- "Seen"
    
    p <- ggplot2::ggplot(
      df,
      ggplot2::aes(
        x = .data$date,
        y = .data$value,
        linetype = .data$metric,
        color = .data$unseen
      )
    ) +
      ggplot2::geom_line(alpha = 0.7, size = 0.8) +
      ggplot2::facet_wrap(
        .data$group,
        scales = "free_y"
      ) +
      ggplot2::labs(
        y = "Value",
        x = "Date",
        linetype = "Metric"
      ) +
      hrbrthemes::theme_ipsum() +
      ggplot2::theme(legend.position = "right")
    
    p <- p + ggplot2::scale_color_manual(
      values = c("coral4", "darkslategray4")
    )
    
    return(p)
  }

daily_error <- function(obs, metrics, y) {
  draws <- obs$draws
  mat <- (abs(sweep(t(draws), 1, y)))
  out <- data.frame(
    group = obs$group,
    date = obs$time
  )
  if ("crps" %in% metrics)
    out$crps <- crps(y, t(draws))

  if ("mean_abs_error" %in% metrics)
    out$mean_abs_error <- rowMeans(mat)
  
  if ("median_abs_error" %in% metrics)
    out$median_abs_error <- apply(mat, 1, median)

  return(out)
}

daily_coverage <- function(obs, levels, y) {
  f <- function(level) {
    qtl <- get_quantiles(obs, level)
    out <- data.frame(
      group = obs$group,
      date = qtl$date,
      tag = qtl$tag[1],
      in_ci = (qtl$lower <= y) * (y <= qtl$upper)
    )
  return(out)
  }
  dfs <- lapply(levels, f)
  return(do.call(rbind, dfs))
}

crps <- function(y, dat) {
  c_1n <- 1/length(dat)
  x <- sort(dat)
  a <- seq.int(0.5 * c_1n, 1 - 0.5 * c_1n, length.out = length(dat))
  f <- function(s) 2 * c_1n * sum(((s < x) - a) * (x - s))
  sapply(y, f)
}

Try the epidemia package in your browser

Any scripts or data that you put into this service are public.

epidemia documentation built on Oct. 25, 2021, 9:09 a.m.