R/projpred.R

Defines functions latent_ilink_cumulative .extract_model_data get_refmodel.brmsfit

Documented in get_refmodel.brmsfit

#' Projection Predictive Variable Selection: Get Reference Model
#'
#' The \code{get_refmodel.brmsfit} method can be used to create the reference
#' model structure which is needed by the \pkg{projpred} package for performing
#' a projection predictive variable selection. This method is called
#' automatically when performing variable selection via
#' \code{\link[projpred:varsel]{varsel}} or
#' \code{\link[projpred:cv_varsel]{cv_varsel}}, so you will rarely need to call
#' it manually yourself.
#'
#' @inheritParams posterior_predict.brmsfit
#' @param cvfun Optional cross-validation function
#' (see \code{\link[projpred:get_refmodel]{get_refmodel}} for details).
#' If \code{NULL} (the default), \code{cvfun} is defined internally
#' based on \code{\link{kfold.brmsfit}}.
#' @param dis Passed to argument \code{dis} of
#'   \code{\link[projpred:init_refmodel]{init_refmodel}}, but leave this at
#'   \code{NULL} unless \pkg{projpred} complains about it.
#' @param latent See argument \code{latent} of
#'   \code{\link[projpred:extend_family]{extend_family}}. Setting this to
#'   \code{TRUE} requires a \pkg{projpred} version >= 2.4.0.
#' @param brms_seed A seed used to infer seeds for \code{\link{kfold.brmsfit}}
#'   and for sampling group-level effects for new levels (in multilevel models).
#'   If \code{NULL}, then \code{\link{set.seed}} is not called at all. If not
#'   \code{NULL}, then the pseudorandom number generator (PRNG) state is reset
#'   (to the state before calling this function) upon exiting this function.
#' @param ... Further arguments passed to
#' \code{\link[projpred:init_refmodel]{init_refmodel}}.
#'
#' @details Note that the \code{extract_model_data} function used internally by
#'   \code{get_refmodel.brmsfit} ignores arguments \code{wrhs} and \code{orhs}.
#'   This is relevant for
#'   \code{\link[projpred:predict.refmodel]{predict.refmodel}}, for example.
#'
#' @return A \code{refmodel} object to be used in conjunction with the
#'   \pkg{projpred} package.
#'
#' @examples
#' \dontrun{
#' # fit a simple model
#' fit <- brm(count ~ zAge + zBase * Trt,
#'            data = epilepsy, family = poisson())
#' summary(fit)
#'
#' # The following code requires the 'projpred' package to be installed:
#' library(projpred)
#'
#' # perform variable selection without cross-validation
#' vs <- varsel(fit)
#' summary(vs)
#' plot(vs)
#'
#' # perform variable selection with cross-validation
#' cv_vs <- cv_varsel(fit)
#' summary(cv_vs)
#' plot(cv_vs)
#' }
#' @exportS3Method projpred::get_refmodel brmsfit
get_refmodel.brmsfit <- function(object, newdata = NULL, resp = NULL,
                                 cvfun = NULL, dis = NULL, latent = FALSE,
                                 brms_seed = NULL, ...) {
  require_package("projpred")
  object <- restructure(object)
  stopifnot_resp(object, resp)
  resp <- validate_resp(resp, object, multiple = FALSE)
  formula <- formula(object)
  if (!is.null(resp)) {
    formula <- formula$forms[[resp]]
  }

  # Infer "sub-seeds":
  if (exists(".Random.seed", envir = .GlobalEnv)) {
    rng_state_old <- get(".Random.seed", envir = .GlobalEnv)
  }
  if (!is.null(brms_seed)) {
    if (exists(".Random.seed", envir = .GlobalEnv)) {
      on.exit(assign(".Random.seed", rng_state_old, envir = .GlobalEnv))
    }
    set.seed(brms_seed)
  }
  kfold_seed <- sample.int(.Machine$integer.max, 1)
  refprd_seed <- sample.int(.Machine$integer.max, 1)

  # prepare the family object for use in projpred
  family <- family(object, resp = resp)
  if (family$family == "bernoulli") {
    family$family <- "binomial"
  } else if (family$family == "gamma") {
    family$family <- "Gamma"
  } else if (family$family == "beta") {
    family$family <- "Beta"
  }
  aug_data <- (is_categorical(family) || is_ordinal(family)) && !latent
  # For the augmented-data and the latent approach, do not re-define the family
  # to preserve family-specific extra arguments ("extra" meaning "additionally
  # to `link`") like `refcat` and `thresholds` (see ?brmsfamily):
  if (!aug_data && !latent) {
    family <- get(family$family, mode = "function")(link = family$link)
  }

  # check if the model is supported by projpred
  bterms <- brmsterms(formula)
  if (length(bterms$dpars) > 1L && !conv_cats_dpars(family)) {
    stop2("Projpred does not support distributional models.")
  }
  if (conv_cats_dpars(family) && length(formula$pforms)) {
    stop2("Projpred does not support category-specific formulas.")
  }
  if (length(bterms$nlpars) > 0L) {
    stop2("Projpred does not support non-linear models.")
  }
  not_ok_term_types <- setdiff(all_term_types(), c("fe", "re", "offset", "sm"))
  if (any(not_ok_term_types %in% names(bterms$dpars$mu))) {
    stop2("Projpred only supports standard multilevel and smoothing terms as ",
          "well as offsets.")
  }

  # only use the raw formula for selection of terms
  formula <- formula$formula
  # LHS should only contain the response variable
  formula[[2]] <- bterms$respform[[2]]

  # projpred requires the dispersion parameter if present
  if (is.null(dis) && !latent) {
    if (family$family == "gaussian") {
      dis <- paste0("sigma", usc(resp))
      dis <- as.data.frame(object, variable = dis)[[dis]]
    } else if (family$family == "Gamma") {
      dis <- paste0("shape", usc(resp))
      dis <- as.data.frame(object, variable = dis)[[dis]]
    }
  }

  # allows to handle additional arguments implicitly
  extract_model_data <- function(object, newdata = NULL, ...) {
    .extract_model_data(object, newdata = newdata, resp = resp, ...)
  }

  # The default `ref_predfun` from projpred does not set `allow_new_levels`, so
  # use a customized `ref_predfun` which also handles some preparations for the
  # augmented-data projection:
  ref_predfun <- function(fit, newdata = NULL) {
    # Setting a seed is necessary for reproducible sampling of group-level
    # effects for new levels:
    if (exists(".Random.seed", envir = .GlobalEnv)) {
      rng_state_old <- get(".Random.seed", envir = .GlobalEnv)
      on.exit(assign(".Random.seed", rng_state_old, envir = .GlobalEnv))
    }
    set.seed(refprd_seed)
    lprd_args <- nlist(
      object = fit, newdata, resp, allow_new_levels = TRUE,
      sample_new_levels = "gaussian"
    )
    if (is_ordinal(family) && !latent) {
      c(lprd_args) <- list(incl_thres = TRUE)
    }
    out <- do_call(posterior_linpred, lprd_args)
    if (length(dim(out)) == 2) {
      out <- t(out)
    }
    out
  }

  if (utils::packageVersion("projpred") <= "2.0.2" && NROW(object$ranef)) {
    warning2("In projpred versions <= 2.0.2, projpred's K-fold CV results may ",
             "not be reproducible for multilevel brms reference models.")
  }

  # extract a list of K-fold sub-models
  if (is.null(cvfun)) {
    cvfun <- function(folds, ...) {
      kfold(
        object, K = max(folds), save_fits = TRUE, folds = folds,
        seed = kfold_seed, ...
      )$fits[, "fit"]
    }
  } else {
    if (!is.function(cvfun)) {
      stop2("'cvfun' should be a function.")
    }
  }

  cvrefbuilder <- function(cvfit) {
    # For `brms_seed` in fold `cvfit$projpred_k` (= k) of K, choose a new seed
    # which is based on the original `brms_seed`:
    if (is.null(brms_seed)) {
      brms_seed_k <- NULL
    } else {
      brms_seed_k <- brms_seed + cvfit$projpred_k
    }
    projpred::get_refmodel(cvfit, resp = resp, dis = dis, latent = latent,
                           brms_seed = brms_seed_k,
                           called_from_cvrefbuilder = TRUE, ...)
  }

  # prepare data passed to projpred
  if (!is.null(newdata)) {
    warning2("Argument 'newdata' of get_refmodel.brmsfit() is deprecated and ",
             "will be removed in the future.")
  }
  data <- current_data(
    object, newdata, resp = resp, check_response = TRUE,
    allow_new_levels = TRUE
  )
  attr(data, "terms") <- NULL
  args <- nlist(
    object, data, formula, family, dis, ref_predfun,
    cvfun, extract_model_data, cvrefbuilder, latent, ...
  )
  if (aug_data) {
    c(args) <- list(
      augdat_link = get(paste0("link_", family$family), mode = "function"),
      augdat_ilink = get(paste0("inv_link_", family$family), mode = "function")
    )
    if (is_ordinal(family)) {
      c(args) <- list(
        augdat_args_link = list(link = family$link),
        augdat_args_ilink = list(link = family$link)
      )
    }
  } else if (latent) {
    require_package("projpred", "2.4.0")
    if (family$family == "cumulative") {
      args$latent_ilink <- latent_ilink_cumulative(
        object = object, family = family, bterms = bterms, resp = resp
      )
    }
    # TODO: If requested by users, add response-scale support for more families:
    # For response-scale support, they all need a specific `latent_ilink`
    # function; some families (those for which the response can be numeric) also
    # require specific `latent_ll_oscale` and `latent_ppd_oscale` functions. The
    # binomial family (and thereby also the brms::bernoulli() family) has
    # response-scale support implemented natively in projpred.
  }
  do_call(projpred::init_refmodel, args)
}

