R/conditional_effects.R

Defines functions plot.brmsMarginalEffects print.brmsMarginalEffects marginal_effects.brmsfit marginal_effects plot.brms_conditional_effects print.brms_conditional_effects add_effects__ make_point_frame vars_specified prepare_cond_data prepare_conditions rows2labels get_cond__ make_conditions ordinal_probs_continuous get_int_vars.brmsterms get_int_vars.mvbrmsterms get_int_vars get_pred_vars get_all_effects.btnl get_all_effects_type get_all_effects.btl get_all_effects.brmsterms get_all_effects.mvbrmsterms get_all_effects.default get_all_effects get_var_combs conditional_effects.brmsterms conditional_effects.mvbrmsterms conditional_effects conditional_effects.brmsfit

Documented in conditional_effects conditional_effects.brmsfit make_conditions marginal_effects marginal_effects.brmsfit plot.brms_conditional_effects rows2labels

#' Display Conditional Effects of Predictors
#'
#' Display conditional effects of one or more numeric and/or categorical
#' predictors including two-way interaction effects.
#'
#' @aliases marginal_effects marginal_effects.brmsfit
#'
#' @param x An object of class \code{brmsfit}.
#' @param effects An optional character vector naming effects (main effects or
#'   interactions) for which to compute conditional plots. Interactions are
#'   specified by a \code{:} between variable names. If \code{NULL} (the
#'   default), plots are generated for all main effects and two-way interactions
#'   estimated in the model. When specifying \code{effects} manually, \emph{all}
#'   two-way interactions (including grouping variables) may be plotted
#'   even if not originally modeled.
#' @param conditions An optional \code{data.frame} containing variable values
#'   to condition on. Each effect defined in \code{effects} will
#'   be plotted separately for each row of \code{conditions}. Values in the
#'   \code{cond__} column will be used as titles of the subplots. If \code{cond__}
#'   is not given, the row names will be used for this purpose instead.
#'   It is recommended to only define a few rows in order to keep the plots clear.
#'   See \code{\link{make_conditions}} for an easy way to define conditions.
#'   If \code{NULL} (the default), numeric variables will be conditionalized by
#'   using their means and factors will get their first level assigned.
#'   \code{NA} values within factors are interpreted as if all dummy
#'   variables of this factor are zero. This allows, for instance, to make
#'   predictions of the grand mean when using sum coding.
#' @param int_conditions An optional named \code{list} whose elements are
#'   vectors of values of the variables specified in \code{effects}.
#'   At these values, predictions are evaluated. The names of
#'   \code{int_conditions} have to match the variable names exactly.
#'   Additionally, the elements of the vectors may be named themselves,
#'   in which case their names appear as labels for the conditions in the plots.
#'   Instead of vectors, functions returning vectors may be passed and are
#'   applied on the original values of the corresponding variable.
#'   If \code{NULL} (the default), predictions are evaluated at the
#'   \eqn{mean} and at \eqn{mean +/- sd} for numeric predictors and at
#'   all categories for factor-like predictors.
#' @param re_formula A formula containing group-level effects to be considered
#'   in the conditional predictions. If \code{NULL}, include all group-level
#'   effects; if \code{NA} (default), include no group-level effects.
#' @param robust If \code{TRUE} (the default) the median is used as the
#'   measure of central tendency. If \code{FALSE} the mean is used instead.
#' @param prob A value between 0 and 1 indicating the desired probability
#'   to be covered by the uncertainty intervals. The default is 0.95.
#' @param probs (Deprecated) The quantiles to be used in the computation of
#'   uncertainty intervals. Please use argument \code{prob} instead.
#' @param method Method used to obtain predictions. Can be set to
#'   \code{"posterior_epred"} (the default), \code{"posterior_predict"},
#'   or \code{"posterior_linpred"}. For more details, see the respective
#'   function documentations.
#' @param spaghetti Logical. Indicates if predictions should
#'   be visualized via spaghetti plots. Only applied for numeric
#'   predictors. If \code{TRUE}, it is recommended
#'   to set argument \code{ndraws} to a relatively small value
#'   (e.g., \code{100}) in order to reduce computation time.
#' @param surface Logical. Indicates if interactions or
#'   two-dimensional smooths should be visualized as a surface.
#'   Defaults to \code{FALSE}. The surface type can be controlled
#'   via argument \code{stype} of the related plotting method.
#' @param categorical Logical. Indicates if effects of categorical
#'   or ordinal models should be shown in terms of probabilities
#'   of response categories. Defaults to \code{FALSE}.
#' @param ordinal (Deprecated) Please use argument \code{categorical}.
#'   Logical. Indicates if effects in ordinal models
#'   should be visualized as a raster with the response categories
#'   on the y-axis. Defaults to \code{FALSE}.
#' @param transform A function or a character string naming
#'   a function to be applied on the predicted responses
#'   before summary statistics are computed. Only allowed
#'   if \code{method = "posterior_predict"}.
#' @param resolution Number of support points used to generate
#'   the plots. Higher resolution leads to smoother plots.
#'   Defaults to \code{100}. If \code{surface} is \code{TRUE},
#'   this implies \code{10000} support points for interaction terms,
#'   so it might be necessary to reduce \code{resolution}
#'   when only few RAM is available.
#' @param too_far Positive number.
#'   For surface plots only: Grid points that are too
#'   far away from the actual data points can be excluded from the plot.
#'   \code{too_far} determines what is too far. The grid is scaled into
#'   the unit square and then grid points more than \code{too_far}
#'   from the predictor variables are excluded. By default, all
#'   grid points are used. Ignored for non-surface plots.
#' @param select_points Positive number.
#'   Only relevant if \code{points} or \code{rug} are set to \code{TRUE}:
#'   Actual data points of numeric variables that
#'   are too far away from the values specified in \code{conditions}
#'   can be excluded from the plot. Values are scaled into
#'   the unit interval and then points more than \code{select_points}
#'   from the values in \code{conditions} are excluded.
#'   By default, all points are used.
#' @param ... Further arguments such as \code{draw_ids} or \code{ndraws}
#'   passed to \code{\link{posterior_predict}} or \code{\link{posterior_epred}}.
#' @inheritParams plot.brmsfit
#' @param ncol Number of plots to display per column for each effect.
#'   If \code{NULL} (default), \code{ncol} is computed internally based
#'   on the number of rows of \code{conditions}.
#' @param points Logical. Indicates if the original data points should be added
#'   via \code{\link[ggplot2:geom_jitter]{geom_jitter}}. Default is
#'   \code{FALSE}. Can be controlled globally via the \code{brms.plot_points}
#'   option. Note that only those data points will be added that match the
#'   specified conditions defined in \code{conditions}. For categorical
#'   predictors, the conditions have to match exactly. For numeric predictors,
#'   argument \code{select_points} is used to determine, which points do match a
#'   condition.
#' @param rug Logical. Indicates if a rug representation of predictor values
#'   should be added via \code{\link[ggplot2:geom_rug]{geom_rug}}. Default is
#'   \code{FALSE}. Depends on \code{select_points} in the same way as
#'   \code{points} does. Can be controlled globally via the \code{brms.plot_rug}
#'   option.
#' @param mean Logical. Only relevant for spaghetti plots.
#'   If \code{TRUE} (the default), display the mean regression
#'   line on top of the regression lines for each sample.
#' @param jitter_width Only used if \code{points = TRUE}:
#'   Amount of horizontal jittering of the data points.
#'   Mainly useful for ordinal models. Defaults to \code{0} that
#'   is no jittering.
#' @param stype Indicates how surface plots should be displayed.
#'   Either \code{"contour"} or \code{"raster"}.
#' @param line_args Only used in plots of continuous predictors:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_smooth]{geom_smooth}}.
#' @param cat_args Only used in plots of categorical predictors:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_point]{geom_point}}.
#' @param errorbar_args Only used in plots of categorical predictors:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_errorbar]{geom_errorbar}}.
#' @param surface_args Only used in surface plots:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_contour]{geom_contour}} or
#'   \code{\link[ggplot2:geom_raster]{geom_raster}}
#'   (depending on argument \code{stype}).
#' @param spaghetti_args Only used in spaghetti plots:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_smooth]{geom_smooth}}.
#' @param point_args Only used if \code{points = TRUE}:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_jitter]{geom_jitter}}.
#' @param rug_args Only used if \code{rug = TRUE}:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:geom_rug]{geom_rug}}.
#' @param facet_args Only used if if multiple conditions are provided:
#'   A named list of arguments passed to
#'   \code{\link[ggplot2:facet_wrap]{facet_wrap}}.
#'
#' @return An object of class \code{'brms_conditional_effects'} which is a
#'   named list with one data.frame per effect containing all information
#'   required to generate conditional effects plots. Among others, these
#'   data.frames contain some special variables, namely \code{estimate__}
#'   (predicted values of the response), \code{se__} (standard error of the
#'   predicted response), \code{lower__} and \code{upper__} (lower and upper
#'   bounds of the uncertainty interval of the response), as well as
#'   \code{cond__} (used in faceting when \code{conditions} contains multiple
#'   rows).
#'
#'   The corresponding \code{plot} method returns a named
#'   list of \code{\link[ggplot2:ggplot]{ggplot}} objects, which can be further
#'   customized using the \pkg{ggplot2} package.
#'
#' @details When creating \code{conditional_effects} for a particular predictor
#'   (or interaction of two predictors), one has to choose the values of all
#'   other predictors to condition on. By default, the mean is used for
#'   continuous variables and the reference category is used for factors, but
#'   you may change these values via argument \code{conditions}. This also has
#'   an implication for the \code{points} argument: In the created plots, only
#'   those points will be shown that correspond to the factor levels actually
#'   used in the conditioning, in order not to create the false impression of
#'   bad model fit, where it is just due to conditioning on certain factor
#'   levels.
#'
#'   To fully change colors of the created plots, one has to amend both
#'   \code{scale_colour} and \code{scale_fill}. See
#'   \code{\link[ggplot2:scale_color_grey]{scale_colour_grey}} or
#'   \code{\link[ggplot2:scale_color_gradient]{scale_colour_gradient}} for
#'   more details.
#'
#' @examples
#' \dontrun{
#' fit <- brm(count ~ zAge + zBase * Trt + (1 | patient),
#'            data = epilepsy, family = poisson())
#'
#' ## plot all conditional effects
#' plot(conditional_effects(fit), ask = FALSE)
#'
#' ## change colours to grey scale
#' library(ggplot2)
#' ce <- conditional_effects(fit, "zBase:Trt")
#' plot(ce, plot = FALSE)[[1]] +
#'   scale_color_grey() +
#'   scale_fill_grey()
#'
#' ## only plot the conditional interaction effect of 'zBase:Trt'
#' ## for different values for 'zAge'
#' conditions <- data.frame(zAge = c(-1, 0, 1))
#' plot(conditional_effects(fit, effects = "zBase:Trt",
#'                          conditions = conditions))
#'
#' ## also incorporate group-level effects variance over patients
#' ## also add data points and a rug representation of predictor values
#' plot(conditional_effects(fit, effects = "zBase:Trt",
#'                          conditions = conditions, re_formula = NULL),
#'      points = TRUE, rug = TRUE)
#'
#' ## change handling of two-way interactions
#' int_conditions <- list(
#'   zBase = setNames(c(-2, 1, 0), c("b", "c", "a"))
#' )
#' conditional_effects(fit, effects = "Trt:zBase",
#'                     int_conditions = int_conditions)
#' conditional_effects(fit, effects = "Trt:zBase",
#'                     int_conditions = list(zBase = quantile))
#'
#' ## fit a model to illustrate how to plot 3-way interactions
#' fit3way <- brm(count ~ zAge * zBase * Trt, data = epilepsy)
#' conditions <- make_conditions(fit3way, "zAge")
#' conditional_effects(fit3way, "zBase:Trt", conditions = conditions)
#' ## only include points close to the specified values of zAge
#' ce <- conditional_effects(
#'   fit3way, "zBase:Trt", conditions = conditions,
#'   select_points = 0.1
#' )
#' plot(ce, points = TRUE)
#' }
#'
#' @export
conditional_effects.brmsfit <- function(x, effects = NULL, conditions = NULL,
                                        int_conditions = NULL, re_formula = NA,
                                        prob = 0.95, robust = TRUE,
                                        method = "posterior_epred",
                                        spaghetti = FALSE, surface = FALSE,
                                        categorical = FALSE, ordinal = FALSE,
                                        transform = NULL, resolution = 100,
                                        select_points = 0, too_far = 0,
                                        probs = NULL, ...) {
  probs <- validate_ci_bounds(prob, probs = probs)
  method <- validate_pp_method(method)
  spaghetti <- as_one_logical(spaghetti)
  surface <- as_one_logical(surface)
  categorical <- as_one_logical(categorical)
  ordinal <- as_one_logical(ordinal)
  contains_draws(x)
  x <- restructure(x)
  new_formula <- update_re_terms(x$formula, re_formula = re_formula)
  bterms <- brmsterms(new_formula)

  if (!is.null(transform) && method != "posterior_predict") {
    stop2("'transform' is only allowed if 'method = posterior_predict'.")
  }
  if (ordinal) {
    warning2("Argument 'ordinal' is deprecated. ",
             "Please use 'categorical' instead.")
  }
  rsv_vars <- rsv_vars(bterms)
  use_def_effects <- is.null(effects)
  if (use_def_effects) {
    effects <- get_all_effects(bterms, rsv_vars = rsv_vars)
  } else {
    # allow to define interactions in any order
    effects <- strsplit(as.character(effects), split = ":")
    if (any(unique(unlist(effects)) %in% rsv_vars)) {
      stop2("Variables ", collapse_comma(rsv_vars),
            " should not be used as effects for this model")
    }
    if (any(lengths(effects) > 2L)) {
      stop2("To display interactions of order higher than 2 ",
            "please use the 'conditions' argument.")
    }
    all_effects <- get_all_effects(
      bterms, rsv_vars = rsv_vars, comb_all = TRUE
    )
    ae_coll <- all_effects[lengths(all_effects) == 1L]
    ae_coll <- ulapply(ae_coll, paste, collapse = ":")
    matches <- match(lapply(all_effects, sort), lapply(effects, sort), 0L)
    if (sum(matches) > 0 && sum(matches > 0) < length(effects)) {
      invalid <- effects[setdiff(seq_along(effects), sort(matches))]
      invalid <- ulapply(invalid, paste, collapse = ":")
      warning2(
        "Some specified effects are invalid for this model: ",
        collapse_comma(invalid), "\nValid effects are ",
        "(combinations of): ", collapse_comma(ae_coll)
      )
    }
    effects <- unique(effects[sort(matches)])
    if (!length(effects)) {
      stop2(
        "All specified effects are invalid for this model.\n",
        "Valid effects are (combinations of): ",
        collapse_comma(ae_coll)
      )
    }
  }
  if (categorical || ordinal) {
    int_effs <- lengths(effects) == 2L
    if (any(int_effs)) {
      effects <- effects[!int_effs]
      warning2(
        "Interactions cannot be plotted directly if 'categorical' ",
        "is TRUE. Please use argument 'conditions' instead."
      )
    }
  }
  if (!length(effects)) {
    stop2("No valid effects detected.")
  }
  mf <- model.frame(x)
  conditions <- prepare_conditions(
    x, conditions = conditions, effects = effects,
    re_formula = re_formula, rsv_vars = rsv_vars
  )
  int_conditions <- lapply(int_conditions,
    function(x) if (is.numeric(x)) sort(x, TRUE) else x
  )
  int_vars <- get_int_vars(bterms)
  group_vars <- get_group_vars(bterms)
  out <- list()
  for (i in seq_along(effects)) {
    eff <- effects[[i]]
    cond_data <- prepare_cond_data(
      mf[, eff, drop = FALSE], conditions = conditions,
      int_conditions = int_conditions, int_vars = int_vars,
      group_vars = group_vars, surface = surface,
      resolution = resolution, reorder = use_def_effects
    )
    if (surface && length(eff) == 2L && too_far > 0) {
      # exclude prediction grid points too far from data
      ex_too_far <- mgcv::exclude.too.far(
        g1 = cond_data[[eff[1]]],
        g2 = cond_data[[eff[2]]],
        d1 = mf[, eff[1]],
        d2 = mf[, eff[2]],
        dist = too_far)
      cond_data <- cond_data[!ex_too_far, ]
    }
    c(out) <- conditional_effects(
      bterms, fit = x, cond_data = cond_data, method = method,
      surface = surface, spaghetti = spaghetti, categorical = categorical,
      ordinal = ordinal, re_formula = re_formula, transform = transform,
      conditions = conditions, int_conditions = int_conditions,
      select_points = select_points, probs = probs, robust = robust,
      ...
    )
  }
  structure(out, class = "brms_conditional_effects")
}

