R/loo.R

Defines functions ps_khat_threshold print.iclist compare_ic add_ic.brmsfit add_ic add_waic add_loo print.loolist get_chain_id r_eff_log_lik.function r_eff_log_lik.matrix r_eff_log_lik r_eff_helper subset.psis recommend_loo_options validate_models match_nobs match_response hash_response get_criterion add_criterion.brmsfit add_criterion loo_model_weights.brmsfit loo_compare.brmsfit prepare_loo_args psis.brmsfit .psis .waic .loo loo_criteria compute_loo compute_loolist WAIC WAIC.brmsfit waic.brmsfit LOO LOO.brmsfit loo.brmsfit

Documented in add_criterion add_criterion.brmsfit add_ic add_ic.brmsfit add_loo add_waic compare_ic LOO loo.brmsfit LOO.brmsfit loo_compare.brmsfit loo_model_weights.brmsfit psis.brmsfit WAIC waic.brmsfit WAIC.brmsfit

#' Efficient approximate leave-one-out cross-validation (LOO)
#'
#' Perform approximate leave-one-out cross-validation based
#' on the posterior likelihood using the \pkg{loo} package.
#' For more details see \code{\link[loo:loo]{loo}}.
#'
#' @aliases loo LOO LOO.brmsfit
#'
#' @param x A \code{brmsfit} object.
#' @param ... More \code{brmsfit} objects or further arguments
#'   passed to the underlying post-processing functions.
#'   In particular, see \code{\link{prepare_predictions}} for further
#'   supported arguments.
#' @param compare A flag indicating if the information criteria
#'  of the models should be compared to each other
#'  via \code{\link{loo_compare}}.
#' @param pointwise A flag indicating whether to compute the full
#'  log-likelihood matrix at once or separately for each observation.
#'  The latter approach is usually considerably slower but
#'  requires much less working memory. Accordingly, if one runs
#'  into memory issues, \code{pointwise = TRUE} is the way to go.
#' @param moment_match Logical; Indicate whether \code{\link{loo_moment_match}}
#'  should be applied on problematic observations. Defaults to \code{FALSE}.
#'  For most models, moment matching will only work if you have set
#'  \code{save_pars = save_pars(all = TRUE)} when fitting the model with
#'  \code{\link{brm}}. See \code{\link{loo_moment_match.brmsfit}} for more
#'  details.
#' @param reloo Logical; Indicate whether \code{\link{reloo}}
#'  should be applied on problematic observations. Defaults to \code{FALSE}.
#' @param k_threshold The Pareto \eqn{k} threshold for which observations
#'   \code{\link{loo_moment_match}} or \code{\link{reloo}} is applied if
#'   argument \code{moment_match} or \code{reloo} is \code{TRUE}.
#'   Defaults to \code{0.7}.
#'   See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details.
#' @param save_psis Should the \code{"psis"} object created internally be saved
#'   in the returned object? For more details see \code{\link[loo:loo]{loo}}.
#' @param moment_match_args Optional named \code{list} of additional arguments
#'  passed to \code{\link{loo_moment_match}}.
#' @param reloo_args Optional named \code{list} of additional arguments passed to
#'   \code{\link{reloo}}. This can be useful, among others, to control
#'   how many chains, iterations, etc. to use for the fitted sub-models.
#' @param model_names If \code{NULL} (the default) will use model names
#'   derived from deparsing the call. Otherwise will use the passed
#'   values as model names.
#' @inheritParams predict.brmsfit
#'
#' @details See \code{\link{loo_compare}} for details on model comparisons.
#'  For \code{brmsfit} objects, \code{LOO} is an alias of \code{loo}.
#'  Use method \code{\link{add_criterion}} to store
#'  information criteria in the fitted model object for later usage.
#'
#' @return If just one object is provided, an object of class \code{loo}.
#'  If multiple objects are provided, an object of class \code{loolist}.
#'
#' @examples
#' \dontrun{
#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#'             data = inhaler)
#' (loo1 <- loo(fit1))
#'
#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#'             data = inhaler)
#' (loo2 <- loo(fit2))
#'
#' # compare both models
#' loo_compare(loo1, loo2)
#' }
#'
#' @references
#' Vehtari, A., Gelman, A., & Gabry J. (2016). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC. In Statistics
#' and Computing, doi:10.1007/s11222-016-9696-4. arXiv preprint arXiv:1507.04544.
#'
#' Gelman, A., Hwang, J., & Vehtari, A. (2014).
#' Understanding predictive information criteria for Bayesian models.
#' Statistics and Computing, 24, 997-1016.
#'
#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation
#' and widely applicable information criterion in singular learning theory.
#' The Journal of Machine Learning Research, 11, 3571-3594.
#'
#' @importFrom loo loo is.loo
#' @export loo
#' @export
loo.brmsfit <-  function(x, ..., compare = TRUE, resp = NULL,
                         pointwise = FALSE, moment_match = FALSE,
                         reloo = FALSE, k_threshold = 0.7, save_psis = FALSE,
                         moment_match_args = list(), reloo_args = list(),
                         model_names = NULL) {
  args <- split_dots(x, ..., model_names = model_names)
  if (!"use_stored" %in% names(args)) {
    further_arg_names <- c(
      "resp", "moment_match", "reloo", "k_threshold",
      "save_psis", "moment_match_args", "reloo_args"
    )
    args$use_stored <- all(names(args) %in% "models") &&
      !any(further_arg_names %in% names(match.call()))
  }
  c(args) <- nlist(
    criterion = "loo", pointwise, compare,
    resp, k_threshold, save_psis, moment_match,
    reloo, moment_match_args, reloo_args
  )
  do_call(compute_loolist, args)
}

