R/loo.R

# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

#' Information criteria and cross-validation
#'
#' @description For models fit using MCMC, compute approximate leave-one-out
#'   cross-validation (LOO, LOOIC) or, less preferably, the Widely Applicable
#'   Information Criterion (WAIC) using the \pkg{\link[=loo-package]{loo}}
#'   package. Functions for \eqn{K}-fold cross-validation, model comparison,
#'   and model weighting/averaging are also provided. \strong{Note}:
#'   these functions are not guaranteed to work properly unless the \code{data}
#'   argument was specified when the model was fit. Also, as of \pkg{loo}
#'   version \code{2.0.0} the default number of cores is now only 1,  but we
#'   recommend using as many (or close to as many) cores as possible by setting
#'   the \code{cores} argument or using \code{options(mc.cores = VALUE)} to set
#'   it for an entire session.
#'
#' @aliases loo waic
#'
#' @export
#'
#' @param x For \code{loo}, \code{waic}, and \code{kfold} methods, a fitted
#'   model object returned by one of the rstanarm modeling functions. See
#'   \link{stanreg-objects}.
#'
#'   For \code{loo_model_weights}, \code{x} should be a "stanreg_list"
#'   object, which is a list of fitted model objects created by
#'   \code{\link{stanreg_list}}.
#'
#' @param ... For \code{compare_models}, \code{...} should contain two or more
#'   objects returned by the \code{loo}, \code{kfold}, or \code{waic} method
#'   (see the \strong{Examples} section, below).
#'
#'   For \code{loo_model_weights}, \code{...} should contain arguments
#'   (e.g. \code{method}) to pass to the default
#'   \code{\link[loo]{loo_model_weights}} method from the \pkg{loo} package.
#'
#' @param cores,save_psis Passed to \code{\link[loo]{loo}}.
#' @param k_threshold Threshold for flagging estimates of the Pareto shape
#'   parameters \eqn{k} estimated by \code{loo}. See the \emph{How to proceed
#'   when \code{loo} gives warnings} section, below, for details.
#'
#' @return The structure of the objects returned by \code{loo} and \code{waic}
#'   methods are documented in detail in the \strong{Value} section in
#'   \code{\link[loo]{loo}} and \code{\link[loo]{waic}} (from the \pkg{loo}
#'   package).
#'
#' @section Approximate LOO CV: The \code{loo} method for stanreg objects
#'   provides an interface to the \pkg{\link[=loo-package]{loo}} package for
#'   approximate leave-one-out cross-validation (LOO). The LOO Information
#'   Criterion (LOOIC) has the same purpose as the Akaike Information Criterion
#'   (AIC) that is used by frequentists. Both are intended to estimate the
#'   expected log predictive density (ELPD) for a new dataset. However, the AIC
#'   ignores priors and assumes that the posterior distribution is multivariate
#'   normal, whereas the functions from the \pkg{loo} package do not make this
#'   distributional assumption and integrate over uncertainty in the parameters.
#'   This only assumes that any one observation can be omitted without having a
#'   major effect on the posterior distribution, which can be judged using the
#'   diagnostic plot provided by the \code{\link[loo]{plot.loo}} method and the
#'   warnings provided by the \code{\link[loo]{print.loo}} method (see the
#'   \emph{How to Use the rstanarm Package} vignette for an example of this
#'   process).
#'
#'   \subsection{How to proceed when \code{loo} gives warnings (k_threshold)}{
#'   The \code{k_threshold} argument to the \code{loo} method for \pkg{rstanarm}
#'   models is provided as a possible remedy when the diagnostics reveal
#'   problems stemming from the posterior's sensitivity to particular
#'   observations. Warnings about Pareto \eqn{k} estimates indicate observations
#'   for which the approximation to LOO is problematic (this is described in
#'   detail in Vehtari, Gelman, and Gabry (2017) and the
#'   \pkg{\link[=loo-package]{loo}} package documentation). The
#'   \code{k_threshold} argument can be used to set the \eqn{k} value above
#'   which an observation is flagged. If \code{k_threshold} is not \code{NULL}
#'   and there are \eqn{J} observations with \eqn{k} estimates above
#'   \code{k_threshold} then when \code{loo} is called it will refit the
#'   original model \eqn{J} times, each time leaving out one of the \eqn{J}
#'   problematic observations. The pointwise contributions of these observations
#'   to the total ELPD are then computed directly and substituted for the
#'   previous estimates from these \eqn{J} observations that are stored in the
#'   object created by \code{loo}.
#'
#'   \strong{Note}: in the warning messages issued by \code{loo} about large
#'   Pareto \eqn{k} estimates we recommend setting \code{k_threshold} to at
#'   least \eqn{0.7}. There is a theoretical reason, explained in Vehtari,
#'   Gelman, and Gabry (2017), for setting the threshold to the stricter value
#'   of \eqn{0.5}, but in practice they find that errors in the LOO
#'   approximation start to increase non-negligibly when \eqn{k > 0.7}.
#'   }
#'
#' @seealso
#' \itemize{
#'   \item The new \href{http://mc-stan.org/loo/articles/}{\pkg{loo} package vignettes}
#'   and various \href{http://mc-stan.org/rstanarm/articles/}{\pkg{rstanarm} vignettes}
#'   for more examples using \code{loo} and related functions with \pkg{rstanarm} models.
#'   \item \code{\link[loo]{pareto-k-diagnostic}} in the \pkg{loo} package for
#'   more on Pareto \eqn{k} diagnostics.
#'   \item \code{\link{log_lik.stanreg}} to directly access the pointwise
#'   log-likelihood matrix.
#' }
#'
#' @examples
#' \donttest{
#' fit1 <- stan_glm(mpg ~ wt, data = mtcars)
#' fit2 <- stan_glm(mpg ~ wt + cyl, data = mtcars)
#'
#' # compare on LOOIC
#' # (for bigger models use as many cores as possible)
#' loo1 <- loo(fit1, cores = 2)
#' print(loo1)
#' loo2 <- loo(fit2, cores = 2)
#' print(loo2)
#'
#' # when comparing exactly two models, the reported 'elpd_diff'
#' # will be positive if the expected predictive accuracy for the
#' # second model is higher. the approximate standard error of the
#' # difference is also reported.
#' compare_models(loo1, loo2)
#' compare_models(loos = list(loo1, loo2)) # can also provide list
#'
#' # when comparing three or more models they are ordered by
#' # expected predictive accuracy. elpd_diff and se_diff are relative
#' # to the model with best elpd_loo (first row)
#' fit3 <- stan_glm(mpg ~ disp * as.factor(cyl), data = mtcars)
#' loo3 <- loo(fit3, cores = 2, k_threshold = 0.7)
#' compare_models(loo1, loo2, loo3)
#'
#' # setting detail=TRUE will also print model formulas
#' compare_models(loo1, loo2, loo3, detail=TRUE)
#'
#' # Computing model weights
#' model_list <- stanreg_list(fit1, fit2, fit3)
#' loo_model_weights(model_list, cores = 2) # can specify k_threshold=0.7 if necessary
#'
#' # if you have already computed loo then it's more efficient to pass a list
#' # of precomputed loo objects than a "stanreg_list", avoiding the need
#' # for loo_models weights to call loo() internally
#' loo_list <- list(fit1 = loo1, fit2 = loo2, fit3 = loo3) # names optional (affects printing)
#' loo_model_weights(loo_list)
#'
#' # 10-fold cross-validation
#' (kfold1 <- kfold(fit1, K = 10))
#' kfold2 <- kfold(fit2, K = 10)
#' compare_models(kfold1, kfold2, detail=TRUE)
#'
#' # Cross-validation stratifying by a grouping variable
#' # (note: might get some divergences warnings with this model but
#' # this is just intended as a quick example of how to code this)
#' library(loo)
#' fit4 <- stan_lmer(mpg ~ disp + (1|cyl), data = mtcars)
#' table(mtcars$cyl)
#' folds_cyl <- kfold_split_stratified(K = 3, x = mtcars$cyl)
#' table(cyl = mtcars$cyl, fold = folds_cyl)
#' kfold4 <- kfold(fit4, K = 3, folds = folds_cyl)
#' }
#'
#' @importFrom loo loo loo.function loo.matrix
#'
loo.stanreg <-
  function(x,
           ...,
           cores = getOption("mc.cores", 1),
           save_psis = FALSE,
           k_threshold = NULL) {
    if (!used.sampling(x))
      STOP_sampling_only("loo")
    if (model_has_weights(x))
      recommend_exact_loo(reason = "model has weights")

    user_threshold <- !is.null(k_threshold)
    if (user_threshold) {
      validate_k_threshold(k_threshold)
    } else {
      k_threshold <- 0.7
    }

    # chain_id to pass to loo::relative_eff
    chain_id <- chain_id_for_loo(x)

    if (is.stanjm(x)) {
      ll <- log_lik(x)
      r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
      loo_x <-
        suppressWarnings(loo.matrix(
          ll,
          r_eff = r_eff,
          cores = cores,
          save_psis = save_psis
        ))
    } else if (is.stanmvreg(x)) {
      M <- get_M(x)
      ll <- do.call("cbind", lapply(1:M, function(m) log_lik(x, m = m)))
      r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
      loo_x <-
        suppressWarnings(loo.matrix(
          ll,
          r_eff = r_eff,
          cores = cores,
          save_psis = save_psis
        ))
    } else if (is_clogit(x)) {
      ll <- log_lik.stanreg(x)
      cons <- apply(ll, MARGIN = 2, FUN = function(y) sd(y) < 1e-15)
      if (any(cons)) {
        message(
          "The following strata were dropped from the ",
          "loo calculation because log-lik is constant: ",
          paste(which(cons), collapse = ", ")
        )
        ll <- ll[,!cons, drop = FALSE]
      }
      r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
      loo_x <-
        suppressWarnings(loo.matrix(
          ll,
          r_eff = r_eff,
          cores = cores,
          save_psis = save_psis
        ))
    } else if (is.stansurv(x) && x$has_quadrature) {
      ll <- log_lik.stanreg(x)
      r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
      loo_x <-
        suppressWarnings(loo.matrix(
          ll,
          r_eff = r_eff,
          cores = cores,
          save_psis = save_psis
        ))
    } else {
      args <- ll_args(x)
      llfun <- ll_fun(x)
      likfun <- function(data_i, draws) {
        exp(llfun(data_i, draws))
      }
      r_eff <- loo::relative_eff(
        # using function method
        x = likfun,
        chain_id = chain_id,
        data = args$data,
        draws = args$draws,
        cores = cores,
        ...
      )
      loo_x <- suppressWarnings(
        loo.function(
          llfun,
          data = args$data,
          draws = args$draws,
          r_eff = r_eff,
          ...,
          cores = cores,
          save_psis = save_psis
        )
      )
    }

    bad_obs <- loo::pareto_k_ids(loo_x, k_threshold)
    n_bad <- length(bad_obs)

    out <- structure(
      loo_x,
      name = deparse(substitute(x)),
      discrete = is_discrete(x),
      yhash = hash_y(x),
      formula = loo_model_formula(x)
    )

    if (!length(bad_obs)) {
      if (user_threshold) {
        message(
          "All pareto_k estimates below user-specified threshold of ",
          k_threshold,
          ". \nReturning loo object."
        )
      }
      return(out)
    }

    if (!user_threshold) {
      if (n_bad > 10) {
        recommend_kfold(n_bad)
      } else {
        recommend_reloo(n_bad)
      }
      return(out)
    }

    reloo_out <- reloo(x, loo_x, obs = bad_obs)
    structure(
      reloo_out,
      name = attr(out, "name"),
      discrete = attr(out, "discrete"),
      yhash = attr(out, "yhash"),
      formula = loo_model_formula(x)
    )
  }