#' @rdname conditional_effects.brmsfit
#' @export
conditional_effects <- function(x, ...) {
  UseMethod("conditional_effects")
}

# compute expected values of MV models for use in conditional_effects
# @return a list of summarized prediction matrices
#' @export
conditional_effects.mvbrmsterms <- function(x, resp = NULL, ...) {
  resp <- validate_resp(resp, x$responses)
  x$terms <- x$terms[resp]
  out <- lapply(x$terms, conditional_effects, ...)
  unlist(out, recursive = FALSE)
}

# conditional_effects for univariate model
# @return a list with the summarized prediction matrix as the only element
# @note argument 'resp' exists only to be excluded from '...' (#589)
#' @export
conditional_effects.brmsterms <- function(
  x, fit, cond_data, int_conditions, method, surface,
  spaghetti, categorical, ordinal, probs, robust,
  dpar = NULL, nlpar = NULL, resp = NULL, ...
) {
  stopifnot(is.brmsfit(fit))
  effects <- attr(cond_data, "effects")
  types <- attr(cond_data, "types")
  catscale <- NULL
  pred_args <- list(
    fit, newdata = cond_data, allow_new_levels = TRUE,
    dpar = dpar, nlpar = nlpar, resp = if (nzchar(x$resp)) x$resp,
    incl_autocor = FALSE, ...
  )
  if (method != "posterior_predict") {
    # 'transform' creates problems in 'posterior_linpred'
    pred_args$transform <- NULL
  }
  out <- do_call(method, pred_args)
  rownames(cond_data) <- NULL

  if (categorical || ordinal) {
    if (method != "posterior_epred") {
      stop2("Can only use 'categorical' with method = 'posterior_epred'.")
    }
    if (!is_polytomous(x)) {
      stop2("Argument 'categorical' may only be used ",
            "for categorical or ordinal models.")
    }
    if (categorical && ordinal) {
      stop2("Please use argument 'categorical' instead of 'ordinal'.")
    }
    catscale <- str_if(is_multinomial(x), "Count", "Probability")
    cats <- dimnames(out)[[3]]
    if (is.null(cats)) cats <- seq_dim(out, 3)
    cond_data <- repl(cond_data, length(cats))
    cond_data <- do_call(rbind, cond_data)
    cond_data$cats__ <- factor(rep(cats, each = ncol(out)), levels = cats)
    effects[2] <- "cats__"
    types[2] <- "factor"
  } else {
    if (conv_cats_dpars(x$family) && is.null(dpar)) {
      stop2("Please set 'categorical' to TRUE.")
    }
    if (is_ordinal(x$family) && is.null(dpar) && method != "posterior_linpred") {
      warning2(
        "Predictions are treated as continuous variables in ",
        "'conditional_effects' by default which is likely invalid ",
        "for ordinal families. Please set 'categorical' to TRUE."
      )
      if (method == "posterior_epred") {
        out <- ordinal_probs_continuous(out)
      }
    }
  }

  cond_data <- add_effects__(cond_data, effects)
  first_numeric <- types[1] %in% "numeric"
  second_numeric <- types[2] %in% "numeric"
  both_numeric <- first_numeric && second_numeric
  if (second_numeric && !surface) {
    # only convert 'effect2__' to factor so that the original
    # second effect variable remains unchanged in the data
    mde2 <- round(cond_data[[effects[2]]], 2)
    levels2 <- sort(unique(mde2), TRUE)
    cond_data$effect2__ <- factor(mde2, levels = levels2)
    labels2 <- names(int_conditions[[effects[2]]])
    if (length(labels2) == length(levels2)) {
      levels(cond_data$effect2__) <- labels2
    }
  }

  spag <- NULL
  if (first_numeric && spaghetti) {
    if (surface) {
      stop2("Cannot use 'spaghetti' and 'surface' at the same time.")
    }
    spag <- out
    if (categorical) {
      spag <- do_call(cbind, array2list(spag))
    }
    sample <- rep(seq_rows(spag), each = ncol(spag))
    if (length(types) == 2L) {
      # draws should be unique across plotting groups
      sample <- paste0(sample, "_", cond_data[[effects[2]]])
    }
    spag <- data.frame(as.numeric(t(spag)), factor(sample))
    colnames(spag) <- c("estimate__", "sample__")
    # ensures that 'cbind' works even in the presence of matrix columns
    cond_data_spag <- repl(cond_data, nrow(spag) / nrow(cond_data))
    cond_data_spag <- Reduce(rbind, cond_data_spag)
    spag <- cbind(cond_data_spag, spag)
  }

  out <- posterior_summary(out, probs = probs, robust = robust)
  if (categorical || ordinal) {
    out <- do_call(rbind, array2list(out))
  }
  colnames(out) <- c("estimate__", "se__", "lower__", "upper__")
  out <- cbind(cond_data, out)
  if (!is.null(dpar)) {
    response <- dpar
  } else if (!is.null(nlpar)) {
    response <- nlpar
  } else {
    response <- as.character(x$formula[2])
  }
  attr(out, "effects") <- effects
  attr(out, "response") <- response
  attr(out, "surface") <- unname(both_numeric && surface)
  attr(out, "categorical") <- categorical
  attr(out, "catscale") <- catscale
  attr(out, "ordinal") <- ordinal
  attr(out, "spaghetti") <- spag
  attr(out, "points") <- make_point_frame(x, fit$data, effects, ...)
  name <- paste0(usc(x$resp, "suffix"), paste0(effects, collapse = ":"))
  setNames(list(out), name)
}