#' @export
LOO.brmsfit <- function(x, ..., compare = TRUE, resp = NULL,
                        pointwise = FALSE, moment_match = FALSE,
                        reloo = FALSE, k_threshold = 0.7, save_psis = FALSE,
                        moment_match_args = list(), reloo_args = list(),
                        model_names = NULL) {
  cl <- match.call()
  cl[[1]] <- quote(loo)
  eval(cl, parent.frame())
}

#' @export
LOO <- function(x, ...) {
  UseMethod("LOO")
}

#' Widely Applicable Information Criterion (WAIC)
#'
#' Compute the widely applicable information criterion (WAIC)
#' based on the posterior likelihood using the \pkg{loo} package.
#' For more details see \code{\link[loo:waic]{waic}}.
#'
#' @aliases waic WAIC WAIC.brmsfit
#'
#' @inheritParams loo.brmsfit
#'
#' @details See \code{\link{loo_compare}} for details on model comparisons.
#'  For \code{brmsfit} objects, \code{WAIC} is an alias of \code{waic}.
#'  Use method \code{\link[brms:add_criterion]{add_criterion}} to store
#'  information criteria in the fitted model object for later usage.
#'
#' @return If just one object is provided, an object of class \code{loo}.
#'  If multiple objects are provided, an object of class \code{loolist}.
#'
#' @examples
#' \dontrun{
#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#'             data = inhaler)
#' (waic1 <- waic(fit1))
#'
#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#'             data = inhaler)
#' (waic2 <- waic(fit2))
#'
#' # compare both models
#' loo_compare(waic1, waic2)
#' }
#'
#' @references
#' Vehtari, A., Gelman, A., & Gabry J. (2016). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC. In Statistics
#' and Computing, doi:10.1007/s11222-016-9696-4. arXiv preprint arXiv:1507.04544.
#'
#' Gelman, A., Hwang, J., & Vehtari, A. (2014).
#' Understanding predictive information criteria for Bayesian models.
#' Statistics and Computing, 24, 997-1016.
#'
#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation
#' and widely applicable information criterion in singular learning theory.
#' The Journal of Machine Learning Research, 11, 3571-3594.
#'
#' @importFrom loo waic
#' @export waic
#' @export
waic.brmsfit <- function(x, ..., compare = TRUE, resp = NULL,
                         pointwise = FALSE, model_names = NULL) {
  args <- split_dots(x, ..., model_names = model_names)
  if (!"use_stored" %in% names(args)) {
    further_arg_names <- c("resp")
    args$use_stored <- all(names(args) %in% "models") &&
      !any(further_arg_names %in% names(match.call()))
  }
  c(args) <- nlist(criterion = "waic", pointwise, compare, resp)
  do_call(compute_loolist, args)
}

#' @export
WAIC.brmsfit <- function(x, ..., compare = TRUE, resp = NULL,
                         pointwise = FALSE, model_names = NULL) {
  cl <- match.call()
  cl[[1]] <- quote(waic)
  eval(cl, parent.frame())
}

#' @export
WAIC <- function(x, ...) {
  UseMethod("WAIC")
}

# helper function used to create (lists of) 'loo' objects
# @param models list of brmsfit objects
# @param criterion name of the criterion to compute
# @param use_stored use precomputed criterion objects if possible?
# @param compare compare models using 'loo_compare'?
# @param ... more arguments passed to compute_loo
# @return If length(models) > 1 an object of class 'loolist'
#   If length(models) == 1 an object of class 'loo'
compute_loolist <- function(models, criterion, use_stored = TRUE,
                            compare = TRUE, ...) {
  criterion <- match.arg(criterion, loo_criteria())
  args <- nlist(criterion, ...)
  for (i in seq_along(models)) {
    models[[i]] <- restructure(models[[i]])
  }
  if (length(models) > 1L) {
    if (!match_nobs(models)) {
      stop2("Models have different number of observations.")
    }
    if (length(use_stored) == 1L) {
      use_stored <- rep(use_stored, length(models))
    }
    out <- list(loos = named_list(names(models)))
    for (i in seq_along(models)) {
      args$x <- models[[i]]
      args$model_name <- names(models)[i]
      args$use_stored <- use_stored[i]
      out$loos[[i]] <- do_call(compute_loo, args)
    }
    compare <- as_one_logical(compare)
    if (compare) {
      out$diffs <- loo_compare(out$loos)
      # for backwards compatibility; remove in brms 3.0
      out$ic_diffs__ <- SW(compare_ic(x = out$loos))$ic_diffs__
    }
    class(out) <- "loolist"
  } else {
    args$x <- models[[1]]
    args$model_name <- names(models)
    args$use_stored <- use_stored
    out <- do_call(compute_loo, args)
  }
  out
}