# auxiliary data required in predictions via projpred
# @return a named list with slots 'y', 'weights', and 'offset'
.extract_model_data <- function(object, newdata = NULL, resp = NULL,
                                extract_y = TRUE, ...) {
  stopifnot(is.brmsfit(object))
  resp <- validate_resp(resp, object, multiple = FALSE)

  # extract the response variable manually instead of from make_standata
  # so that it passes input checks of validate_newdata later on (#1314)
  formula <- formula(object)
  if (!is.null(resp)) {
    formula <- formula$forms[[resp]]
  }
  bterms <- brmsterms(formula)
  y <- NULL
  if (extract_y) {
    data <- current_data(
      object, newdata, resp = resp, check_response = TRUE,
      allow_new_levels = TRUE, req_vars = all.vars(bterms$respform)
    )
    y <- model.response(model.frame(bterms$respform, data, na.action = na.pass))
    y <- unname(y)
  }

  # extract relevant auxiliary data (offsets and weights (or numbers of trials))
  # call standata to ensure the correct format of the data
  # For this, we use `check_response = FALSE` and only include offsets and
  # weights (or numbers of trials) in `req_vars` because of issue #1457 (note
  # that all.vars(NULL) gives character(0), as desired).
  req_vars <- unlist(lapply(bterms$dpars, function(x) all.vars(x[["offset"]])))
  req_vars <- unique(req_vars)
  c(req_vars) <- all.vars(bterms$adforms$weights)
  c(req_vars) <- all.vars(bterms$adforms$trials)
  args <- nlist(
    object, newdata, resp,
    allow_new_levels = TRUE,
    check_response = FALSE,
    internal = TRUE,
    req_vars = req_vars
  )
  # NOTE: Missing weights don't cause an error here (see #1459)
  sdata <- do_call(standata, args)

  usc_resp <- usc(resp)
  N <- sdata[[paste0("N", usc_resp)]]
  weights <- as.vector(sdata[[paste0("weights", usc_resp)]])
  trials <- as.vector(sdata[[paste0("trials", usc_resp)]])
  if (is_binary(formula)) {
    trials <- rep(1, N)
  }
  if (!is.null(trials)) {
    if (!is.null(weights)) {
      stop2("Projpred cannot handle 'trials' and 'weights' at the same time.")
    }
    weights <- trials
  }
  if (is.null(weights)) {
    weights <- rep(1, N)
  }
  offset <- as.vector(sdata[[paste0("offsets", usc_resp)]])
  if (is.null(offset)) {
    offset <- rep(0, N)
  }
  nlist(y, weights, offset)
}

