R/brms_tidiers.R

Defines functions augment.brmsfit glance.brmsfit tidy.brmsfit

Documented in augment.brmsfit glance.brmsfit tidy.brmsfit

#' Tidying methods for a brms model
#'
#' These methods tidy the estimates from
#' \code{\link[brms:brmsfit-class]{brmsfit-objects}}
#' (fitted model objects from the \pkg{brms} package) into a summary.
#'
#' @return All tidying methods return a \code{data.frame} without rownames.
#' The structure depends on the method chosen.
#'
#' @seealso \code{\link[brms]{brms}}, \code{\link[brms]{brmsfit-class}}
#'
#' @name brms_tidiers
#'
#' @param x Fitted model object from the \pkg{brms} package. See
#'   \code{\link[brms]{brmsfit-class}}.
#' @examples
#'  ## original model
#'  \dontrun{
#'     brms_crossedRE <- brm(mpg ~ wt + (1|cyl) + (1+wt|gear), data = mtcars,
#'            iter = 500, chains = 2)
#'  }
#'  \donttest{
#'    ## too slow for CRAN (>5 seconds)
#'    ## load stored object
#'    if (require("rstan") && require("brms")) {
#'       load(system.file("extdata", "brms_example.rda", package="broom.mixed"))
#'
#'       fit <- brms_crossedRE
#'       tidy(fit)
#'       tidy(fit, parameters = "^sd_", conf.int = FALSE)
#'       tidy(fit, effects = "fixed", conf.method="HPDinterval")
#'       tidy(fit, effects = "ran_vals")
#'       tidy(fit, effects = "ran_pars", robust = TRUE)
#'       if (require("posterior")) {
#'       tidy(fit, effects = "ran_pars", rhat = TRUE, ess = TRUE)
#'    
#'    }
#'    }
#'    if (require("rstan") && require("brms")) {
#'    # glance method
#'    glance(fit)
#'    ## this example will give a warning that it should be run with
#'    ## reloo=TRUE; however, doing this will fail
#'    ## because the \code{fit} object has been stripped down to save space
#'    suppressWarnings(glance(fit, looic = TRUE, cores = 1))
#'    head(augment(fit))
#'   }
#' }
#'
NULL
## examples for all methods (tidy/glance/augment) included in the same
##  block so we can surround them with a single "if (require(brms))" block