# compute model fit criteria using the 'loo' package
# @param x an object of class brmsfit
# @param criterion the criterion to be computed
# @param newdata optional data.frame of new data
# @param resp optional names of the predicted response variables
# @param model_name original variable name of object 'x'
# @param use_stored use precomputed criterion objects if possible?
# @param ... passed to the individual methods
# @return an object of class 'loo'
compute_loo <- function(x, criterion, newdata = NULL, resp = NULL,
                        model_name = "", use_stored = TRUE, ...) {
  criterion <- match.arg(criterion, loo_criteria())
  model_name <- as_one_character(model_name)
  use_stored <- as_one_logical(use_stored)
  out <- get_criterion(x, criterion)
  if (is.loo(out) && !use_stored) {
    message("Recomputing '", criterion, "' for model '", model_name, "'")
  }
  if (!is.loo(out) || !use_stored) {
    args <- nlist(x, newdata, resp, model_name, ...)
    out <- do_call(paste0(".", criterion), args)
    attr(out, "yhash") <- hash_response(x, newdata = newdata, resp = resp)
  }
  attr(out, "model_name") <- model_name
  out
}

# possible criteria to evaluate via the loo package
loo_criteria <- function() {
  c("loo", "waic", "psis", "kfold", "loo_subsample")
}

# compute 'loo' criterion using the 'loo' package
.loo <- function(x, pointwise, k_threshold, moment_match, reloo,
                 moment_match_args, reloo_args, newdata,
                 resp, model_name, save_psis, ...) {
  loo_args <- prepare_loo_args(
    x, newdata = newdata, resp = resp,
    pointwise = pointwise, save_psis = save_psis,
    ...
  )
  out <- SW(do_call("loo", loo_args, pkg = "loo"))
  if (moment_match) {
    c(moment_match_args) <- nlist(
      x, loo = out, newdata, resp,
      k_threshold, check = FALSE, ...
    )
    out <- do_call("loo_moment_match", moment_match_args)
  }
  if (reloo) {
    c(reloo_args) <- nlist(
      x, loo = out, newdata, resp,
      k_threshold, check = FALSE, ...
    )
    out <- do_call("reloo", reloo_args)
  }
  recommend_loo_options(out, k_threshold, moment_match, model_name)
  out
}

# compute 'waic' criterion using the 'loo' package
# @param model_name ignored but included to avoid being passed to '...'
.waic <- function(x, pointwise, newdata, resp, model_name, ...) {
  loo_args <- prepare_loo_args(
    x, newdata = newdata, resp = resp,
    pointwise = pointwise, ...
  )
  do_call("waic", loo_args, pkg = "loo")
}

# alias of psis for convenient use in compute_loo()
.psis <- function(x, newdata, resp, model_name, ...) {
  psis(x, newdata = newdata, resp = resp, model_name = model_name, ...)
}

#' @inherit loo::psis return title description details references
#'
#' @aliases psis psis.brmsfit
#'
#' @param log_ratios A fitted model object of class \code{brmsfit}.
#'   Argument is named "log_ratios" to match the argument name of the
#'   \code{\link[loo:psis]{loo::psis}} generic function.
#' @param model_name Currently ignored.
#' @param ... Further arguments passed to \code{\link{log_lik}} and
#'   \code{\link[loo:psis]{loo::psis}}.
#' @inheritParams log_lik.brmsfit
#'
#' @examples
#' \dontrun{
#' fit <- brm(rating ~ treat + period + carry, data = inhaler)
#' psis(fit)
#'}
#' @importFrom loo psis
#' @export psis
#' @export
psis.brmsfit <- function(log_ratios, newdata = NULL, resp = NULL,
                         model_name = NULL, ...) {
  loo_args <- prepare_loo_args(
    log_ratios, newdata = newdata, resp = resp,
    pointwise = FALSE, ...
  )
  loo_args$log_ratios <- -loo_args$x
  loo_args$x <- NULL
  do_call("psis", loo_args, pkg = "loo")
}

