R/loocv.R

Defines functions print.baggr_cv print.compare_baggr_cv loo_compare.baggr_cv loo_compare is.baggr_cv loocv

Documented in is.baggr_cv loo_compare loocv print.baggr_cv print.compare_baggr_cv

#' Leave one group out cross-validation for \code{baggr} models
#'
#' Performs exact leave-one-group-out cross-validation on a baggr model.
#'
#' @param data Input data frame - same as for [baggr] function.
#' @param return_models logical; if FALSE, summary statistics will be returned and the
#'                      models discarded;
#'                      if TRUE, a list of models will be returned alongside summaries
#' @param ... Additional arguments passed to [baggr].
#' @return log predictive density value, an object of class `baggr_cv`;
#' full model, prior values and _lpd_ of each model are also returned.
#' These can be examined by using `attributes()` function.
#' @seealso [loo_compare] for comparison of many LOO CV results
#' @details
#'
#' The values returned by `loocv()` can be used to understand how excluding
#' any one group affects the overall result, as well as how well the model
#' predicts the omitted group. LOO-CV approaches are a good general practice
#' for comparing Bayesian models, not only in meta-analysis.
#'
#' This function automatically runs _K_ baggr models, where _K_ is number of groups (e.g. studies),
#' leaving out one group at a time. For each run, it calculates
#' _expected log predictive density_ (ELPD) for that group (see Gelman et al 2013).
#' (In the logistic model, where the proportion in control group is unknown, each of
#' the groups is divided into data for controls, which is kept for estimation, and data for
#' treated units, which is not used for estimation but only for calculating predictive density.
#' This is akin to fixing the baseline risk and only trying to infer the odds ratio.)
#'
#' The main output is the cross-validation
#' information criterion, or -2 times the ELPD averaged over _K_ models.
#' This is related to, and often approximated by, the Watanabe-Akaike
#' Information Criterion. When comparing models, smaller values mean
#' a better fit. For more information on cross-validation see
#' [this overview article](http://www.stat.columbia.edu/~gelman/research/published/waic_understand3.pdf)
#'
#' For running more computation-intensive models, consider setting the
#' `mc.cores` option before running loocv, e.g. `options(mc.cores = 4)`
#' (by default baggr runs 4 MCMC chains in parallel).
#' As a default, rstan runs "silently" (`refresh=0`).
#' To see sampling progress, please set e.g. `loocv(data, refresh = 500)`.
#'
#' @examples
#' \dontrun{
#' # even simple examples may take a while
#' cv <- loocv(schools, pooling = "partial")
#' print(cv)      # returns the lpd value
#' attributes(cv) # more information is included in the object
#' }
#'
#' @author Witold Wiecek
#' @importFrom utils txtProgressBar
#' @importFrom utils setTxtProgressBar
#' @references
#' Gelman, Andrew, Jessica Hwang, and Aki Vehtari.
#' “Understanding Predictive Information Criteria for Bayesian Models.”
#' Statistics and Computing 24, no. 6 (November 2014): 997–1016.
#' @export