#' @rdname brms_tidiers
#' @param parameters Names of parameters for which a summary should be
#'   returned, as given by a character vector or regular expressions.
#'   If \code{NA} (the default) summarized parameters are specified
#'   by the \code{effects} argument.
#' @param effects A character vector including one or more of \code{"fixed"},
#'   \code{"ran_vals"}, or \code{"ran_pars"}.
#'   See the Value section for details.
#' @param robust Whether to use median and median absolute deviation of
#' the posterior distribution, rather
#'   than mean and standard deviation, to derive point estimates and uncertainty
#' @param conf.int If \code{TRUE} columns for the lower (\code{conf.low})
#' and upper bounds (\code{conf.high}) of posterior uncertainty intervals are included.
#' @param exponentiate  whether to exponentiate the fixed-effect coefficient estimates and confidence intervals (common for logistic regression); if \code{TRUE}, also scales the standard errors by the exponentiated coefficient, transforming them to the new scale
#' @param conf.level Defines the range of the posterior uncertainty conf.int,
#'  such that \code{100 * conf.level}\% of the parameter's posterior distributio
#'  lies within the corresponding interval.
#'  Only used if \code{conf.int = TRUE}.
#' @param conf.method method for computing confidence intervals
#' ("quantile" or "HPDinterval")
#' @param rhat whether to calculate the *Rhat* convergence metric
#' (\code{FALSE} by default)
#' @param ess whether to calculate the *effective sample size* (ESS) convergence metric
#' (\code{FALSE} by default)
#' @param fix.intercept rename "Intercept" parameter to "(Intercept)", to match
#' behaviour of other model types?
#' @param looic Should the LOO Information Criterion (and related info) be
#'   included? See \code{\link[rstan]{loo.stanfit}} for details. (This
#'   can be slow for models fit to large datasets.)
#' @param ... Extra arguments, not used
#' @return
#' When \code{parameters = NA}, the \code{effects} argument is used
#' to determine which parameters to summarize.
#'
#' Generally, \code{tidy.brmsfit} returns
#' one row for each coefficient, with at least three columns:
#' \item{term}{The name of the model parameter.}
#' \item{estimate}{A point estimate of the coefficient (mean or median).}
#' \item{std.error}{A standard error for the point estimate (sd or mad).}
#'
#' When \code{effects = "fixed"}, only population-level
#' effects are returned.
#'
#' When \code{effects = "ran_vals"}, only group-level effects are returned.
#' In this case, two additional columns are added:
#' \item{group}{The name of the grouping factor.}
#' \item{level}{The name of the level of the grouping factor.}
#'
#' Specifying \code{effects = "ran_pars"} selects the
#' standard deviations and correlations of the group-level parameters.
#'
#' If \code{conf.int = TRUE}, columns for the \code{lower} and
#' \code{upper} bounds of the posterior conf.int computed.
#'
#' @note The names \sQuote{fixed}, \sQuote{ran_pars}, and \sQuote{ran_vals}
#' (corresponding to "non-varying", "hierarchical", and "varying" respectively
#' in previous versions of the package), while technically inappropriate in
#' a Bayesian setting where "fixed" and "random" effects are not well-defined,
#' are used for compatibility with other (frequentist) mixed model types.
#' @note At present, the components of parameter estimates are separated by parsing the column names of \code{as_draws} (e.g. \code{r_patient[1,Intercept]} for the random effect on the intercept for patient 1, or \code{b_Trt1} for the fixed effect \code{Trt1}. We try to detect underscores in parameter names and warn, but detection may be imperfect.
#' @export
tidy.brmsfit <- function(x, parameters = NA,
                         effects = c("fixed", "ran_pars"),
                         robust = FALSE,
                         conf.int = TRUE, conf.level = 0.95,
                         conf.method = c("quantile", "HPDinterval"),
                         rhat = FALSE, ess = FALSE,
                         fix.intercept = TRUE,
                         exponentiate = FALSE,
                         ...) {

  check_dots(...)
  bad_effects <- setdiff(effects, c("fixed", "ran_pars", "ran_vals", "ran_coefs"))
  if (length(bad_effects)>0) {
      stop("unrecognized effects: ", paste(bad_effects, collapse = ", "))
  }
  std.error <- NULL ## NSE/code check
  if (!requireNamespace("brms", quietly=TRUE)) {
      stop("can't tidy brms objects without brms installed")
  }
  xr <- brms::restructure(x)
  has_ranef <- nrow(xr$ranef)>0
  if (any(grepl("_", rownames(fixef(x)))) ||
        (has_ranef && any(grepl("_", names(ranef(x)))))) {
      warning("some parameter names contain underscores: term naming may be unreliable!")
  }
  use_effects <- anyNA(parameters)
  conf.method <- match.arg(conf.method)
  is.multiresp <- length(x$formula$forms)>1
  ## make regular expression from a list of prefixes
  mkRE <- function(x,LB=FALSE) {
      pref <- "(^|_)"
      if (LB) pref <- sprintf("(?<=%s)",pref)
      sprintf("%s(%s)", pref, paste(unlist(x), collapse = "|"))
  }
  ## NOT USED:  could use this (or something like) to
  ##  obviate need for gsub("_","",str_extract(...)) pattern ...
  prefs_LB <- list(
      fixed = "b_", ran_vals = "r_",
      ## don't want to remove these pieces, so use look*behind*
      ran_pars =   sprintf("(?<=(%s))", c("sd_", "cor_", "sigma")),
      components = sprintf("(?<=%s)", c("zi_","disp_"))
  )
  prefs <- list(
      fixed = "b_", ran_vals = "r_",
      ## no lookahead (doesn't work with grep[l])
      ran_pars = c("sd_", "cor_", "sigma"),
      components = c("zi_", "disp_")
  )
  pref_RE <- mkRE(prefs[effects])
  if (use_effects) {
    ## prefixes distinguishing fixed, random effects

    parameters <- pref_RE
  }
  samples_perchain <- brms::as_draws_array(x, parameters, regex = TRUE)
  if (is.null(samples_perchain) || posterior::nvariables(samples_perchain) == 0) {
    stop("No parameter name matches the specified pattern.",
      call. = FALSE
    )
  }
  samples <- brms::as_draws_matrix(samples_perchain)
  terms <- colnames(samples)
  if (use_effects) {
      if (is.multiresp) {
        if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
           warning("ran_pars response/group tidying for multi-response models is currently incorrect")
        }
        ## FIXME: unfinished attempt to fix GH #39
        ## extract response component from terms
        ## resp0 <- strsplit(terms, "_+")
        ## resp1 <- sapply(resp0,
        ##          function(x) if (length(x)==2) x[2] else x[length(x)-1])
        ## ## put the pieces back together
        ## t0 <- lapply(resp0,
        ##          function(x) if (length(x)==2) x[1] else x[-(length(x)-1)])
        ## t1 <- lapply(t0,
        ##          function(x)
        ##              case_when(
        ##                  x[[1]]=="b"  ~ sprintf("b%s",x[[2]]),
        ##                  x[[2]]=="sd" ~ sprintf("sd_%s__%s",x[[2]],x[[3]]),
        ##                  x[[3]]=="cor" ~ sprintf("cor_%s_%s_%s_%s",
        ##                                          x[[2]],x[[3]],x[[4]],x[[5]])
        ##              ))
        ## resp0 <- stringr::str_extract_all(terms, "_[^_]+")
        ## resp1 <- lapply(resp0, gsub, pattern= "^_", replacement="")
        response <- gsub("^_","",stringr::str_extract(terms,"_[^_]+"))
        terms <- sub("_[^_]+","",terms)
    }
    res_list <- list()
    fixed.only <- identical(effects, "fixed")
    if ("fixed" %in% effects) {
      ## empty tibble: NA columns will be filled in as appropriate
      nfixed <- sum(grepl(prefs[["fixed"]], terms))
      res_list$fixed <- as_tibble(matrix(nrow = nfixed, ncol = 0))
    }
    grpfun <- function(x) {
        if (grepl("sigma",x[[1]])) "Residual" else x[[2]]
    }
    if ("ran_pars" %in% effects) {
      rterms <- grep(mkRE(prefs$ran_pars), terms, value = TRUE)
      ss <- strsplit(rterms, "__")
      pp <- "^(cor|sd)(?=(_))"
      nodash <- function(x) gsub("^_", "", x)
      ##  split the first term (cor/sd) into tag + group
      ss2 <- lapply(
        ss,
        function(x) {
          if (!is.na(pref <- stringr::str_extract(x[1], pp))) {
            return(c(pref, nodash(stringr::str_remove(x[1], pp)), x[-1]))
          }
          return(x)
        }
      )
      sep <- getOption("broom.mixed.sep1")
      termfun <- function(x) {
        if (grepl("^sigma",x[[1]])) {
            paste("sd", "Observation", sep = sep)
        } else {
            ## re-attach remaining terms
            paste(x[[1]],
                  paste(x[3:length(x)], collapse = "."),
                  sep = sep
          )
        }
      }
      res_list$ran_pars <-
        dplyr::tibble(
          group = sapply(ss2, grpfun),
          term = sapply(ss2, termfun)
        )
    }

    ## nice, but needs to be done outside averaging loop ...
    ##   meltfun <- function(a) {
          
    ##     dd <- as.data.frame(ftable(a)) |>  
    ##         setNames(c("level", "var", "term", "value")) |>
    ##         tidyr::pivot_wider(names_from = var, values_from = value) |>
    ##         rename(estimate = "Estimate",
    ##                std.error = "Est.Error",
    ##                ## FIXME: not robust to changing levels
    ##                conf.low = "Q2.5",
    ##                conf.high = "Q97.5")
    ##   }
          

    ## ## purrr:::map_dfr(ranef(x), meltfun, .id = "group")

    ## if ("ran_coefs" %in% effects) {
    ##     res_list$ran_coefs <- purrr:::map_dfr(coef(x), meltfun, .id = "group")
    ## }
    if ("ran_vals" %in% effects) {
    rterms <- grep(mkRE(prefs$ran_vals), terms, value = TRUE)

      vals <- stringr::str_match_all(rterms, "_(.+?)\\[(.+?),(.+?)\\]")

      res_list$ran_vals <-
        dplyr::tibble(
          group = purrr::map_chr(vals, function (v) { v[[2]] }),
          term = purrr::map_chr(vals, function (v) { v[[4]] }),
          level = purrr::map_chr(vals, function (v) { v[[3]] })
        )

    }
    out <- dplyr::bind_rows(res_list, .id = "effect")
    # In the case where nrow(res_list$fixed) > 0 but nrow(res_list$ran_pars) == 0,
    # the out object needs to be fixed a bit (replace columns with unexpected
    # lists of NULL by expected vectors of NA).
    for (col in c("group", "term")) {
      if (is.list(out[[col]]) && all(sapply(out[[col]], is.null))) {
        out[[col]] <- rep(NA, nrow(out))
      }
    }
    v <- if (fixed.only) seq(nrow(out)) else is.na(out$term)
    newterms <- stringr::str_remove(terms[v], mkRE(prefs[c("fixed")]))
    if (length(newterms)>0) { 
      if (fixed.only) {
        out$term <- newterms
      } else {
        out$term[v] <- newterms
      }
    }
    if (is.multiresp) {
        out$response <- response
    }
    ## prefixes already removed for ran_vals; don't remove for ran_pars
  } else {
    ## if !use_effects
    out <- dplyr::tibble(term = terms)
  }
  pointfun <- if (robust) stats::median else base::mean
  stdfun <- if (robust) stats::mad else stats::sd
  out$estimate <- apply(samples, 2, pointfun)
  out$std.error <- apply(samples, 2, stdfun)
  if (conf.int) {

    stopifnot(length(conf.level) == 1L)
    probs <- c((1 - conf.level) / 2, 1 - (1 - conf.level) / 2)
    if (conf.method == "HPDinterval") {
        cc <- coda::HPDinterval(coda::as.mcmc(samples), prob=conf.level)
    } else {
        cc <- t(apply(samples, 2, stats::quantile, probs = probs))
    }
    out$conf.low <- cc[,1]
    out$conf.high <- cc[,2]
  }
  posterior_metrics <- c()
  if (rhat) {
    posterior_metrics <- c(posterior_metrics, rhat = posterior::rhat)
  }
  if (ess) {
    posterior_metrics <- c(posterior_metrics, ess = posterior::ess_basic)
  }
  if (length(posterior_metrics) > 0) {
    if (!requireNamespace("posterior", quietly=TRUE)) {
        stop(paste0(paste0(names(posterior_metrics), collapse=", "),
             " calculation for brmsfit objects requires posterior package"))
    }
    out[names(posterior_metrics)] <- posterior::summarise_draws(samples_perchain, posterior_metrics)[names(posterior_metrics)]
  }
  ## figure out component
  out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
                                    ## ??? is this possible in brms models
                                    grepl("^disp",out$term) ~ "disp",
                                    TRUE ~ "cond")

  if (exponentiate) {
    vv <- c("estimate", "conf.low", "conf.high")
    out <- (out
      %>% mutate(across(contains(vv), exp))
      %>% mutate(across(std.error, ~ . * estimate))
    )
  }

  out$term <- stringr::str_remove(out$term,mkRE(prefs[["components"]],
                                                LB=TRUE))
  if (fix.intercept) {
      ## use lookahead/lookbehind: replace Intercept with word boundary
      ## or underscore before/after by (Intercept) - without removing
      ## underscores!
      out$term <- stringr::str_replace(out$term,
                                        "(?<=(\\b|_))Intercept(?=(\\b|_))",
                                        "(Intercept)")
  }
  out <- reorder_cols(out)
  return(out)
}