# get combinations of variables used in predictor terms
# @param ... character vectors or formulas
# @param alist a list of character vectors or formulas
get_var_combs <- function(..., alist = list()) {
  dots <- c(list(...), alist)
  for (i in seq_along(dots)) {
    if (is.formula(dots[[i]])) {
      dots[[i]] <- attr(terms(dots[[i]]), "term.labels")
    }
    dots[[i]] <- lapply(dots[[i]], all_vars)
  }
  unique(unlist(dots, recursive = FALSE))
}

# extract combinations of predictor variables
get_all_effects <- function(x, ...) {
  UseMethod("get_all_effects")
}

#' @export
get_all_effects.default <- function(x, ...) {
  NULL
}

#' @export
get_all_effects.mvbrmsterms <- function(x, ...) {
  out <- lapply(x$terms, get_all_effects, ...)
  unique(unlist(out, recursive = FALSE))
}

# get all effects for use in conditional_effects
# @param bterms object of class brmsterms
# @param rsv_vars character vector of reserved variables
# @param comb_all include all main effects and two-way interactions?
# @return a list with one element per valid effect / effects combination
#   excludes all 3-way or higher interactions
#' @export
get_all_effects.brmsterms <- function(x, rsv_vars = NULL, comb_all = FALSE, ...) {
  stopifnot(is_atomic_or_null(rsv_vars))
  out <- list()
  for (dp in names(x$dpars)) {
    out <- c(out, get_all_effects(x$dpars[[dp]]))
  }
  for (nlp in names(x$nlpars)) {
    out <- c(out, get_all_effects(x$nlpars[[nlp]]))
  }
  out <- rmNULL(lapply(out, setdiff, y = rsv_vars))
  if (comb_all) {
    # allow to combine all variables with each other
    out <- unique(unlist(out))
    out <- c(out, get_group_vars(x))
    if (length(out)) {
      int <- expand.grid(out, out, stringsAsFactors = FALSE)
      int <- int[int[, 1] != int[, 2], ]
      int <- as.list(as.data.frame(t(int), stringsAsFactors = FALSE))
      int <- unique(unname(lapply(int, sort)))
      out <- c(as.list(out), int)
    }
  }
  unique(out[lengths(out) <= 2L])
}

