R/kfold.R

Defines functions kfold_predict .kfold kfold.brmsfit

Documented in kfold.brmsfit kfold_predict

#' K-Fold Cross-Validation
#'
#' Perform exact K-fold cross-validation by refitting the model \eqn{K}
#' times each leaving out one-\eqn{K}th of the original data.
#' Folds can be run in parallel using the \pkg{future} package.
#'
#' @aliases kfold
#'
#' @inheritParams loo.brmsfit
#' @param K The number of subsets of equal (if possible) size
#'   into which the data will be partitioned for performing
#'   \eqn{K}-fold cross-validation. The model is refit \code{K} times, each time
#'   leaving out one of the \code{K} subsets. If \code{K} is equal to the total
#'   number of observations in the data then \eqn{K}-fold cross-validation is
#'   equivalent to exact leave-one-out cross-validation.
#' @param Ksub Optional number of subsets (of those subsets defined by \code{K})
#'   to be evaluated. If \code{NULL} (the default), \eqn{K}-fold cross-validation
#'   will be performed on all subsets. If \code{Ksub} is a single integer,
#'   \code{Ksub} subsets (out of all \code{K}) subsets will be randomly chosen.
#'   If \code{Ksub} consists of multiple integers or a one-dimensional array
#'   (created via \code{as.array}) potentially of length one, the corresponding
#'   subsets will be used. This argument is primarily useful, if evaluation of
#'   all subsets is infeasible for some reason.
#' @param folds Determines how the subsets are being constructed.
#'   Possible values are \code{NULL} (the default), \code{"stratified"},
#'   \code{"grouped"}, or \code{"loo"}. May also be a vector of length
#'   equal to the number of observations in the data. Alters the way
#'   \code{group} is handled. More information is provided in the 'Details'
#'   section.
#' @param group Optional name of a grouping variable or factor in the model.
#'   What exactly is done with this variable depends on argument \code{folds}.
#'   More information is provided in the 'Details' section.
#' @param exact_loo Deprecated! Please use \code{folds = "loo"} instead.
#' @param save_fits If \code{TRUE}, a component \code{fits} is added to
#'   the returned object to store the cross-validated \code{brmsfit}
#'   objects and the indices of the omitted observations for each fold.
#'   Defaults to \code{FALSE}.
#' @param recompile Logical, indicating whether the Stan model should be
#'   recompiled. This may be necessary if you are running \code{reloo} on
#'   another machine than the one used to fit the model.
#' @param future_args A list of further arguments passed to
#'   \code{\link[future:future]{future}} for additional control over parallel
#'   execution if activated.
#' @param ... Further arguments passed to \code{\link{brm}}.
#'
#' @return \code{kfold} returns an object that has a similar structure as the
#'   objects returned by the \code{loo} and \code{waic} methods and
#'   can be used with the same post-processing functions.
#'
#' @details The \code{kfold} function performs exact \eqn{K}-fold
#'   cross-validation. First the data are partitioned into \eqn{K} folds
#'   (i.e. subsets) of equal (or as close to equal as possible) size by default.
#'   Then the model is refit \eqn{K} times, each time leaving out one of the
#'   \code{K} subsets. If \eqn{K} is equal to the total number of observations
#'   in the data then \eqn{K}-fold cross-validation is equivalent to exact
#'   leave-one-out cross-validation (to which \code{loo} is an efficient
#'   approximation). The \code{compare_ic} function is also compatible with
#'   the objects returned by \code{kfold}.
#'
#'   The subsets can be constructed in multiple different ways:
#'   \itemize{
#'   \item If both \code{folds} and \code{group} are \code{NULL}, the subsets
#'   are randomly chosen so that they have equal (or as close to equal as
#'   possible) size.
#'   \item If \code{folds} is \code{NULL} but \code{group} is specified, the
#'   data is split up into subsets, each time omitting all observations of one
#'   of the factor levels, while ignoring argument \code{K}.
#'   \item If \code{folds = "stratified"} the subsets are stratified after
#'   \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_stratified}}.
#'   \item If \code{folds = "grouped"} the subsets are split by
#'   \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_grouped}}.
#'   \item If \code{folds = "loo"} exact leave-one-out cross-validation
#'   will be performed and \code{K} will be ignored. Further, if \code{group}
#'   is specified, all observations corresponding to the factor level of the
#'   currently predicted single value are omitted. Thus, in this case, the
#'   predicted values are only a subset of the omitted ones.
#'   \item If \code{folds} is a numeric vector, it must contain one element per
#'   observation in the data. Each element of the vector is an integer in
#'   \code{1:K} indicating to which of the \code{K} folds the corresponding
#'   observation belongs. There are some convenience functions available in
#'   the \pkg{loo} package that create integer vectors to use for this purpose
#'   (see the Examples section below and also the
#'   \link[loo:kfold-helpers]{kfold-helpers} page).
#'   }
#'
#'   When running \code{kfold} on a \code{brmsfit} created with the
#'   \pkg{cmdstanr} backend in a different \R session, several recompilations
#'   will be triggered because by default, \pkg{cmdstanr} writes the model
#'   executable to a temporary directory. To avoid that, set option
#'   \code{"cmdstanr_write_stan_file_dir"} to a nontemporary path of your choice
#'   before creating the original \code{brmsfit} (see section 'Examples' below).
#'
#' @examples
#' \dontrun{
#' fit1 <- brm(count ~ zAge + zBase * Trt + (1|patient) + (1|obs),
#'            data = epilepsy, family = poisson())
#' # throws warning about some pareto k estimates being too high
#' (loo1 <- loo(fit1))
#' # perform 10-fold cross validation
#' (kfold1 <- kfold(fit1, chains = 1))
#'
#' # use the future package for parallelization
#' library(future)
#' plan(multiprocess)
#' kfold(fit1, chains = 1)
#'
#' ## to avoid recompilations when running kfold() on a 'cmdstanr'-backend fit
#' ## in a fresh R session, set option 'cmdstanr_write_stan_file_dir' before
#' ## creating the initial 'brmsfit'
#' ## CAUTION: the following code creates some files in the current working
#' ## directory: two 'model_<hash>.stan' files, one 'model_<hash>(.exe)'
#' ## executable, and one 'fit_cmdstanr_<some_number>.rds' file
#' set.seed(7)
#' fname <- paste0("fit_cmdstanr_", sample.int(.Machine$integer.max, 1))
#' options(cmdstanr_write_stan_file_dir = getwd())
#' fit_cmdstanr <- brm(rate ~ conc + state,
#'                     data = Puromycin,
#'                     backend = "cmdstanr",
#'                     file = fname)
#' # now restart the R session and run the following (after attaching 'brms')
#' set.seed(7)
#' fname <- paste0("fit_cmdstanr_", sample.int(.Machine$integer.max, 1))
#' fit_cmdstanr <- brm(rate ~ conc + state,
#'                     data = Puromycin,
#'                     backend = "cmdstanr",
#'                     file = fname)
#' kfold_cmdstanr <- kfold(fit_cmdstanr, K = 2)
#' }
#'
#' @seealso \code{\link{loo}}, \code{\link{reloo}}
#'
#' @importFrom loo kfold
#' @export kfold
#' @export
kfold.brmsfit <- function(x, ..., K = 10, Ksub = NULL, folds = NULL,
                          group = NULL, exact_loo = NULL, compare = TRUE,
                          resp = NULL, model_names = NULL, save_fits = FALSE,
                          recompile = NULL, future_args = list()) {
  args <- split_dots(x, ..., model_names = model_names)
  use_stored <- ulapply(args$models, function(x) is_equal(x$kfold$K, K))
  if (!is.null(exact_loo) && as_one_logical(exact_loo)) {
    warning2("'exact_loo' is deprecated. Please use folds = 'loo' instead.")
    folds <- "loo"
  }
  c(args) <- nlist(
    criterion = "kfold", K, Ksub, folds, group, compare,
    resp, save_fits, recompile, future_args, use_stored
  )
  do_call(compute_loolist, args)
}