# prepare arguments passed to the methods of the `loo` package
prepare_loo_args <- function(x, newdata, resp, pointwise, ...) {
  pointwise <- as_one_logical(pointwise)
  loo_args <- list(...)
  ll_args <- nlist(object = x, newdata, resp, pointwise, ...)
  loo_args$x <- do_call(log_lik, ll_args)
  if (pointwise) {
    loo_args$draws <- attr(loo_args$x, "draws")
    loo_args$data <- attr(loo_args$x, "data")
  }
  # compute pointwise relative efficiencies
  r_eff_args <- loo_args
  r_eff_args$fit <- x
  loo_args$r_eff <- do_call(r_eff_log_lik, r_eff_args)
  loo_args
}

#' Model comparison with the \pkg{loo} package
#'
#' For more details see \code{\link[loo:loo_compare]{loo_compare}}.
#'
#' @aliases loo_compare
#'
#' @inheritParams loo.brmsfit
#' @param ... More \code{brmsfit} objects.
#' @param criterion The name of the criterion to be extracted
#'   from \code{brmsfit} objects.
#'
#' @details All \code{brmsfit} objects should contain precomputed
#'   criterion objects. See \code{\link{add_criterion}} for more help.
#'
#' @return An object of class "\code{compare.loo}".
#'
#' @examples
#' \dontrun{
#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#'             data = inhaler)
#' fit1 <- add_criterion(fit1, "waic")
#'
#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#'             data = inhaler)
#' fit2 <- add_criterion(fit2, "waic")
#'
#' # compare both models
#' loo_compare(fit1, fit2, criterion = "waic")
#' }
#'
#' @importFrom loo loo_compare
#' @export loo_compare
#' @export
loo_compare.brmsfit <- function(x, ..., criterion = c("loo", "waic", "kfold"),
                                model_names = NULL) {
  criterion <- match.arg(criterion)
  models <- split_dots(x, ..., model_names = model_names, other = FALSE)
  loos <- named_list(names(models))
  for (i in seq_along(models)) {
    models[[i]] <- restructure(models[[i]])
    loo_i <- get_criterion(models[[i]], criterion)
    if (is.null(loo_i)) {
      stop2(
        "Model '", names(models)[i], "' does not contain a precomputed '",
        criterion, "' criterion. See ?loo_compare.brmsfit for help."
      )
    }
    # only assign object to list after checking if non-null
    # otherwise the index may be out of bounds in the error check
    loos[[i]] <- loo_i
  }
  loo_compare(loos)
}

#' Model averaging via stacking or pseudo-BMA weighting.
#'
#' Compute model weights for \code{brmsfit} objects via stacking
#' or pseudo-BMA weighting. For more details, see
#' \code{\link[loo:loo_model_weights]{loo::loo_model_weights}}.
#'
#' @aliases loo_model_weights
#'
#' @inheritParams loo.brmsfit
#'
#' @return A named vector of model weights.
#'
#' @examples
#' \dontrun{
#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#'             data = inhaler, family = "gaussian")
#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#'             data = inhaler, family = "gaussian")
#' loo_model_weights(fit1, fit2)
#' }
#'
#' @method loo_model_weights brmsfit
#' @importFrom loo loo_model_weights
#' @export loo_model_weights
#' @export
loo_model_weights.brmsfit <- function(x, ..., model_names = NULL) {
  args <- split_dots(x, ..., model_names = model_names)
  models <- args$models
  args$models <- NULL
  log_lik_list <- lapply(models, function(x)
    do_call(log_lik, c(list(x), args))
  )
  args$x <- log_lik_list
  args$r_eff_list <- mapply(
    r_eff_log_lik, log_lik_list,
    fit = models, SIMPLIFY = FALSE
  )
  out <- do_call(loo::loo_model_weights, args)
  names(out) <- names(models)
  out
}