#' @export
get_all_effects.btl <- function(x, ...) {
  c(get_var_combs(x[["fe"]], x[["cs"]]),
    get_all_effects_type(x, "sp"),
    get_all_effects_type(x, "sm"),
    get_all_effects_type(x, "gp"))
}

# extract variable combinations from special terms
get_all_effects_type <- function(x, type) {
  stopifnot(is.btl(x))
  type <- as_one_character(type)
  regex_type <- regex_sp(type)
  terms <- all_terms(x[[type]])
  out <- named_list(terms)
  for (i in seq_along(terms)) {
    # some special terms can appear within interactions
    # we did not allow ":" within these terms so we can use it for splitting
    term_parts <- unlist(strsplit(terms[i], split = ":"))
    vars <- vector("list", length(term_parts))
    for (j in seq_along(term_parts)) {
      matches <- get_matches_expr(regex_type, term_parts[j])
      for (k in seq_along(matches)) {
        # evaluate special terms to extract variables
        tmp <- eval2(matches[[k]])
        c(vars[[j]]) <- setdiff(unique(c(tmp$term, tmp$by)), "NA")
      }
      # extract all variables not part of any special term
      c(vars[[j]]) <- setdiff(all_vars(term_parts[j]), all_vars(matches))
    }
    vars <- unique(unlist(vars))
    out[[i]] <- str2formula(vars, collapse = "*")
  }
  get_var_combs(alist = out)
}