# helper function to perform k-fold cross-validation
# @inheritParams kfold.brmsfit
# @param model_name ignored but included to avoid being passed to '...'
.kfold <- function(x, K, Ksub, folds, group, save_fits,
                   newdata, resp, model_name, recompile = NULL,
                   future_args = list(), newdata2 = NULL, ...) {
  stopifnot(is.brmsfit(x), is.list(future_args))
  if (is.brmsfit_multiple(x)) {
    warn_brmsfit_multiple(x)
    class(x) <- "brmsfit"
  }
  if (is.null(newdata)) {
    newdata <- x$data
  } else {
    newdata <- as.data.frame(newdata)
  }
  if (is.null(newdata2)) {
    newdata2 <- x$data2
  } else {
    bterms <- brmsterms(x$formula)
    newdata2 <- validate_data2(newdata2, bterms)
  }
  N <- nrow(newdata)
  # validate argument 'group'
  if (!is.null(group)) {
    valid_groups <- get_cat_vars(x)
    if (length(group) != 1L || !group %in% valid_groups) {
      stop2("Group '", group, "' is not a valid grouping factor. ",
            "Valid groups are: \n", collapse_comma(valid_groups))
    }
    gvar <- factor(get(group, newdata))
  }
  # validate argument 'folds'
  if (is.null(folds)) {
    if (is.null(group)) {
      fold_type <- "random"
      folds <- loo::kfold_split_random(K, N)
    } else {
      fold_type <- "group"
      folds <- as.numeric(gvar)
      K <- length(levels(gvar))
      message("Setting 'K' to the number of levels of '", group, "' (", K, ")")
    }
  } else if (is.character(folds) && length(folds) == 1L) {
    opts <- c("loo", "stratified", "grouped")
    fold_type <- match.arg(folds, opts)
    req_group_opts <- c("stratified", "grouped")
    if (fold_type %in% req_group_opts && is.null(group)) {
      stop2("Argument 'group' is required for fold type '", fold_type, "'.")
    }
    if (fold_type == "loo") {
      folds <- seq_len(N)
      K <- N
      message("Setting 'K' to the number of observations (", K, ")")
    } else if (fold_type == "stratified") {
      folds <- loo::kfold_split_stratified(K, gvar)
    } else if (fold_type == "grouped") {
      folds <- loo::kfold_split_grouped(K, gvar)
    }
  } else {
    fold_type <- "custom"
    folds <- as.numeric(factor(folds))
    if (length(folds) != N) {
      stop2("If 'folds' is a vector, it must be of length N.")
    }
    K <- max(folds)
    message("Setting 'K' to the number of folds (", K, ")")
  }
  # validate argument 'Ksub'
  if (is.null(Ksub)) {
    Ksub <- seq_len(K)
  } else {
    # see issue #441 for reasons to check for arrays
    is_array_Ksub <- is.array(Ksub)
    Ksub <- as.integer(Ksub)
    if (any(Ksub <= 0 | Ksub > K)) {
      stop2("'Ksub' must contain positive integers not larger than 'K'.")
    }
    if (length(Ksub) == 1L && !is_array_Ksub) {
      Ksub <- sample(seq_len(K), Ksub)
    } else {
      Ksub <- unique(Ksub)
    }
    Ksub <- sort(Ksub)
  }

  # split dots for use in log_lik and update
  dots <- list(...)
  ll_arg_names <- arg_names("log_lik")
  ll_args <- dots[intersect(names(dots), ll_arg_names)]
  ll_args$allow_new_levels <- TRUE
  ll_args$resp <- resp
  ll_args$combine <- TRUE
  up_args <- dots[setdiff(names(dots), ll_arg_names)]
  up_args$refresh <- 0

  # function to be run inside future::future
  .kfold_k <- function(k) {
    if (fold_type == "loo" && !is.null(group)) {
      omitted <- which(folds == folds[k])
      predicted <- k
    } else {
      omitted <- predicted <- which(folds == k)
    }
    newdata_omitted <- newdata[-omitted, , drop = FALSE]
    fit <- x
    up_args$object <- fit
    up_args$newdata <- newdata_omitted
    up_args$data2 <- subset_data2(newdata2, -omitted)
    fit <- SW(do_call(update, up_args))
    ll_args$object <- fit
    ll_args$newdata <- newdata[predicted, , drop = FALSE]
    ll_args$newdata2 <- subset_data2(newdata2, predicted)
    lppds <- do_call(log_lik, ll_args)
    out <- nlist(lppds, omitted, predicted)
    if (save_fits) out$fit <- fit
    return(out)
  }

  futures <- vector("list", length(Ksub))
  lppds <- obs_order <- vector("list", length(Ksub))
  if (save_fits) {
    fits <- array(list(), dim = c(length(Ksub), 3))
    dimnames(fits) <- list(NULL, c("fit", "omitted", "predicted"))
  }

  x <- recompile_model(x, recompile = recompile)
  future_args$FUN <- .kfold_k
  future_args$seed <- TRUE
  for (k in Ksub) {
    ks <- match(k, Ksub)
    message("Fitting model ", k, " out of ", K)
    future_args$args <- list(k)
    futures[[ks]] <- do_call("futureCall", future_args, pkg = "future")
  }
  for (k in Ksub) {
    ks <- match(k, Ksub)
    tmp <- future::value(futures[[ks]])
    if (save_fits) {
      fits[ks, ] <- tmp[c("fit", "omitted", "predicted")]
    }
    obs_order[[ks]] <- tmp$predicted
    lppds[[ks]] <- tmp$lppds
  }

  lppds <- do_call(cbind, lppds)
  elpds <- apply(lppds, 2, log_mean_exp)
  # make sure elpds are put back in the right order
  obs_order <- unlist(obs_order)
  elpds <- elpds[order(obs_order)]
  # compute effective number of parameters
  ll_args$object <- x
  ll_args$newdata <- newdata
  ll_args$newdata2 <- newdata2
  ll_full <- do_call(log_lik, ll_args)
  lpds <- apply(ll_full, 2, log_mean_exp)
  ps <- lpds - elpds
  # put everything together in a loo object
  pointwise <- cbind(elpd_kfold = elpds, p_kfold = ps, kfoldic = -2 * elpds)
  est <- colSums(pointwise)
  se_est <- sqrt(nrow(pointwise) * apply(pointwise, 2, var))
  estimates <- cbind(Estimate = est, SE = se_est)
  rownames(estimates) <- colnames(pointwise)
  out <- nlist(estimates, pointwise)
  atts <- nlist(K, Ksub, group, folds, fold_type)
  attributes(out)[names(atts)] <- atts
  if (save_fits) {
    out$fits <- fits
    out$data <- newdata
  }
  structure(out, class = c("kfold", "loo"))
}