# WAIC
#
#' @rdname loo.stanreg
#' @export
#' @importFrom loo waic waic.function waic.matrix
#'
waic.stanreg <- function(x, ...) {
  if (!used.sampling(x))
    STOP_sampling_only("waic")
  if (is.stanjm(x)) {
    out <- waic.matrix(log_lik(x))
  } else if (is.stanmvreg(x)) {
    M <- get_M(x)
    ll <- do.call("cbind", lapply(1:M, function(m) log_lik(x, m = m)))
    out <- waic.matrix(ll)
  } else if (is_clogit(x)) {
    out <- waic.matrix(log_lik(x))
  } else if (is.stansurv(x) && x$has_quadrature) {
    out <- waic.matrix(log_lik(x))
  } else {
    args <- ll_args(x)
    out <- waic.function(ll_fun(x), data = args$data, draws = args$draws)
  }
  structure(out,
            class = c("waic", "loo"),
            name = deparse(substitute(x)),
            discrete = is_discrete(x),
            yhash = hash_y(x),
            formula = loo_model_formula(x))
}


# K-fold CV
#
#' @rdname loo.stanreg
#' @export
#' @param K For \code{kfold}, the number of subsets (folds)
#'   into which the data will be partitioned for performing
#'   \eqn{K}-fold cross-validation. The model is refit \code{K} times, each time
#'   leaving out one of the \code{K} folds. If \code{K} is equal to the total
#'   number of observations in the data then \eqn{K}-fold cross-validation is
#'   equivalent to exact leave-one-out cross-validation.
#' @param save_fits For \code{kfold}, if \code{TRUE}, a component \code{'fits'}
#'   is added to the returned object to store the cross-validated
#'   \link[=stanreg-objects]{stanreg} objects and the indices of the omitted
#'   observations for each fold. Defaults to \code{FALSE}.
#' @param folds For \code{kfold}, an optional integer vector with one element
#'   per observation in the data used to fit the model. Each element of the
#'   vector is an integer in \code{1:K} indicating to which of the \code{K}
#'   folds the corresponding observation belongs. There are some convenience
#'   functions available in the \pkg{loo} package that create integer vectors to
#'   use for this purpose (see the \strong{Examples} section below and also the
#'   \link[loo]{kfold-helpers} page).
#'
#'   If \code{folds} is not specified then the default is to call
#'   \code{loo::\link[loo]{kfold_split_random}} to randomly partition the data
#'   into \code{K} subsets of equal (or as close to equal as possible) size.
#'
#' @return \code{kfold} returns an object with classes 'kfold' and 'loo' that
#'   has a similar structure as the objects returned by the \code{loo} and
#'   \code{waic} methods.
#'
#' @section K-fold CV: The \code{kfold} function performs exact \eqn{K}-fold
#'   cross-validation. First the data are randomly partitioned into \eqn{K}
#'   subsets of equal (or as close to equal as possible) size (unless the folds
#'   are specified manually). Then the model is refit \eqn{K} times, each time
#'   leaving out one of the \eqn{K} subsets. If \eqn{K} is equal to the total
#'   number of observations in the data then \eqn{K}-fold cross-validation is
#'   equivalent to exact leave-one-out cross-validation (to which \code{loo} is
#'   an efficient approximation). The \code{compare_models} function is also
#'   compatible with the objects returned by \code{kfold}.
#'
kfold <- function(x, K = 10, save_fits = FALSE, folds = NULL) {
  validate_stanreg_object(x)
  stopifnot(K > 1, K <= nobs(x))
  if (!used.sampling(x)) {
    STOP_sampling_only("kfold")
  }
  if (is.stanmvreg(x)) {
    STOP_if_stanmvreg("kfold")
  }
  if (model_has_weights(x)) {
    stop("kfold is not currently available for models fit using weights.")
  }

  d <- kfold_and_reloo_data(x)
  N <- nrow(d)
  K <- as.integer(K)

  if (is.null(folds)) {
    folds <- loo::kfold_split_random(K = K, N = N)
  } else {
    stopifnot(
      length(folds) == N,
      all(folds == as.integer(folds)),
      all(folds %in% 1L:K),
      all(1:K %in% folds)
    )
    folds <- as.integer(folds)
  }

  lppds <- list()
  fits <- array(list(), c(K, 2), list(NULL, c("fit", "omitted")))
  for (k in 1:K) {
    message("Fitting model ", k, " out of ", K)
    omitted <- which(folds == k)
    fit_k_call <- update.stanreg(
      object = x,
      data = d[-omitted,, drop=FALSE],
      subset = rep(TRUE, nrow(d) - length(omitted)),
      weights = NULL,
      refresh = 0,
      open_progress = FALSE,
      evaluate = FALSE
    )
    if (!is.null(getCall(x)$offset)) {
      fit_k_call$offset <- x$offset[-omitted]
    }
    fit_k_call$subset <- if (!is.stansurv(x)) eval(fit_k_call$subset) else NULL
    fit_k_call$data <- eval(fit_k_call$data)
    capture.output(
      fit_k <- eval(fit_k_call)
    )

    lppds[[k]] <-
      log_lik.stanreg(
        fit_k,
        newdata = d[omitted, , drop = FALSE],
        offset = x$offset[omitted],
        newx = get_x(x)[omitted, , drop = FALSE],
        newz = x$z[omitted, , drop = FALSE], # NULL other than for some stan_betareg models
        stanmat = as.matrix.stanreg(fit_k)
      )
    if (save_fits) {
      fits[k, ] <- list(fit = fit_k, omitted = omitted)
    }
  }
  elpds_unord <- unlist(lapply(lppds, function(x) {
    apply(x, 2, log_mean_exp)
  }))

  # make sure elpds are put back in the right order
  obs_order <- unlist(lapply(1:K, function(k) which(folds == k)))
  elpds <- rep(NA, length(elpds_unord))
  elpds[obs_order] <- elpds_unord

  out <- list(
    elpd_kfold = sum(elpds),
    se_elpd_kfold = sqrt(N * var(elpds)),
    pointwise = cbind(elpd_kfold = elpds)
  )

  # for compatibility with new structure of loo package objects
  out$estimates <- cbind(Estimate = out$elpd_kfold, SE = out$se_elpd_kfold)
  rownames(out$estimates) <- c("elpd_kfold")

  if (save_fits) {
    out$fits <- fits
  }

  structure(out,
            class = c("kfold", "loo"),
            K = K,
            name = deparse(substitute(x)),
            discrete = is_discrete(x),
            yhash = hash_y(x),
            formula = loo_model_formula(x))
}