#' @export
get_all_effects.btnl <- function(x, ...) {
  covars <- all_vars(rhs(x$covars))
  out <- as.list(covars)
  if (length(covars) > 1L) {
    c(out) <- utils::combn(covars, 2, simplify = FALSE)
  }
  unique(out)
}

# extract names of predictor variables
get_pred_vars <- function(x) {
  unique(unlist(get_all_effects(x)))
}

# extract names of variables treated as integers
get_int_vars <- function(x, ...) {
  UseMethod("get_int_vars")
}

#' @export
get_int_vars.mvbrmsterms <- function(x, ...) {
  unique(ulapply(x$terms, get_int_vars))
}

#' @export
get_int_vars.brmsterms <- function(x, ...) {
  adterms <- c("trials", "thres", "vint")
  advars <- ulapply(rmNULL(x$adforms[adterms]), all_vars)
  unique(c(advars, get_sp_vars(x, "mo")))
}

# transform posterior draws of ordinal probabilities to a
# continuous scale assuming equidistance between adjacent categories
# @param x an ndraws x nobs x ncat array of posterior draws
# @return an ndraws x nobs matrix of posterior draws
ordinal_probs_continuous <- function(x) {
  stopifnot(length(dim(x)) == 3)
  for (k in seq_dim(x, 3)) {
    x[, , k] <- x[, , k] * k
  }
  x <- lapply(seq_dim(x, 2), function(s) rowSums(x[, s, ]))
  do_call(cbind, x)
}

#' Prepare Fully Crossed Conditions
#'
#' This is a helper function to prepare fully crossed conditions primarily
#' for use with the \code{conditions} argument of \code{\link{conditional_effects}}.
#' Automatically creates labels for each row in the \code{cond__} column.
#'
#' @param x An \R object from which to extract the variables
#'   that should be part of the conditions.
#' @param vars Names of the variables that should be part of the conditions.
#' @param ... Arguments passed to \code{\link{rows2labels}}.
#'
#' @return A \code{data.frame} where each row indicates a condition.
#'
#' @details For factor like variables, all levels are used as conditions.
#'   For numeric variables, \code{mean + (-1:1) * SD} are used as conditions.
#'
#' @seealso \code{\link{conditional_effects}}, \code{\link{rows2labels}}
#'
#' @examples
#' df <- data.frame(x = c("a", "b"), y = rnorm(10))
#' make_conditions(df, vars = c("x", "y"))
#'
#' @export
make_conditions <- function(x, vars, ...) {
  # rev ensures that the last variable varies fastest in expand.grid
  vars <- rev(as.character(vars))
  if (!is.data.frame(x) && "data" %in% names(x)) {
    x <- x$data
  }
  x <- as.data.frame(x)
  out <- named_list(vars)
  for (v in vars) {
    tmp <- get(v, x)
    if (is_like_factor(tmp)) {
      tmp <- levels(as.factor(tmp))
    } else {
      tmp <- mean(tmp, na.rm = TRUE) + (-1:1) * sd(tmp, na.rm = TRUE)
    }
    out[[v]] <- tmp
  }
  out <- rev(expand.grid(out))
  out$cond__ <- rows2labels(out, ...)
  out
}

# extract the cond__ variable used for faceting
get_cond__ <- function(x) {
  out <- x[["cond__"]]
  if (is.null(out)) {
    out <- rownames(x)
  }
  as.character(out)
}

#' Convert Rows to Labels
#'
#' Convert information in rows to labels for each row.
#'
#' @param x A \code{data.frame} for which to extract labels.
#' @param digits Minimal number of decimal places shown in
#'   the labels of numeric variables.
#' @param sep A single character string defining the separator
#'   between variables used in the labels.
#' @param incl_vars Indicates if variable names should
#'   be part of the labels. Defaults to \code{TRUE}.
#' @param ... Currently unused.
#'
#' @return A character vector of the same length as the number
#'   of rows of \code{x}.
#'
#' @seealso \code{\link{make_conditions}}, \code{\link{conditional_effects}}
#'
#' @export
rows2labels <- function(x, digits = 2, sep = " & ", incl_vars = TRUE, ...) {
  x <- as.data.frame(x)
  incl_vars <- as_one_logical(incl_vars)
  out <- x
  for (i in seq_along(out)) {
    if (!is_like_factor(out[[i]])) {
      out[[i]] <- round(out[[i]], digits)
    }
    if (incl_vars) {
      out[[i]] <- paste0(names(out)[i], " = ", out[[i]])
    }
  }
  paste_sep <- function(..., sep__ = sep) {
    paste(..., sep = sep__)
  }
  Reduce(paste_sep, out)
}