#' Add model fit criteria to model objects
#'
#' @param x An \R object typically of class \code{brmsfit}.
#' @param criterion Names of model fit criteria
#'   to compute. Currently supported are \code{"loo"},
#'   \code{"waic"}, \code{"kfold"}, \code{"loo_subsample"},
#'   \code{"bayes_R2"} (Bayesian R-squared),
#'   \code{"loo_R2"} (LOO-adjusted R-squared), and
#'   \code{"marglik"} (log marginal likelihood).
#' @param model_name Optional name of the model. If \code{NULL}
#'   (the default) the name is taken from the call to \code{x}.
#' @param overwrite Logical; Indicates if already stored fit
#'   indices should be overwritten. Defaults to \code{FALSE}.
#'   Setting it to \code{TRUE} is useful for example when changing
#'   additional arguments of an already stored criterion.
#' @param file Either \code{NULL} or a character string. In the latter case, the
#'   fitted model object including the newly added criterion values is saved via
#'   \code{\link{saveRDS}} in a file named after the string supplied in
#'   \code{file}. The \code{.rds} extension is added automatically. If \code{x}
#'   was already stored in a file before, the file name will be reused
#'   automatically (with a message) unless overwritten by \code{file}. In any
#'   case, \code{file} only applies if new criteria were actually added via
#'   \code{add_criterion} or if \code{force_save} was set to \code{TRUE}.
#' @param force_save Logical; only relevant if \code{file} is specified and
#'   ignored otherwise. If \code{TRUE}, the fitted model object will be saved
#'   regardless of whether new criteria were added via \code{add_criterion}.
#' @param ... Further arguments passed to the underlying
#'   functions computing the model fit criteria. If you are recomputing
#'   an already stored criterion with other \code{...} arguments, make
#'   sure to set \code{overwrite = TRUE}.
#'
#' @return An object of the same class as \code{x}, but
#'   with model fit criteria added for later usage.
#'
#' @details Functions \code{add_loo} and \code{add_waic} are aliases of
#'   \code{add_criterion} with fixed values for the \code{criterion} argument.
#'
#' @examples
#' \dontrun{
#' fit <- brm(count ~ Trt, data = epilepsy)
#' # add both LOO and WAIC at once
#' fit <- add_criterion(fit, c("loo", "waic"))
#' print(fit$criteria$loo)
#' print(fit$criteria$waic)
#' }
#'
#' @export
add_criterion <- function(x, ...) {
  UseMethod("add_criterion")
}

#' @rdname add_criterion
#' @export
add_criterion.brmsfit <- function(x, criterion, model_name = NULL,
                                  overwrite = FALSE, file = NULL,
                                  force_save = FALSE, ...) {
  if (!is.null(model_name)) {
    model_name <- as_one_character(model_name)
  } else {
    model_name <- deparse0(substitute(x))
  }
  criterion <- unique(as.character(criterion))
  if (any(criterion == "R2")) {
    # deprecated as of version 2.10.4
    warning2("Criterion 'R2' is deprecated. Please use 'bayes_R2' instead.")
    criterion[criterion == "R2"] <- "bayes_R2"
  }
  loo_options <- c("loo", "waic", "kfold", "loo_subsample")
  options <- c(loo_options, "bayes_R2", "loo_R2", "marglik")
  if (!length(criterion) || !all(criterion %in% options)) {
    stop2("Argument 'criterion' should be a subset of ",
          collapse_comma(options))
  }
  auto_save <- FALSE
  if (!is.null(file)) {
    file <- paste0(as_one_character(file), ".rds")
  } else {
    file <- x$file
    if (!is.null(file)) auto_save <- TRUE
  }
  force_save <- as_one_logical(force_save)
  overwrite <- as_one_logical(overwrite)
  if (overwrite) {
    # recompute all criteria
    new_criteria <- criterion
  } else {
    # only computed criteria not already stored
    new_criteria <- criterion[ulapply(x$criteria[criterion], is.null)]
  }
  # remove all criteria that are to be recomputed
  x$criteria[new_criteria] <- NULL
  args <- list(x, ...)
  for (fun in intersect(new_criteria, loo_options)) {
    args$model_names <- model_name
    x$criteria[[fun]] <- do_call(fun, args)
  }
  if ("bayes_R2" %in% new_criteria) {
    args$summary <- FALSE
    x$criteria$bayes_R2 <- do_call(bayes_R2, args)
  }
  if ("loo_R2" %in% new_criteria) {
    args$summary <- FALSE
    x$criteria$loo_R2 <- do_call(loo_R2, args)
  }
  if ("marglik" %in% new_criteria) {
    x$criteria$marglik <- do_call(bridge_sampler, args)
  }
  if (!is.null(file) && (force_save || length(new_criteria))) {
    if (auto_save) {
      message("Automatically saving the model object in '", file, "'")
    }
    x$file <- file
    saveRDS(x, file = file)
  }
  x
}

# extract a recomputed model fit criterion
get_criterion <- function(x, criterion) {
  stopifnot(is.brmsfit(x))
  criterion <- as_one_character(criterion)
  x$criteria[[criterion]]
}

# create a hash based on the response of a model
hash_response <- function(x, newdata = NULL, resp = NULL, ...) {
  require_package("digest")
  stopifnot(is.brmsfit(x))
  sdata <- standata(
    x, newdata = newdata, re_formula = NA, internal = TRUE,
    check_response = TRUE, only_response = TRUE
  )
  add_funs <- lsp("brms", what = "exports", pattern = "^resp_")
  regex <- c("Y", sub("^resp_", "", add_funs))
  regex <- outer(regex, escape_all(usc(resp)), FUN = paste0)
  regex <- paste0("(", as.vector(regex), ")", collapse = "|")
  regex <- paste0("^(", regex, ")(_|$)")
  out <- sdata[grepl(regex, names(sdata))]
  out <- as.matrix(as.data.frame(rmNULL(out)))
  out <- p(out, attr(sdata, "old_order"))
  # see issue #642
  attributes(out) <- NULL
  digest::sha1(x = out, ...)
}