#' Various print methods
#'
#' @keywords internal
#' @export
#' @method print kfold
#' @param x,digits,... See \code{\link{print}}.
print.kfold <- function(x, digits = 1, ...) {
  cat("\n", paste0(attr(x, "K"), "-fold"), "cross-validation\n\n")
  out <- data.frame(Estimate = x$elpd_kfold, SE = x$se_elpd_kfold,
                    row.names = "elpd_kfold")
  .printfr(out, digits)
  invisible(x)
}


#' @rdname loo.stanreg
#' @export
#'
#' @param loos For \code{compare_models}, a list of two or more objects returned
#'   by the \code{loo}, \code{kfold}, or \code{waic} method. This argument can
#'   be used as an alternative to passing these objects via \code{...}.
#' @param detail For \code{compare_models}, if \code{TRUE} then extra
#'   information about each model (currently just the model formulas) will be
#'   printed with the output.
#'
#' @return \code{compare_models} returns a vector or matrix with class
#'   'compare.loo'. See the \strong{Comparing models} section below for more
#'   details.
#'
#' @section Comparing models: \code{compare_models} is a method for the
#'   \code{\link[loo]{compare}} function in the \pkg{loo} package that
#'   performs some extra checks to make sure the \pkg{rstanarm} models are
#'   suitable for comparison. These extra checks include verifying that all
#'   models to be compared were fit using the same outcome variable and
#'   likelihood family.
#'
#'   If exactly two models are being compared then \code{compare_models} returns
#'   a vector containing the difference in expected log predictive density
#'   (ELPD) between the models and the standard error of that difference (the
#'   documentation for \code{\link[loo]{compare}} in the \pkg{loo}
#'   package has additional details about the calculation of the standard error
#'   of the difference). The difference in ELPD will be negative if the expected
#'   out-of-sample predictive accuracy of the first model is higher. If the
#'   difference is be positive then the second model is preferred.
#'
#'   If more than two models are being compared then \code{compare_models}
#'   returns a matrix with one row per model. This matrix summarizes the objects
#'   and arranges them in descending order according to expected out-of-sample
#'   predictive accuracy. That is, the first row of the matrix will be
#'   for the model with the largest ELPD (smallest LOOIC).
#'   The columns containing the ELPD difference and the standard error of the
#'   difference contain values relative to the model with the best ELPD.
#'   See the \strong{Details} section at the \code{\link[loo]{compare}}
#'   page in the \pkg{loo} package for more information.
#'
compare_models <- function(..., loos = list(), detail = FALSE) {
  dots <- list(...)
  if (length(dots) && length(loos)) {
    stop("'...' and 'loos' can't both be specified.", call. = FALSE)
  } else if (length(dots)) {
    loos <- dots
  } else {
    stopifnot(is.list(loos))
  }

  loos <- validate_loos(loos)
  comp <- loo::compare(x = loos)
  structure(
    comp,
    class = c("compare_rstanarm_loos", class(comp)),
    model_names = names(loos),
    formulas = if (!detail) NULL else lapply(loos, attr, "formula")
  )
}