#' @importFrom stats quantile
#' @export
sigma.brmsfit <- function (object, ...)  {
    if (!("sigma" %in% brms::variables(object)))
        return(1)
    stats::quantile(brms::as_draws_array(object, "sigma"), probs=0.5)
}

#' @rdname brms_tidiers
#' @export
glance.brmsfit <- function(x, looic = FALSE, ...) {
  ## defined in rstanarm_tidiers.R
  glance_stan(x, looic = looic, type = "brmsfit", ...)
}

#' @rdname brms_tidiers
#' @param data data frame
#' @param newdata new data frame
#' @param se.fit return standard errors of fit?
#' @export
augment.brmsfit <- function(x, data = stats::model.frame(x), newdata = NULL,
                            se.fit = TRUE, ...) {
  ## can't use augment_columns because residuals.brmsfit returns
  ## a 4-column matrix (because summary=TRUE by default, no way
  ## to suppress this within augment_columns)
  ## ... add resids.arg to augment_columns?
  args <- list(x, se.fit = se.fit)
  if (!missing(newdata)) args$newdata <- newdata
  ## FIXME: influence measures??
  ## allow optional arguments to augment, e.g. pred.type,
  ## residual.type, re.form ...
  pred <- do.call(stats::predict, args)
  ret <- dplyr::tibble(.fitted = pred[, "Estimate"])
  if (se.fit) ret[[".se.fit"]] <- pred[, "Est.Error"]
  if (is.null(newdata)) {
    ret[[".resid"]] <- stats::residuals(x)[, "Estimate"]
    ret <- dplyr::bind_cols(as_tibble(data), ret)
  } else {
    ret <- dplyr::bind_cols(as_tibble(newdata), ret)
  }
  return(ret)
}

Try the broom.mixed package in your browser

Any scripts or data that you put into this service are public.

broom.mixed documentation built on Oct. 16, 2024, 1:06 a.m.