loocv <- function(data, return_models = FALSE, ...) {
  # Set prior, if not specified by the user
  args <- list(...)
  # Stop if pooling == "none"
  if(!is.null(args[["pooling"]]) && (args[["pooling"]] == "none"))
    stop("For a model with no pooling LOO CV doesn't exist.")

  # Set parallel up...
  # if(is.null(getOption("mc.cores"))){
  #   cat(paste0(
  #     "loocv() temporarily set options(mc.cores = parallel::detectCores()) \n"))
  #   temp_cores <- TRUE
  #   options(mc.cores = parallel::detectCores())
  # } else {
  #   temp_cores <- FALSE
  # }
  # temp_cores <- FALSE

  # Model with all of data:
  full_fit <- try(baggr(data, refresh = 0, ...))
  if(class(full_fit) == "try-error")
    stop("Inference failed for the model with all data, please try fitting outside of loocv()")

  # Prepare the arguments
  args[["data"]] <- data
  args[["model"]] <- full_fit$model
  # if(full_fit$model == "rubin_full")
    # stop("LOO CV is not implemented for full data models yet.")
  if(!("prior" %in% names(args))) {
    message("(Prior distributions taken from the model with all data. See $prior.)")
    args[["formatted_prior"]] <- full_fit$formatted_prior
  }

  # Determine number and names of groups
  if(args[["model"]] %in% c("rubin_full", "quantiles", "logit")) {
    #For individual-level data models we need to calculate N groups
    if(!is.null(args[["group"]]))
      group_col <- args[["group"]]
    else
      group_col <- "group"
    group_names <- unique(data[[group_col]])
    K <- length(group_names)

  } else {
    K <- nrow(data)
  }
  if(K > 50){
    message(paste("Large number of groups - note that the
                   model will be repeated", K, "times"))
  } else {
    message(paste0("Repeating baggr() for ", K, " separate models"))
  }

  # LOO CV models
  cat("\n")
  pb <- utils::txtProgressBar(style = 3)
  kfits <- lapply(as.list(1:K), function(i) {

    # Partitioning data into the fit data and LOO
    if(args[["model"]] %in% c("quantiles")) {
      args$data      <- data[data[[group_col]] != group_names[i], ]
      args$test_data <- data[data[[group_col]] == group_names[i], ]
    } else if(args[["model"]] %in% c("rubin_full", "logit")) {
      trt_column <- ifelse(is.null(args[["treatment"]]), "treatment", args[["treatment"]])
      args$data      <- data[data[[group_col]] != group_names[i] | data[[trt_column]] == 0, ]
      args$test_data <- data[data[[group_col]] == group_names[i] & data[[trt_column]] == 1, ]
    } else {
      args$data      <- data[-i,]
      args$test_data <- data[i,]
    }
    # Done partitioning data.

    utils::setTxtProgressBar(pb, (i-1)/K)

    # Run baggr models:
    if(is.null(args[["refresh"]]))
      args[["refresh"]] <- 0
    res <- do.call(baggr, args)

    # Sanitized version:
    # res <- try(do.call(baggr, args))
    # if(class(res) == "try-error") {
    #   stop(paste0("Inference failed for model number ", i, " out of ", K))
    # } else {
    #   return(res)
    # }
  })
  utils::setTxtProgressBar(pb, 1)
  close(pb)

  tau_estimate <-
    lapply(kfits, function(x) mean(treatment_effect(x)[[1]]))
  sd_estimate <-
    lapply(kfits, function(x) mean(treatment_effect(x)[[2]]))
  loglik <-
    lapply(kfits, function(x) apply(as.matrix(x$fit, "logpd[1]"), 2, mean))

  elpds <- unlist(loglik)

  out <- list(
    se = sqrt(length(elpds) * var(elpds)),
    elpd = sum(unlist(elpds)),
    looic = -2*sum(unlist(elpds)),
    df  = data.frame(
      "tau" = unlist(tau_estimate),
      "sigma_tau" = unlist(sd_estimate),
      "lpd" = unlist(loglik)),
    full_model = full_fit,
    prior = args[["prior"]],
    K = K,
    pointwise = elpds
  )
  if(return_models)
    out$models <- kfits
  class(out) <- "baggr_cv"

  # if(temp_cores)
    # options(mc.cores = NULL)

  return(out)
}

#' Check if something is a baggr_cv object
#' @param x object to check
is.baggr_cv <- function(x) {
  inherits(x, "baggr_cv")
}

#' Compare LOO CV models
#'
#' Given multiple [loocv] outputs, calculate differences in their expected log
#' predictive density.
#'
#' @param x An object of class `baggr_cv` or a list of such objects.
#' @param ... Additional objects of class "baggr_cv"
#' @export loo_compare
#' @seealso [loocv] for fitting LOO CV objects and explanation of the procedure
#' @examples
#' \dontrun{
#' # 2 models with more/less informative priors -- this will take a while to run
#' cv_1 <- loocv(schools, model = "rubin", pooling = "partial")
#' cv_2 <- loocv(schools, model = "rubin", pooling = "partial",
#'               prior_hypermean = normal(0, 5), prior_hypersd = cauchy(0,4))
#' loo_compare(cv_1, cv_2)
#' }
#' @export
loo_compare <- function(x, ...) {
  UseMethod("loo_compare")
}

#' @aliases loo_compare
#' @export
loo_compare.baggr_cv <- function(x, ...) {
  if (is.baggr_cv(x)) {
    dots <- list(...)
    loos <- c(list(x), dots)
  } else {
    if (!is.list(x) || !length(x)) {
      stop("'x' must be a list if not a 'loo' object.")
    }
    loos <- x
  }
  if (!all(sapply(loos, is.baggr_cv))) {
    stop("All inputs should have class 'baggr_cv'.")
  }
  if (length(loos) <= 1L) {
    stop("'loo_compare' requires at least two models.")
  }
  Ns <- sapply(loos, function(x) nrow(x$df))
  if (!all(Ns == Ns[1L])) {
    stop("Not all models have the same number of data points.")
  }

  elpds <- lapply(loos, function(x) x$pointwise)
  comp <- Reduce(cbind, elpds)

  diffs <- list()

  for(i in 2:ncol(comp)) {
    diffname <- paste0("Model 1", " - ", "Model ", i)
    rawdiffs <- comp[,1] - comp[,i]
    diffs[[i-1]] <-
      matrix(nrow = 1, ncol = 2,
             c(sum(rawdiffs),
                     sqrt(length(rawdiffs))*sd(rawdiffs)),
               dimnames = list(diffname, c("ELPD", "ELPD SE")))
  }

  diffs <- Reduce(rbind, diffs)
  class(diffs) <- c("compare_baggr_cv", class(comp))
  diffs
}

#' Print baggr_cv comparisons
#' @param x baggr_cv comparison to print
#' @param digits number of digits to print
#' @param ... additional arguments for s3 consistency
#' @importFrom testthat capture_output
#' @importFrom crayon bold
#' @export
print.compare_baggr_cv <- function(x, digits = 3, ...) {
  mat <- as.matrix(x)
  class(mat) <- "matrix"
  cat(
    crayon::bold(paste0("Comparison of cross-validation\n")),
    "\n",
    testthat::capture_output(print(signif(mat, digits = digits)))
  )
}

#' Print baggr cv objects nicely
#' @param x baggr_cv object to print
#' @param digits number of digits to print
#' @param ... additional arguments for s3 consistency
#' @importFrom testthat capture_output
#' @importFrom crayon bold
#' @export
print.baggr_cv <- function(x, digits = 3, ...) {

  mat <- matrix(nrow = 2, ncol = 2)

  mat[1,] <- c(x$elpd, x$se)
  mat[2,] <- c(x$looic, -2*x$se)

  colnames(mat) <- c("Estimate", "Standard Error")
  rownames(mat) <- c("elpd", "looic")

  cat(
    crayon::bold(paste0("Based on ", x$K, "-fold cross-validation\n")),
    "\n",
    testthat::capture_output(print(signif(mat, digits = digits)))
  )

}

Try the baggr package in your browser

Any scripts or data that you put into this service are public.

baggr documentation built on Sept. 5, 2021, 5:46 p.m.