#' @rdname print.kfold
#' @keywords internal
#' @export
#' @method print compare_rstanarm_loos
print.compare_rstanarm_loos <- function(x, ...) {
  formulas <- attr(x, "formulas")
  nms <- attr(x, "model_names")
  if (!is.null(formulas)) {
    cat("Model formulas: ")
    for (j in seq_len(NROW(x))) {
      cat("\n ", paste0(nms[j], ": "),
          formula_string(formulas[[j]]))
    }
    cat("\n")
  }

  xcopy <- x
  class(xcopy) <- "compare.loo"

  if (NROW(x) == 2) {
    cat("\nModel comparison: ")
    cat("\n(negative 'elpd_diff' favors 1st model, positive favors 2nd) \n\n")
  } else {
    cat("\nModel comparison: ")
    cat("\n(ordered by highest ELPD)\n\n")
  }
  print(xcopy, ...)

  return(invisible(x))
}


#' @rdname loo.stanreg
#' @aliases loo_model_weights
#'
#' @importFrom loo loo_model_weights
#' @export loo_model_weights
#'
#' @export
#'
#'
#' @section Model weights: The \code{loo_model_weights} method can be used to
#'   compute model weights for a \code{"stanreg_list"} object, which is a list
#'   of fitted model objects made with \code{\link{stanreg_list}}. The end of
#'   the \strong{Examples} section has a demonstration. For details see the
#'   \code{\link[loo]{loo_model_weights}} documentation in the \pkg{loo}
#'   package.
#'
loo_model_weights.stanreg_list <-
  function(x,
           ...,
           cores = getOption("mc.cores", 1),
           k_threshold = NULL) {

    loo_list <- vector(mode = "list", length = length(x))
    for (j in seq_along(x)) {
      loo_list[[j]] <-
        loo.stanreg(x[[j]], cores = cores, k_threshold = k_threshold)
    }
    wts <- loo::loo_model_weights.default(x = loo_list, ...)
    setNames(wts, names(x))
  }