# prepare conditions for use in conditional_effects
# @param fit an object of class 'brmsfit'
# @param conditions optional data.frame containing user defined conditions
# @param effects see conditional_effects
# @param re_formula see conditional_effects
# @param rsv_vars names of reserved variables
# @return a data.frame with (possibly updated) conditions
prepare_conditions <- function(fit, conditions = NULL, effects = NULL,
                               re_formula = NA, rsv_vars = NULL) {
  mf <- model.frame(fit)
  new_formula <- update_re_terms(fit$formula, re_formula = re_formula)
  bterms <- brmsterms(new_formula)
  if (any(grepl_expr("^(as\\.)?factor(.+)$", bterms$allvars))) {
    # conditions are chosen based the variables stored in the data
    # this approach cannot take into account possible transformations
    # to factors happening inside the model formula
    warning2(
      "Using 'factor' or 'as.factor' in the model formula ",
      "might lead to problems in 'conditional_effects'.",
      "Please convert your variables to factors beforehand."
    )
  }
  req_vars <- all_vars(rhs(bterms$allvars))
  req_vars <- setdiff(req_vars, rsv_vars)
  req_vars <- setdiff(req_vars, names(fit$data2))
  if (is.null(conditions)) {
    conditions <- as.data.frame(as.list(rep(NA, length(req_vars))))
    names(conditions) <- req_vars
  } else {
    conditions <- as.data.frame(conditions)
    if (!nrow(conditions)) {
      stop2("Argument 'conditions' must have a least one row.")
    }
    conditions <- unique(conditions)
    if (any(duplicated(get_cond__(conditions)))) {
      stop2("Condition labels should be unique.")
    }
    req_vars <- setdiff(req_vars, names(conditions))
  }
  # special treatment for 'trials' addition variables
  trial_vars <- all_vars(bterms$adforms$trials)
  trial_vars <- trial_vars[!vars_specified(trial_vars, conditions)]
  if (length(trial_vars)) {
    message("Setting all 'trials' variables to 1 by ",
            "default if not specified otherwise.")
    req_vars <- setdiff(req_vars, trial_vars)
    for (v in trial_vars) {
      conditions[[v]] <- 1L
    }
  }
  # use sensible default values for unspecified variables
  subset_vars <- get_ad_vars(bterms, "subset")
  int_vars <- get_int_vars(bterms)
  group_vars <- get_group_vars(bterms)
  req_vars <- setdiff(req_vars, group_vars)
  for (v in req_vars) {
    if (is_like_factor(mf[[v]])) {
      # factor-like variable
      if (v %in% subset_vars) {
        # avoid unintentional subsetting of newdata (#755)
        conditions[[v]] <- TRUE
      } else {
        # use reference category for factors
        levels <- levels(as.factor(mf[[v]]))
        ordered <- is.ordered(mf[[v]])
        conditions[[v]] <- factor(levels[1], levels, ordered = ordered)
      }
    } else {
      # numeric-like variable
      if (v %in% subset_vars) {
        # avoid unintentional subsetting of newdata (#755)
        conditions[[v]] <- 1
      } else if (v %in% int_vars) {
        # ensure valid integer values
        conditions[[v]] <- round(median(mf[[v]], na.rm = TRUE))
      } else {
        conditions[[v]] <- mean(mf[[v]], na.rm = TRUE)
      }
    }
  }
  all_vars <- c(all_vars(bterms$allvars), "cond__")
  unused_vars <- setdiff(names(conditions), all_vars)
  if (length(unused_vars)) {
    warning2(
      "The following variables in 'conditions' are not ",
      "part of the model:\n", collapse_comma(unused_vars)
    )
  }
  cond__ <- conditions$cond__
  conditions <- validate_newdata(
    conditions, fit, re_formula = re_formula,
    allow_new_levels = TRUE, check_response = FALSE,
    incl_autocor = FALSE
  )
  conditions$cond__ <- cond__
  conditions
}

# prepare data to be used in conditional_effects
# @param data data.frame containing only data of the predictors of interest
# @param conditions see argument 'conditions' of conditional_effects
# @param int_conditions see argument 'int_conditions' of conditional_effects
# @param int_vars names of variables being treated as integers
# @param group_vars names of grouping variables
# @param surface generate surface plots later on?
# @param resolution number of distinct points at which to evaluate
#   the predictors of interest
# @param reorder reorder predictors so that numeric ones come first?
prepare_cond_data <- function(data, conditions, int_conditions = NULL,
                              int_vars = NULL, group_vars = NULL,
                              surface = FALSE, resolution = 100,
                              reorder = TRUE) {
  effects <- names(data)
  stopifnot(length(effects) %in% c(1L, 2L))
  is_factor <- ulapply(data, is_like_factor) | names(data) %in% group_vars
  types <- ifelse(is_factor, "factor", "numeric")
  # numeric effects should come first
  if (reorder) {
    new_order <- order(types, decreasing = TRUE)
    effects <- effects[new_order]
    types <- types[new_order]
  }
  # handle first predictor
  if (effects[1] %in% names(int_conditions)) {
    # first predictor has pre-specified conditions
    int_cond <- int_conditions[[effects[1]]]
    if (is.function(int_cond)) {
      int_cond <- int_cond(data[[effects[1]]])
    }
    values <- int_cond
  } else if (types[1] == "factor") {
    # first predictor is factor-like
    values <- factor(unique(data[[effects[1]]]))
  } else {
    # first predictor is numeric
    min1 <- min(data[[effects[1]]], na.rm = TRUE)
    max1 <- max(data[[effects[1]]], na.rm = TRUE)
    if (effects[1] %in% int_vars) {
      values <- seq(min1, max1, by = 1)
    } else {
      values <- seq(min1, max1, length.out = resolution)
    }
  }
  if (length(effects) == 2L) {
    # handle second predictor
    values <- setNames(list(values, NA), effects)
    if (effects[2] %in% names(int_conditions)) {
      # second predictor has pre-specified conditions
      int_cond <- int_conditions[[effects[2]]]
      if (is.function(int_cond)) {
        int_cond <- int_cond(data[[effects[2]]])
      }
      values[[2]] <- int_cond
    } else if (types[2] == "factor") {
      # second predictor is factor-like
      values[[2]] <- factor(unique(data[[effects[2]]]))
    } else {
      # second predictor is numeric
      if (surface) {
        min2 <- min(data[[effects[2]]], na.rm = TRUE)
        max2 <- max(data[[effects[2]]], na.rm = TRUE)
        if (effects[2] %in% int_vars) {
          values[[2]] <- seq(min2, max2, by = 1)
        } else {
          values[[2]] <- seq(min2, max2, length.out = resolution)
        }
      } else {
        if (effects[2] %in% int_vars) {
          median2 <- median(data[[effects[2]]])
          mad2 <- mad(data[[effects[2]]])
          values[[2]] <- round((-1:1) * mad2 + median2)
        } else {
          mean2 <- mean(data[[effects[2]]], na.rm = TRUE)
          sd2 <- sd(data[[effects[2]]], na.rm = TRUE)
          values[[2]] <- (-1:1) * sd2 + mean2
        }
      }
    }
    data <- do_call(expand.grid, values)
  } else {
    stopifnot(length(effects) == 1L)
    data <- structure(data.frame(values), names = effects)
  }
  # no need to have the same value combination more than once
  data <- unique(data)
  data <- data[do_call(order, unname(as.list(data))), , drop = FALSE]
  data <- replicate(nrow(conditions), data, simplify = FALSE)
  cond_vars <- setdiff(names(conditions), effects)
  cond__ <- get_cond__(conditions)
  for (j in seq_rows(conditions)) {
    data[[j]] <- fill_newdata(data[[j]], cond_vars, conditions, n = j)
    data[[j]]$cond__ <- cond__[j]
  }
  data <- do_call(rbind, data)
  data$cond__ <- factor(data$cond__, cond__)
  structure(data, effects = effects, types = types)
}

# which variables in 'vars' are specified in 'data'?
vars_specified <- function(vars, data) {
  .fun <- function(v) isTRUE(v %in% names(data)) && any(!is.na(data[[v]]))
  as.logical(ulapply(vars, .fun))
}