# compare the response parts of multiple brmsfit objects
# @param models A list of brmsfit objects
# @param ... passed to hash_response
# @return TRUE if the response parts of all models match and FALSE otherwise
match_response <- function(models, ...) {
  if (length(models) <= 1L) {
    out <- TRUE
  } else {
    yhash <- lapply(models, hash_response, ...)
    yhash_check <- ulapply(yhash, is_equal, yhash[[1]])
    if (all(yhash_check)) {
      out <- TRUE
    } else {
      out <- FALSE
    }
  }
  out
}

# compare number of observations of multipe models
# @param models A list of brmsfit objects
# @param ... currently ignored
# @return TRUE if the number of rows match
match_nobs <- function(models, ...) {
  if (length(models) <= 1L) {
    out <- TRUE
  } else {
    nobs <- lapply(models, nobs)
    nobs_check <- ulapply(nobs, is_equal, nobs[[1]])
    if (all(nobs_check)) {
      out <- TRUE
    } else {
      out <- FALSE
    }
  }
  out
}

# validate models passed to loo and related methods
# @param models list of fitted model objects
# @param model_names names specified by the user
# @param sub_names names inferred by substitute()
validate_models <- function(models, model_names, sub_names) {
  stopifnot(is.list(models))
  model_names <- as.character(model_names)
  if (!length(model_names)) {
    model_names <- as.character(sub_names)
  }
  if (length(model_names) != length(models)) {
    stop2("Number of model names is not equal to the number of models.")
  }
  names(models) <- model_names
  for (i in seq_along(models)) {
    if (!is.brmsfit(models[[i]])) {
      stop2("Object '", names(models)[i], "' is not of class 'brmsfit'.")
    }
  }
  models
}

# recommend options if approximate loo fails for some observations
# @param moment_match has moment matching already been performed?
recommend_loo_options <- function(loo, k_threshold = 0.7, moment_match = FALSE,
                                  model_name = "") {
  if (isTRUE(nzchar(model_name))) {
    model_name <- paste0(" in model '", model_name, "'")
  } else {
    model_name <- ""
  }
  ndraws <- dim(loo)[1] %||% Inf
  n <- n2 <- length(loo::pareto_k_ids(loo, threshold = k_threshold))
  # for small number of draws the threshold may be smaller than 0.7
  k_threshold2 <- ps_khat_threshold(ndraws)
  if (k_threshold2 < k_threshold) {
    n2 <- length(loo::pareto_k_ids(loo, threshold = k_threshold2))
  }
  if (n2 > n && k_threshold2 <= 0.7) {
    warning2(
      "Found ", n2, " observations with a pareto_k > ", round(k_threshold2, 2),
      model_name, ". We recommend to run more iterations to get at least ",
      "about 2200 posterior draws to improve LOO-CV approximation accuracy."
    )
    out <- "loo_more_draws"
  } else if (n > 0 && !moment_match) {
    warning2(
      "Found ", n, " observations with a pareto_k > ", k_threshold,
      model_name, ". We recommend to set 'moment_match = TRUE' in order ",
      "to perform moment matching for problematic observations. "
    )
    out <- "loo_moment_match"
  } else if (n > 0 && n <= 10) {
    warning2(
      "Found ", n, " observations with a pareto_k > ", k_threshold,
      model_name, ". We recommend to set 'reloo = TRUE' in order to ",
      "calculate the ELPD without the assumption that these observations " ,
      "are negligible. This will refit the model ", n, " times to compute ",
      "the ELPDs for the problematic observations directly."
    )
    out <- "reloo"
  } else if (n > 10) {
    warning2(
      "Found ", n, " observations with a pareto_k > ", k_threshold,
      model_name, ". With this many problematic observations, it may be more ",
      "appropriate to use 'kfold' with argument 'K = 10' to perform ",
      "10-fold cross-validation rather than LOO."
    )
    out <- "kfold"
  } else {
    out <- "loo"
  }
  invisible(out)
}

# subset observations in a psis object
# this is a bit cumbersome because of how psis stores information
# @param subset vector with which to subset
#' @export
subset.psis <- function(x, subset, ...) {
  stopifnot(is.vector(subset))
  x$log_weights <- x$log_weights[, subset, drop = FALSE]
  for (d in names(x$diagnostics)) {
    x$diagnostics[[d]] <- x$diagnostics[[d]][subset]
  }
  attr_names <- c("norm_const_log", "tail_len", "r_eff")
  for (a in attr_names) {
    attr(x, a) <- attr(x, a)[subset]
  }
  attr(x, "dims") <- dim(x$log_weights)
  x
}