# internal ----------------------------------------------------------------
validate_k_threshold <- function(k) {
  if (!is.numeric(k) || length(k) != 1) {
    stop("'k_threshold' must be a single numeric value.",
         call. = FALSE)
  } else if (k < 0) {
    stop("'k_threshold' < 0 not allowed.",
         call. = FALSE)
  } else if (k > 1) {
    warning(
      "Setting 'k_threshold' > 1 is not recommended.",
      "\nFor details see the PSIS-LOO section in help('loo-package', 'loo').",
      call. = FALSE
    )
  }
}
recommend_kfold <- function(n) {
  warning(
    "Found ", n, " observations with a pareto_k > 0.7. ",
    "With this many problematic observations we recommend calling ",
    "'kfold' with argument 'K=10' to perform 10-fold cross-validation ",
    "rather than LOO.\n",
    call. = FALSE
  )
}
recommend_reloo <- function(n) {
  warning(
    "Found ", n, " observation(s) with a pareto_k > 0.7. ",
    "We recommend calling 'loo' again with argument 'k_threshold = 0.7' ",
    "in order to calculate the ELPD without the assumption that ",
    "these observations are negligible. ", "This will refit the model ",
    n, " times to compute the ELPDs for the problematic observations directly.\n",
    call. = FALSE
  )
}
recommend_exact_loo <- function(reason) {
  stop(
    "'loo' is not supported if ", reason, ". ",
    "If refitting the model 'nobs(x)' times is feasible, ",
    "we recommend calling 'kfold' with K equal to the ",
    "total number of observations in the data to perform exact LOO-CV.\n",
    call. = FALSE
  )
}