#' Predictions from K-Fold Cross-Validation
#'
#' Compute and evaluate predictions after performing K-fold
#' cross-validation via \code{\link{kfold}}.
#'
#' @param x Object of class \code{'kfold'} computed by \code{\link{kfold}}.
#'   For \code{kfold_predict} to work, the fitted model objects need to have
#'   been stored via argument \code{save_fits} of \code{\link{kfold}}.
#' @param method The method used to make predictions. Either \code{"predict"}
#'   or \code{"fitted"}. See \code{\link{predict.brmsfit}} for details.
#' @inheritParams predict.brmsfit
#'
#' @return A \code{list} with two slots named \code{'y'} and \code{'yrep'}.
#'   Slot \code{y} contains the vector of observed responses.
#'   Slot \code{yrep} contains the matrix of predicted responses,
#'   with rows being posterior draws and columns being observations.
#'
#' @seealso \code{\link{kfold}}
#'
#' @examples
#' \dontrun{
#' fit <- brm(count ~ zBase * Trt + (1|patient),
#'            data = epilepsy, family = poisson())
#'
#' # perform k-fold cross validation
#' (kf <- kfold(fit, save_fits = TRUE, chains = 1))
#'
#' # define a loss function
#' rmse <- function(y, yrep) {
#'   yrep_mean <- colMeans(yrep)
#'   sqrt(mean((yrep_mean - y)^2))
#' }
#'
#' # predict responses and evaluate the loss
#' kfp <- kfold_predict(kf)
#' rmse(y = kfp$y, yrep = kfp$yrep)
#' }
#'
#' @export
kfold_predict <- function(x, method = c("predict", "fitted"),
                          resp = NULL, ...) {
  if (!inherits(x, "kfold")) {
    stop2("'x' must be a 'kfold' object.")
  }
  if (!all(c("fits", "data") %in% names(x))) {
    stop2(
      "Slots 'fits' and 'data' are required. ",
      "Please run kfold with 'save_fits = TRUE'."
    )
  }
  method <- get(match.arg(method), mode = "function")
  resp <- validate_resp(resp, x$fits[[1, "fit"]], multiple = FALSE)
  all_predicted <- as.character(sort(unlist(x$fits[, "predicted"])))
  npredicted <- length(all_predicted)
  ndraws <- ndraws(x$fits[[1, "fit"]])
  y <- rep(NA, npredicted)
  yrep <- matrix(NA, nrow = ndraws, ncol = npredicted)
  names(y) <- colnames(yrep) <- all_predicted
  for (k in seq_rows(x$fits)) {
    fit_k <- x$fits[[k, "fit"]]
    predicted_k <- x$fits[[k, "predicted"]]
    obs_names <- as.character(predicted_k)
    newdata <- x$data[predicted_k, , drop = FALSE]
    y[obs_names] <- get_y(fit_k, resp, newdata = newdata, ...)
    yrep[, obs_names] <- method(
      fit_k, newdata = newdata, resp = resp,
      allow_new_levels = TRUE, summary = FALSE, ...
    )
  }
  nlist(y, yrep)
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Sept. 26, 2023, 1:08 a.m.