# helper function to compute relative efficiences
# @param x matrix of posterior draws
# @param fit a brmsfit object to extract metadata from
# @param allow_na allow NA values in the output?
# @return a numeric vector of length NCOL(x)
r_eff_helper <- function(x, chain_id, allow_na = TRUE, ...) {
  out <- loo::relative_eff(x, chain_id = chain_id, ...)
  if (!allow_na && anyNA(out)) {
    # avoid error in loo if some but not all r_effs are NA
    out <- rep(1, length(out))
    warning2(
      "Ignoring relative efficiencies as some were NA. ",
      "See argument 'r_eff' in ?loo::loo for more details."
    )
  }
  out
}

# wrapper around r_eff_helper to compute efficiency
# of likelihood draws based on log-likelihood draws
r_eff_log_lik <- function(x, ...) {
  UseMethod("r_eff_log_lik")
}

#' @export
r_eff_log_lik.matrix <- function(x, fit, allow_na = FALSE, ...) {
  if (is.brmsfit_multiple(fit)) {
    # due to stacking of chains from multiple models
    # efficiency computations will likely be incorrect
    # assume relative efficiency of 1 for now
    return(rep(1, ncol(x)))
  }
  chain_id <- get_chain_id(nrow(x), fit)
  r_eff_helper(exp(x), chain_id = chain_id, allow_na = allow_na, ...)
}

#' @export
r_eff_log_lik.function <- function(x, fit, draws, allow_na = FALSE, ...) {
  if (is.brmsfit_multiple(fit)) {
    # due to stacking of chains from multiple models
    # efficiency computations will likely be incorrect
    # assume relative efficiency of 1 for now
    return(rep(1, draws$nobs))
  }
  lik_fun <- function(data_i, draws, ...) {
    exp(x(data_i, draws, ...))
  }
  chain_id <- get_chain_id(draws$ndraws, fit)
  r_eff_helper(
    lik_fun, chain_id = chain_id, draws = draws,
    allow_na = allow_na, ...
  )
}

# get chain IDs per posterior draw
get_chain_id <- function(ndraws, fit) {
  if (ndraws != ndraws(fit)) {
    # don't know the chain IDs of a subset of draws
    chain_id <- rep(1L, ndraws)
  } else {
    nchains <- fit$fit@sim$chains
    chain_id <- rep(seq_len(nchains), each = ndraws / nchains)
  }
  chain_id
}

# print the output of a list of loo objects
#' @export
print.loolist <- function(x, digits = 1, ...) {
  model_names <- loo::find_model_names(x$loos)
  for (i in seq_along(x$loos)) {
    cat(paste0("Output of model '", model_names[i], "':\n"))
    print(x$loos[[i]], digits = digits, ...)
    cat("\n")
  }
  if (!is.null(x$diffs)) {
    cat("Model comparisons:\n")
    print(x$diffs, digits = digits, ...)
  }
  invisible(x)
}

# ---------- deprecated functions ----------
#' @rdname add_ic
#' @export
add_loo <- function(x, model_name = NULL, ...) {
  warning2("'add_loo' is deprecated. Please use 'add_criterion' instead.")
  if (!is.null(model_name)) {
    model_name <- as_one_character(model_name)
  } else {
    model_name <- deparse0(substitute(x))
  }
  add_criterion(x, criterion = "loo", model_name = model_name, ...)
}

#' @rdname add_ic
#' @export
add_waic <- function(x, model_name = NULL, ...) {
  warning2("'add_waic' is deprecated. Please use 'add_criterion' instead.")
  if (!is.null(model_name)) {
    model_name <- as_one_character(model_name)
  } else {
    model_name <- deparse0(substitute(x))
  }
  add_criterion(x, criterion = "waic", model_name = model_name, ...)
}

#' Add model fit criteria to model objects
#'
#' Deprecated aliases of \code{\link{add_criterion}}.
#'
#' @inheritParams add_criterion
#' @param ic,value Names of model fit criteria
#'   to compute. Currently supported are \code{"loo"},
#'   \code{"waic"}, \code{"kfold"}, \code{"R2"} (R-squared), and
#'   \code{"marglik"} (log marginal likelihood).
#'
#' @return An object of the same class as \code{x}, but
#'   with model fit criteria added for later usage.
#'   Previously computed criterion objects will be overwritten.
#'
#' @export
add_ic <- function(x, ...) {
  UseMethod("add_ic")
}

#' @rdname add_ic
#' @export
add_ic.brmsfit <- function(x, ic = "loo", model_name = NULL, ...) {
  warning2("'add_ic' is deprecated. Please use 'add_criterion' instead.")
  if (!is.null(model_name)) {
    model_name <- as_one_character(model_name)
  } else {
    model_name <- deparse0(substitute(x))
  }
  add_criterion(x, criterion = ic, model_name = model_name, ...)
}

#' @rdname add_ic
#' @export
'add_ic<-' <- function(x, ..., value) {
  add_ic(x, ic = value, ...)
}