# Refit model leaving out specific observations
#
# @param x stanreg object
# @param loo_x the result of loo(x)
# @param obs vector of observation indexes. the model will be refit length(obs)
#   times, each time leaving out one of the observations specified in 'obs'.
# @param ... unused currently
# @param refit logical, to toggle whether refitting actually happens (only used
#   to avoid refitting in tests)
#
# @return A modified version of 'loo_x'.
# @importFrom utils capture.output
reloo <- function(x, loo_x, obs, ..., refit = TRUE) {
  if (is.stanmvreg(x))
    STOP_if_stanmvreg("reloo")
  stopifnot(!is.null(x$data), is.loo(loo_x))

  J <- length(obs)
  d <- kfold_and_reloo_data(x)
  lls <- vector("list", J)
  message(
    J, " problematic observation(s) found.",
    "\nModel will be refit ", J, " times."
  )

  if (!refit)
    return(NULL)

  for (j in 1:J) {
    message(
      "\nFitting model ", j, " out of ", J,
      " (leaving out observation ", obs[j], ")"
    )
    omitted <- obs[j]

    if (is_clogit(x)) {
      strata_id <- model.weights(model.frame(x))
      omitted <- which(strata_id == strata_id[obs[j]])
    }

    fit_j_call <-
      update(
        x,
        data = d[-omitted, , drop = FALSE],
        subset = rep(TRUE, nrow(d) - length(omitted)),
        evaluate = FALSE,
        refresh = 0,
        open_progress = FALSE
      )
    fit_j_call$subset <- if (!is.stansurv(x)) eval(fit_j_call$subset) else NULL
    fit_j_call$data <- eval(fit_j_call$data)
    if (!is.null(getCall(x)$offset)) {
      fit_j_call$offset <- x$offset[-omitted]
    }
    capture.output(
      fit_j <- suppressWarnings(eval(fit_j_call))
    )

    lls[[j]] <-
      log_lik.stanreg(
        fit_j,
        newdata = d[omitted, , drop = FALSE],
        offset = x$offset[omitted],
        newx = get_x(x)[omitted, , drop = FALSE],
        newz = x$z[omitted, , drop = FALSE], # NULL other than for some stan_betareg models
        stanmat = as.matrix.stanreg(fit_j)
      )
  }

  # compute elpd_{loo,j} for each of the held out observations
  elpd_loo <- unlist(lapply(lls, log_mean_exp))

  # compute \hat{lpd}_j for each of the held out observations (using log-lik
  # matrix from full posterior, not the leave-one-out posteriors)
  ll_x <- log_lik(
    object = x,
    newdata = d[obs,, drop=FALSE],
    offset = x$offset[obs]
  )
  hat_lpd <- apply(ll_x, 2, log_mean_exp)

  # compute effective number of parameters
  p_loo <- hat_lpd - elpd_loo

  # replace parts of the loo object with these computed quantities
  sel <- c("elpd_loo", "p_loo", "looic")
  loo_x$pointwise[obs, sel] <- cbind(elpd_loo, p_loo,  -2 * elpd_loo)
  loo_x$estimates[sel, "Estimate"] <- with(loo_x, colSums(pointwise[, sel]))
  loo_x$estimates[sel, "SE"] <- with(loo_x, {
    N <- nrow(pointwise)
    sqrt(N * apply(pointwise[, sel], 2, var))
  })
  loo_x$diagnostics$pareto_k[obs] <- NA

  return(loo_x)
}