# prepare data points based on the provided conditions
# allows to add data points to conditional effects plots
# @return a data.frame containing the data points to be plotted
make_point_frame <- function(bterms, mf, effects, conditions,
                             select_points = 0, transform = NULL, ...) {
  stopifnot(is.brmsterms(bterms), is.data.frame(mf))
  effects <- intersect(effects, names(mf))
  points <- mf[, effects, drop = FALSE]
  points$resp__ <- model.response(
    model.frame(bterms$respform, mf, na.action = na.pass)
  )
  req_vars <- names(mf)
  groups <- get_re_groups(bterms)
  if (length(groups)) {
    c(req_vars) <- unlist(strsplit(groups, ":"))
  }
  req_vars <- unique(setdiff(req_vars, effects))
  req_vars <- intersect(req_vars, names(conditions))
  if (length(req_vars)) {
    # find out which data point is valid for which condition
    cond__ <- get_cond__(conditions)
    mf <- mf[, req_vars, drop = FALSE]
    conditions <- conditions[, req_vars, drop = FALSE]
    points$cond__ <- NA
    points <- replicate(nrow(conditions), points, simplify = FALSE)
    for (i in seq_along(points)) {
      cond <- conditions[i, , drop = FALSE]
      # ensures correct handling of matrix columns
      not_na <- function(x) !any(is.na(x) | x %in% "zero__")
      not_na <- ulapply(cond, not_na)
      cond <- cond[, not_na, drop = FALSE]
      mf_tmp <- mf[, not_na, drop = FALSE]
      if (ncol(mf_tmp)) {
        is_num <- sapply(mf_tmp, is.numeric)
        is_num <- is_num & !names(mf_tmp) %in% groups
        if (sum(is_num)) {
          # handle numeric variables
          stopifnot(select_points >= 0)
          if (select_points > 0) {
            for (v in names(mf_tmp)[is_num]) {
              min <- min(mf_tmp[, v], na.rm = TRUE)
              max <- max(mf_tmp[, v], na.rm = TRUE)
              unit <- scale_unit(mf_tmp[, v], min, max)
              unit_cond <- scale_unit(cond[, v], min, max)
              unit_diff <- abs(unit - unit_cond)
              close_enough <- unit_diff <= select_points
              mf_tmp[[v]][close_enough] <- cond[, v]
              mf_tmp[[v]][!close_enough] <- NA
            }
          } else {
            # take all numeric values if select_points is zero
            cond <- cond[, !is_num, drop = FALSE]
            mf_tmp <- mf_tmp[, !is_num, drop = FALSE]
          }
        }
      }
      if (ncol(mf_tmp)) {
        # handle factors and grouping variables
        # do it like base::duplicated
        K <- do_call("paste", c(mf_tmp, sep = "\r")) %in%
          do_call("paste", c(cond, sep = "\r"))
      } else {
        K <- seq_rows(mf)
      }
      # cond__ allows to assign points to conditions
      points[[i]]$cond__[K] <- cond__[i]
    }
    points <- do_call(rbind, points)
    points <- points[!is.na(points$cond__), , drop = FALSE]
    points$cond__ <- factor(points$cond__, cond__)
  }
  points <- add_effects__(points, effects)
  if (!is.numeric(points$resp__)) {
    points$resp__ <- as.numeric(as.factor(points$resp__))
    if (is_binary(bterms$family)) {
      points$resp__ <- points$resp__ - 1
    }
  }
  if (!is.null(transform)) {
    points$resp__ <- do_call(transform, list(points$resp__))
  }
  points
}

# add effect<i>__ variables to the data
add_effects__ <- function(data, effects) {
  for (i in seq_along(effects)) {
    data[[paste0("effect", i, "__")]] <- eval2(effects[i], data)
  }
  data
}

#' @export
print.brms_conditional_effects <- function(x, ...) {
  plot(x, ...)
}