# Construct the inverse-link function required for the latent projection in case
# of the cumulative family.
#
# @param object See argument `object` of get_refmodel.brmsfit(), but here, the
#   `object` as modified inside of get_refmodel.brmsfit() is required.
# @param family The `family` object corresponding to `object` (taking `resp`
#   into account). Could be re-inferred from `object` and `resp`, but for
#   computational efficiency, this is avoided.
# @param bterms The `brmsterms` object corresponding to `object` (or rather
#   `object`'s `formula`, taking `resp` into account). Could be re-inferred from
#   `object` and `resp`, but for computational efficiency, this is avoided.
# @param resp See argument `resp` of get_refmodel.brmsfit(), but here, the
#   `resp` as modified inside of get_refmodel.brmsfit() is required.
#
# @return A function to be supplied to projpred::extend_family()'s argument
#   `latent_ilink`.
latent_ilink_cumulative <- function(object, family, bterms, resp) {
  stopifnot(!is.null(family$cats))
  draws_mat <- as_draws_matrix(object)
  thres_regex <- paste0("^b", usc(combine_prefix(bterms)), "_Intercept\\[")
  thres_draws <- prepare_draws(draws_mat, variable = thres_regex, regex = TRUE)
  if (ncol(thres_draws) > length(family$cats) - 1L) {
    stop2("Currently, projpred does not support group-specific thresholds ",
          "(argument `gr` of resp_thres()).")
  }
  # Note: Currently, `disc` should always be constantly 1 because
  # distributional models are not allowed here.
  disc_regex <- paste0("^", "disc", resp, "$")
  disc_draws <- prepare_draws(draws_mat, variable = disc_regex, regex = TRUE)

  out <- function(lpreds, cl_ref, wdraws_ref = rep(1, length(cl_ref))) {
    thres_agg <- projpred::cl_agg(thres_draws, cl = cl_ref, wdraws = wdraws_ref)
    disc_agg <- projpred::cl_agg(disc_draws, cl = cl_ref, wdraws = wdraws_ref)
    disc_agg <- as.vector(disc_agg)
    lpreds_thres <- apply(thres_agg, 2, function(thres_agg_c) {
      # Notes on dimensionalities (with S_agg = `nrow(lpreds)`):
      # * `disc_agg` is a vector of length S_agg (because `disc` is not
      #   predicted here),
      # * `thres_agg` is S_agg x C_lat (with C_lat = `ncats - 1L` =
      #   `nthres`) and thus `thres_agg_c` is a vector of length S_agg,
      # * `lpreds` is S_agg x N (with N denoting the number of (possibly
      #   new) observations (not necessarily the original number of
      #   observations)).
      disc_agg * (thres_agg_c - lpreds)
    }, simplify = FALSE)
    # Coerce to an S_agg x N x C_lat array:
    lpreds_thres <- do.call(abind, c(lpreds_thres, rev.along = 0))
    # Transform to response space, yielding an S_agg x N x C_cat array:
    return(inv_link_cumulative(lpreds_thres, link = family$link))
  }
  # Free up some memory (keeping `draws_mat` would lead to unnecessary memory
  # usage because `draws_mat` would continue to live in the environment of the
  # returned function):
  rm(draws_mat)
  out
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Sept. 26, 2023, 1:08 a.m.