log_sum_exp2 <- function(a,b) {
  m <- max(a,b)
  m + log(sum(exp(c(a,b) - m)))
}

# @param x numeric vector
log_sum_exp <- function(x) {
  max_x <- max(x)
  max_x + log(sum(exp(x - max_x)))
}

# log_mean_exp (just log_sum_exp(x) - log(length(x)))
log_mean_exp <- function(x) {
  log_sum_exp(x) - log(length(x))
}

# Get correct data to use for kfold and reloo
#
# @param x stanreg object
# @return data frame
kfold_and_reloo_data <- function(x) {
  # either data frame or environment
  d <- x[["data"]]

  sub <- getCall(x)[["subset"]]
  if (!is.null(sub)) {
    keep <- eval(substitute(sub), envir = d)
  }

  if (is.environment(d)) {
    # make data frame
    d <- get_all_vars(formula(x), data = d)
  } else {
    # already a data frame
    all_vars <- all.vars(formula(x))
    if ("." %in% all_vars) {
      all_vars <- seq_len(ncol(d))
    }
    d <- d[, all_vars, drop=FALSE]
  }

  if (!is.null(sub)) {
    d <- d[keep,, drop=FALSE]
  }

  d <- na.omit(d)

  if (is_clogit(x)) {
    strata_var <- as.character(getCall(x)$strata)
    d[[strata_var]] <- model.weights(model.frame(x))
  }

  return(d)
}