#' @rdname conditional_effects.brmsfit
#' @method plot brms_conditional_effects
#' @importFrom rlang .data
#' @export
plot.brms_conditional_effects <- function(
  x, ncol = NULL, points = getOption("brms.plot_points", FALSE),
  rug = getOption("brms.plot_rug", FALSE), mean = TRUE,
  jitter_width = 0, stype = c("contour", "raster"),
  line_args = list(), cat_args = list(), errorbar_args = list(),
  surface_args = list(), spaghetti_args = list(), point_args = list(),
  rug_args = list(), facet_args = list(), theme = NULL, ask = TRUE,
  plot = TRUE, ...
) {
  dots <- list(...)
  plot <- use_alias(plot, dots$do_plot)
  stype <- match.arg(stype)
  smooths_only <- isTRUE(attr(x, "smooths_only"))
  if (points && smooths_only) {
    stop2("Argument 'points' is invalid for objects ",
          "returned by 'conditional_smooths'.")
  }
  if (!is_equal(jitter_width, 0)) {
    warning2("'jitter_width' is deprecated. Please use ",
             "'point_args = list(width = <width>)' instead.")
  }
  if (!is.null(theme) && !is.theme(theme)) {
      stop2("Argument 'theme' should be a 'theme' object.")
  }
  if (plot) {
    default_ask <- devAskNewPage()
    on.exit(devAskNewPage(default_ask))
    devAskNewPage(ask = FALSE)
  }
  dont_replace <- c("mapping", "data", "inherit.aes")
  plots <- named_list(names(x))
  for (i in seq_along(x)) {
    response <- attr(x[[i]], "response")
    effects <- attr(x[[i]], "effects")
    ncond <- length(unique(x[[i]]$cond__))
    df_points <- attr(x[[i]], "points")
    categorical <- isTRUE(attr(x[[i]], "categorical"))
    catscale <- attr(x[[i]], "catscale")
    surface <- isTRUE(attr(x[[i]], "surface"))
    # deprecated as of brms 2.4.3
    ordinal <- isTRUE(attr(x[[i]], "ordinal"))
    if (surface || ordinal) {
      # surface plots for two dimensional interactions or ordinal plots
      plots[[i]] <- ggplot(x[[i]]) +
        aes(.data[["effect1__"]], .data[["effect2__"]]) +
        labs(x = effects[1], y = effects[2])
      if (ordinal) {
        width <- ifelse(is_like_factor(x[[i]]$effect1__), 0.9, 1)
        .surface_args <- nlist(
          mapping = aes(fill = .data[["estimate__"]]),
          height = 0.9, width = width
        )
        replace_args(.surface_args, dont_replace) <- surface_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_tile, .surface_args) +
          scale_fill_gradientn(colors = viridis6(), name = catscale) +
          ylab(response)
      } else if (stype == "contour") {
        .surface_args <- nlist(
          mapping = aes(
            z = .data[["estimate__"]],
            colour = after_stat(.data[["level"]])
          ),
          bins = 30, linewidth = 1.3
        )
        replace_args(.surface_args, dont_replace) <- surface_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_contour, .surface_args) +
          scale_color_gradientn(colors = viridis6(), name = response)
      } else if (stype == "raster") {
        .surface_args <- nlist(mapping = aes(fill = .data[["estimate__"]]))
        replace_args(.surface_args, dont_replace) <- surface_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_raster, .surface_args) +
          scale_fill_gradientn(colors = viridis6(), name = response)
      }
    } else {
      # plot effects of single predictors or two-way interactions
      gvar <- if (length(effects) == 2L) "effect2__"
      spaghetti <- attr(x[[i]], "spaghetti")
      aes_tmp <- aes(x = .data[["effect1__"]], y = .data[["estimate__"]])
      if (!is.null(gvar)) {
        aes_tmp$colour <- aes(colour = .data[[gvar]])$colour
      }
      plots[[i]] <- ggplot(x[[i]]) + aes_tmp +
        labs(x = effects[1], y = response, colour = effects[2])
      if (is.null(spaghetti)) {
        aes_tmp <- aes(ymin = .data[["lower__"]], ymax = .data[["upper__"]])
        if (!is.null(gvar)) {
          aes_tmp$fill <- aes(fill = .data[[gvar]])$fill
        }
        plots[[i]] <- plots[[i]] + aes_tmp +
          labs(fill = effects[2])
      }
      # extract suggested colors for later use
      colors <- ggplot_build(plots[[i]])
      colors <- unique(colors$data[[1]][["colour"]])
      if (points && !categorical && !surface) {
        # add points first so that they appear behind the predictions
        .point_args <- list(
          mapping = aes(x = .data[["effect1__"]], y = .data[["resp__"]]),
          data = df_points, inherit.aes = FALSE,
          size = 2 / ncond^0.25, height = 0, width = jitter_width
        )
        if (is_like_factor(df_points[, gvar])) {
          .point_args$mapping[c("colour", "fill")] <-
            aes(colour = .data[[gvar]], fill = .data[[gvar]])
        }
        replace_args(.point_args, dont_replace) <- point_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_jitter, .point_args)
      }
      if (!is.null(spaghetti)) {
        # add a regression line for each sample separately
        .spaghetti_args <- list(
          aes(group = .data[["sample__"]]),
          data = spaghetti, stat = "identity", linewidth = 0.5
        )
        if (!is.null(gvar)) {
          .spaghetti_args[[1]]$colour <- aes(colour = .data[[gvar]])$colour
        }
        if (length(effects) == 1L) {
          .spaghetti_args$colour <- alpha("blue", 0.1)
        } else {
          # workaround to get transparent lines
          plots[[i]] <- plots[[i]] +
            scale_color_manual(values = alpha(colors, 0.1))
        }
        replace_args(.spaghetti_args, dont_replace) <- spaghetti_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_smooth, .spaghetti_args)
      }
      if (is.numeric(x[[i]]$effect1__)) {
        # line plots for numeric predictors
        .line_args <- list(stat = "identity")
        if (!is.null(spaghetti)) {
          # display a white mean regression line
          if (!is.null(gvar)) {
            .line_args$mapping <- aes(group = .data[[gvar]])
          }
          .line_args$colour <- alpha("white", 0.8)
        }
        replace_args(.line_args, dont_replace) <- line_args
        if (mean || is.null(spaghetti)) {
          plots[[i]] <- plots[[i]] +
            do_call(geom_smooth, .line_args)
        }
        if (rug) {
          .rug_args <- list(
            aes(x = .data[["effect1__"]]), sides = "b",
            data = df_points, inherit.aes = FALSE
          )
          if (is_like_factor(df_points[, gvar])) {
            .point_args$mapping[c("colour", "fill")] <-
              aes(colour = .data[[gvar]], fill = .data[[gvar]])
          }
          replace_args(.rug_args, dont_replace) <- rug_args
          plots[[i]] <- plots[[i]] +
            do_call(geom_rug, .rug_args)
        }
      } else {
        # points and errorbars for factors
        .cat_args <- list(
          position = position_dodge(width = 0.4),
          size = 4 / ncond^0.25
        )
        .errorbar_args <- list(
          position = position_dodge(width = 0.4),
          width = 0.3
        )
        replace_args(.cat_args, dont_replace) <- cat_args
        replace_args(.errorbar_args, dont_replace) <- errorbar_args
        plots[[i]] <- plots[[i]] +
          do_call(geom_point, .cat_args) +
          do_call(geom_errorbar, .errorbar_args)
      }
      if (categorical) {
        plots[[i]] <- plots[[i]] + ylab(catscale) +
          labs(fill = response, color = response)
      }
    }
    if (ncond > 1L) {
      # one plot per row of conditions
      if (is.null(ncol)) {
        ncol <- max(floor(sqrt(ncond)), 3)
      }
      .facet_args <- nlist(facets = "cond__", ncol)
      replace_args(.facet_args, dont_replace) <- facet_args
      plots[[i]] <- plots[[i]] +
        do_call(facet_wrap, .facet_args)
    }
    plots[[i]] <- plots[[i]] + theme
    if (plot) {
      plot(plots[[i]])
      if (i == 1) {
        devAskNewPage(ask = ask)
      }
    }
  }
  invisible(plots)
}

# the name 'marginal_effects' is deprecated as of brms 2.10.3
# do not remove it eventually as it has been used in the brms papers
#' @export
marginal_effects <- function(x, ...) {
  UseMethod("marginal_effects")
}

#' @export
marginal_effects.brmsfit <- function(x, ...) {
  warning2("Method 'marginal_effects' is deprecated. ",
           "Please use 'conditional_effects' instead.")
  conditional_effects.brmsfit(x, ...)
}

#' @export
print.brmsMarginalEffects <- function(x, ...) {
  class(x) <- "brms_conditional_effects"
  print(x, ...)
}

#' @export
plot.brmsMarginalEffects <- function(x, ...) {
  class(x) <- "brms_conditional_effects"
  plot(x, ...)
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Sept. 23, 2024, 5:08 p.m.