#' Compare Information Criteria of Different Models
#'
#' Compare information criteria of different models fitted
#' with \code{\link{waic}} or \code{\link{loo}}.
#' Deprecated and will be removed in the future. Please use
#' \code{\link{loo_compare}} instead.
#'
#' @param ... At least two objects returned by
#'   \code{\link{waic}} or \code{\link{loo}}.
#'   Alternatively, \code{brmsfit} objects with information
#'   criteria precomputed via \code{\link{add_ic}}
#'   may be passed, as well.
#' @param x A \code{list} containing the same types of objects as
#'   can be passed via \code{...}.
#' @param ic The name of the information criterion to be extracted
#'   from \code{brmsfit} objects. Ignored if information
#'   criterion objects are only passed directly.
#'
#' @return An object of class \code{iclist}.
#'
#' @details See \code{\link{loo_compare}} for the recommended way
#'   of comparing models with the \pkg{loo} package.
#'
#' @seealso
#'   \code{\link{loo}},
#'   \code{\link{loo_compare}}
#'   \code{\link{add_criterion}}
#'
#' @examples
#' \dontrun{
#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#'             data = inhaler)
#' waic1 <- waic(fit1)
#'
#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#'             data = inhaler)
#' waic2 <- waic(fit2)
#'
#' # compare both models
#' compare_ic(waic1, waic2)
#' }
#'
#' @export
compare_ic <- function(..., x = NULL, ic = c("loo", "waic", "kfold")) {
  # will be removed in brms 3.0
  warning2(
    "'compare_ic' is deprecated and will be removed ",
    "in the future. Please use 'loo_compare' instead."
  )
  ic <- match.arg(ic)
  if (!(is.null(x) || is.list(x))) {
    stop2("Argument 'x' should be a list.")
  }
  x$ic_diffs__ <- NULL
  x <- c(list(...), x)
  for (i in seq_along(x)) {
    # extract precomputed values from brmsfit objects
    if (is.brmsfit(x[[i]]) && !is.null(x[[i]][[ic]])) {
      x[[i]] <- x[[i]][[ic]]
    }
  }
  if (!all(sapply(x, inherits, "loo"))) {
    stop2("All inputs should have class 'loo' ",
          "or contain precomputed 'loo' objects.")
  }
  if (length(x) < 2L) {
    stop2("Expecting at least two objects.")
  }
  ics <- unname(sapply(x, function(y) rownames(y$estimates)[3]))
  if (!all(ics %in% ics[1])) {
    stop2("All inputs should be from the same criterion.")
  }
  yhash <- lapply(x, attr, which = "yhash")
  yhash_check <- ulapply(yhash, is_equal, yhash[[1]])
  if (!all(yhash_check)) {
    warning2(
      "Model comparisons are likely invalid as the response ",
      "values of at least two models do not match."
    )
  }
  names(x) <- loo::find_model_names(x)
  n_models <- length(x)
  ic_diffs <- matrix(0, nrow = n_models * (n_models - 1) / 2, ncol = 2)
  rnames <- rep("", nrow(ic_diffs))
  # pairwise comparision to get differences in ICs and their SEs
  n <- 1
  for (i in seq_len(n_models - 1)) {
    for (j in (i + 1):n_models) {
      tmp <- SW(loo::compare(x[[j]], x[[i]]))
      ic_diffs[n, ] <- c(-2 * tmp[["elpd_diff"]], 2 * tmp[["se"]])
      rnames[n] <- paste(names(x)[i], "-", names(x)[j])
      n <- n + 1
    }
  }
  rownames(ic_diffs) <- rnames
  colnames(ic_diffs) <- c(toupper(ics[1]), "SE")
  x$ic_diffs__ <- ic_diffs
  class(x) <- "iclist"
  x
}

# print the output of LOO and WAIC with multiple models
# deprecated as of brms > 2.5.0 and will be removed in brms 3.0
#' @export
print.iclist <- function(x, digits = 2, ...) {
  m <- x
  m$ic_diffs__ <- NULL
  if (length(m)) {
    ic <- rownames(m[[1]]$estimates)[3]
    mat <- matrix(0, nrow = length(m), ncol = 2)
    dimnames(mat) <- list(names(m), c(toupper(ic), "SE"))
    for (i in seq_along(m)) {
      mat[i, ] <- m[[i]]$estimates[3, ]
    }
  } else {
    mat <- ic <- NULL
  }
  ic_diffs <- x$ic_diffs__
  if (is.matrix(attr(x, "compare"))) {
    # deprecated as of brms 1.4.0
    ic_diffs <- attr(x, "compare")
  }
  if (is.matrix(ic_diffs)) {
    # models were compared using the compare_ic function
    mat <- rbind(mat, ic_diffs)
  }
  print(round(mat, digits = digits), na.print = "")
  invisible(x)
}

# Pareto-smoothing k-hat threshold
# not yet exported by loo so copied over here for now
ps_khat_threshold <- function(S, ...) {
  1 - 1 / log10(S)
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Sept. 23, 2024, 5:08 p.m.