# Calculate a SHA1 hash of y
# @param x stanreg object
# @param ... Passed to digest::sha1
#
hash_y <- function(x, ...) {
  if (!requireNamespace("digest", quietly = TRUE))
    stop("Please install the 'digest' package.")
  validate_stanreg_object(x)
  y <- get_y(x)
  attributes(y) <- NULL
  digest::sha1(x = y, ...)
}

# check if discrete or continuous
# @param object stanreg object
is_discrete <- function(object) {
  if (inherits(object, "polr"))
    return(TRUE)
  if (inherits(object, "stansurv"))
    return(FALSE)
  if (inherits(object, "stanmvreg")) {
    fams <- fetch(family(object), "family")
    res <- sapply(fams, function(x)
      is.binomial(x) || is.poisson(x) || is.nb(x))
    return(res)
  }
  fam <- family(object)$family
  is.binomial(fam) || is.poisson(fam) || is.nb(fam)
}

is.loo <- function(x) inherits(x, "loo")
is.kfold <- function(x) is.loo(x) && inherits(x, "kfold")
is.waic <- function(x) is.loo(x) && inherits(x, "waic")

# validate objects for model comparison
validate_loos <- function(loos = list()) {
  if (length(loos) <= 1)
    stop("At least two objects are required for model comparison.",
         call. = FALSE)

  is_loo <- sapply(loos, is.loo)
  is_waic <- sapply(loos, is.waic)
  is_kfold <- sapply(loos, is.kfold)
  if (!all(is_loo))
    stop("All objects must have class 'loo'", call. = FALSE)
  if ((any(is_waic) && !all(is_waic) ||
       (any(is_kfold) && !all(is_kfold))))
    stop("Can't mix objects computed using 'loo', 'waic', and 'kfold'.",
         call. = FALSE)

  yhash <- lapply(loos, attr, which = "yhash")
  yhash_check <- sapply(yhash, function(x) {
    isTRUE(all.equal(x, yhash[[1]]))
  })
  if (!all(yhash_check))
    stop("Not all models have the same y variable.", call. = FALSE)

  discrete <- sapply(loos, attr, which = "discrete")
  if (!all(discrete == discrete[1]))
    stop("Discrete and continuous observation models can't be compared.",
         call. = FALSE)

  setNames(loos, nm = lapply(loos, attr, which = "name"))
}


# chain_id to pass to loo::relative_eff
chain_id_for_loo <- function(object) {
  dims <- dim(object$stanfit)[1:2]
  n_iter <- dims[1]
  n_chain <- dims[2]
  rep(1:n_chain, each = n_iter)
}


# model formula to store in loo object
# @param x stanreg object
loo_model_formula <- function(x) {
  form <- try(formula(x), silent = TRUE)
  if (inherits(form, "try-error") || is.null(form)) {
    form <- "formula not found"
  }
  return(form)
}
csetraynor/aeim documentation built on May 15, 2019, 6:25 p.m.