R/conditional_effects.R

Defines functions .apply_link_transform_inner .apply_link_transform .compute_multinomial_ce_grids .compute_multinomial_conditional_effects .compute_softmax_ce_inner .compute_softmax_conditional_effects .ce_summarize_draws .ce_build_grids_for_vars .ce_build_grids .ce_prediction_grid .extract_re_grouping_vars .filter_internal_effects .brms_conditional_effects .ce_compute_and_transform .ce_single_parameter .ce_all_parameters conditional_effects.bmmfit

Documented in conditional_effects.bmmfit

#' Conditional Effects for BMM Models
#'
#' @description
#' Compute conditional effects for parameters of a bmmfit object.
#' This method provides a more intuitive interface than directly calling
#' [brms::conditional_effects()] on bmmfit objects, by:
#' \itemize{
#'   \item Accepting model parameter names directly (e.g., `"kappa"`, `"thetat"`)
#'   \item Automatically determining whether parameters are distributional or non-linear
#'   \item Optionally applying inverse link transformations to show parameters on their natural scale
#' }
#'
#' @param x A bmmfit object (created by [bmm()])
#' @param par Character string. Name of the model parameter to compute effects for.
#'   This should be one of the parameter names from the original model specification
#'   (see `names(x$bmm$model$parameters)`). If `NULL` (the default), conditional
#'   effects are computed for all estimated (non-fixed) parameters.
#' @param scale Character. Scale on which to show the parameter:
#'   \describe{
#'     \item{`"native"` (default)}{Show on natural scale using inverse link transformation.
#'       For example, `kappa` with log link shown on exp scale, `thetat` with
#'       logit (mixture2p) or softmax (mixture3p) link shown on the probability scale.}
#'     \item{`"sampling"`}{Show on the sampling scale (as used during MCMC).
#'       For example, `kappa` with log link shown on log scale.}
#'   }
#' @param ... Additional arguments passed to [brms::conditional_effects()].
#'   Common arguments include:
#'   \itemize{
#'     \item `effects`: Character vector specifying which predictor effects to plot
#'     \item `conditions`: Named list for setting values of covariates
#'     \item `int_conditions`: Conditions for interactions
#'     \item `prob`: Probability mass to include in credible intervals (default 0.95)
#'     \item `spaghetti`: Logical, whether to add spaghetti lines
#'     \item `method`: Method for computing effects ("posterior_predict" or "posterior_epred")
#'   }
#'
#' @return A `brms_conditional_effects` object (from brms), which can be:
#' \itemize{
#'   \item Plotted directly using [plot()]
#'   \item Converted to a data frame for custom plotting
#'   \item Combined with other conditional effects plots
#' }
#'
#' @details
#' ## Parameter Types
#'
#' bmm models use two types of parameters internally:
#' \itemize{
#'   \item **Non-linear parameters (`nlpar`)**: Core model parameters like `kappa`, `c`, `a`, `thetat`
#'   \item **Distributional parameters (`dpar`)**: Derived parameters used in brms mixture distributions
#' }
#'
#' Users should not need to know this distinction - `conditional_effects.bmmfit()`
#' automatically routes to the correct parameter type.
#'
#' ## Scale Transformations
#'
#' By default (`scale = "native"`), parameters are shown on their natural scale by
#' applying inverse link transformations:
#' \itemize{
#'   \item `log` link → exp transformation
#'   \item `logit` link → inverse logit (probability scale)
#'   \item `tan_half` link → 2*atan transformation (radians)
#'   \item `identity` link → no transformation
#' }
#'
#' Use `scale = "sampling"` to see parameters on the scale used during MCMC sampling.
#'
#' @seealso [brms::conditional_effects()] for the underlying brms function
#'
#' @aliases conditional_effects
#' @method conditional_effects bmmfit
#' @export
#'
#' @examples
#' \dontrun{
#' # Fit a mixture model with set size effect on kappa
#' fit <- bmm(
#'   formula = bmf(kappa ~ 0 + setsize, thetat ~ 1),
#'   data = zhang_luck_2008,
#'   model = mixture3p(
#'     resp_error = "response_error",
#'     nt_features = paste0("col_lure", 1:5),
#'     set_size = "setsize"
#'   )
#' )
#'
#' # Get conditional effects for kappa on natural scale (exp of log)
#' ce_kappa <- conditional_effects(fit, par = "kappa")
#' plot(ce_kappa)
#'
#' # Get conditional effects for kappa on log scale (sampling scale)
#' ce_kappa_log <- conditional_effects(fit, par = "kappa", scale = "sampling")
#' plot(ce_kappa_log)
#'
#' # Get effects for thetat (memory probability)
#' ce_thetat <- conditional_effects(fit, par = "thetat")
#' plot(ce_thetat)
#'
#' # Specify which effects to plot
#' ce_specific <- conditional_effects(fit, par = "kappa", effects = "setsize")
#'
#' # Combine with other brms options
#' ce_detailed <- conditional_effects(
#'   fit,
#'   par = "kappa",
#'   effects = "setsize",
#'   spaghetti = TRUE,
#'   ndraws = 100
#' )
#' }
conditional_effects.bmmfit <- function(x,
                                       par = NULL,
                                       scale = c("native", "sampling"),
                                       ...) {
  x <- restructure(x)
  scale <- match.arg(scale)

  if (is.null(par)) {
    .ce_all_parameters(x, scale, ...)
  } else {
    stopif(
      !is.character(par) || length(par) != 1,
      "Argument 'par' must be a single character string"
    )
    .ce_single_parameter(x, par, scale, ...)
  }
}


.ce_all_parameters <- function(x, scale, ...) {
  model <- x$bmm$model
  estimated_pars <- setdiff(names(model$parameters),
                            names(model$fixed_parameters))
  all_effects <- list()
  for (p in estimated_pars) {
    ce <- .ce_single_parameter(x, par = p, scale = scale, ...)
    if (length(ce) > 0) {
      names(ce) <- paste0(p, ".", names(ce))
      all_effects <- c(all_effects, ce)
    }
  }
  class(all_effects) <- c("brms_conditional_effects")
  all_effects
}


.ce_single_parameter <- function(x, par, scale, ...) {
  par_info <- .get_parameter_info(x, par)

  if (par_info$softmax && scale == "native") {
    softmax_result <- .compute_softmax_conditional_effects(x, par, ...)
    if (!is.null(softmax_result)) {
      .filter_internal_effects(softmax_result, x)
    } else {
      warning2(
        "Parameter '{par}' uses softmax transformation.\n",
        "Native scale display not available for this model configuration.\n",
        "Showing on sampling scale instead."
      )
      .ce_compute_and_transform(x, par, par_info, "sampling", ...)
    }
  } else {
    .ce_compute_and_transform(x, par, par_info, scale, ...)
  }
}


.ce_compute_and_transform <- function(x, par, par_info, scale, ...) {
  # m3 models require categorical = TRUE in brms, which breaks nlpar-level
  # computation — bypass via posterior_linpred directly
  ce_result <- if ("m3" %in% class(x$bmm$model)) {
    .compute_multinomial_conditional_effects(x, par, ...)
  } else if (par_info$type == "dpar") {
    .brms_conditional_effects(x, dpar = par_info$brms_name, ...)
  } else if (par_info$type == "nlpar") {
    .brms_conditional_effects(x, nlpar = par_info$brms_name, ...)
  } else {
    stop2("Internal error: parameter type must be 'dpar' or 'nlpar'")
  }

  # nlpars: brms returns on sampling scale; dpars: on native scale
  if (par_info$link != "identity") {
    if (par_info$type == "nlpar" && scale == "native") {
      ce_result <- .apply_link_transform(ce_result, par_info$link, inverse = TRUE)
    } else if (par_info$type == "dpar" && scale == "sampling") {
      ce_result <- .apply_link_transform(ce_result, par_info$link, inverse = FALSE)
    }
  }

  .filter_internal_effects(ce_result, x)
}


#' Call brms conditional_effects without infinite recursion
#'
#' @description
#' Strips the `"bmmfit"` class so that S3 dispatch reaches
#' `brms::conditional_effects.brmsfit()` instead of recursing back to
#' `conditional_effects.bmmfit()`.
#'
#' @param x A bmmfit object
#' @param ... Arguments forwarded to [brms::conditional_effects()]
#'
#' @return A `brms_conditional_effects` object
#'
#' @keywords internal
#' @noRd
.brms_conditional_effects <- function(x, ...) {
  class(x) <- class(x)[class(x) != "bmmfit"]
  conditional_effects(x, ...)
}


#' Filter internal variables from conditional_effects results
#'
#' @description
#' Removes conditional effects plots for internal model variables
#' (like LureIdx, Idx_*, inv_ss, etc.) that are created during data
#' preprocessing but are not part of the user's formula.
#'
#' @param ce_result A brms_conditional_effects object
#' @param bmmfit A bmmfit object
#'
#' @return Filtered conditional_effects object with only user-specified predictors
#'
#' @keywords internal
#' @noRd
.filter_internal_effects <- function(ce_result, bmmfit) {
  internal_patterns <- c(
    "^LureIdx",
    "^Idx_",
    "^inv_ss$",
    "^Item[0-9]+_",
    "^expS$"
  )
  
  model <- bmmfit$bmm$model
  if (!is.null(model$other_vars$nt_features)) {
    nt_features <- model$other_vars$nt_features
    escaped <- gsub("([][(){}^$*+?.|\\\\])", "\\\\\\1", nt_features)
    internal_patterns <- c(internal_patterns, paste0("^", escaped, "$"))
  }
  if (!is.null(model$other_vars$nt_distances)) {
    nt_distances <- model$other_vars$nt_distances
    escaped <- gsub("([][(){}^$*+?.|\\\\])", "\\\\\\1", nt_distances)
    internal_patterns <- c(internal_patterns, paste0("^", escaped, "$"))
  }
  
  effect_names <- names(ce_result)
  combined_pattern <- paste(internal_patterns, collapse = "|")
  keep_effects <- !vapply(effect_names, function(name) {
    vars <- strsplit(name, ":")[[1]]
    any(grepl(combined_pattern, vars))
  }, logical(1))

  if (any(keep_effects)) {
    ce_result <- ce_result[keep_effects]
    class(ce_result) <- c("brms_conditional_effects")
  }

  ce_result
}


#' Extract grouping variable names from random effects in a formula
#'
#' @description
#' Parses the RHS of a formula to identify random-effects grouping variables
#' that should be excluded from conditional effects. Handles all brms grouping
#' specifications:
#' \itemize{
#'   \item Bare names: `(1 | id)`, `(1 || id)`
#'   \item Correlation IDs: `(1 |ID1| id)` — excludes both `ID1` and `id`
#'   \item `gr()`: `(1 | gr(id, by = exp))` — extracts `id`, not `exp`
#'   \item `mm()`: `(1 | mm(g1, g2))` — extracts all positional args
#'   \item Crossed: `(1 | id:group)` — extracts both `id` and `group`
#' }
#'
#' @param formula A formula object
#'
#' @return Character vector of grouping variable names to exclude
#'
#' @keywords internal
#' @noRd
.extract_re_grouping_vars <- function(formula) {
  rhs_str <- paste(deparse(formula[[length(formula)]]), collapse = " ")

  # Match text after each | that is not itself | or )
  # This captures: bare grouping vars, correlation IDs, and gr()/mm() calls
  bar_parts <- regmatches(
    rhs_str, gregexpr("(?<=\\|)[^|)]+", rhs_str, perl = TRUE)
  )[[1]]
  bar_parts <- trimws(bar_parts)
  bar_parts <- bar_parts[nchar(bar_parts) > 0]

  if (length(bar_parts) == 0) {
    character(0)
  } else {
    unlist(lapply(bar_parts, function(part) {
      if (grepl("^gr\\s*\\(", part)) {
        # gr(id, ...) — first argument is the grouping variable
        inner <- sub("^gr\\s*\\(\\s*", "", part)
        trimws(sub("[,)]+.*", "", inner))
      } else if (grepl("^mm\\s*\\(", part)) {
        # mm(g1, g2, ...) — positional args (before named args) are grouping vars
        inner <- sub("^mm\\s*\\(\\s*", "", part)
        args <- trimws(strsplit(inner, ",")[[1]])
        args[!grepl("=", args)]
      } else {
        # Bare variable name(s) or correlation ID — split on : only
        trimws(strsplit(part, ":")[[1]])
      }
    }))
  }
}


#' Build a prediction grid for conditional effects
#'
#' @description
#' Constructs a prediction grid for computing conditional effects via
#' [brms::posterior_linpred()]. For each effect variable, creates a data frame
#' where that variable varies over its range (numeric) or levels (factor) while
#' all other columns are held at reference values (mean for numeric, first level
#' for factor).
#'
#' @param bmmfit A bmmfit object
#' @param par Character string. Parameter name whose formula determines the
#'   predictor variables.
#' @param effects Character vector. Specific effect variables to include. If
#'   `NULL`, all RHS variables from the parameter's formula are used.
#' @param resolution Integer. Number of points for numeric predictors (default
#'   100).
#'
#' @return A named list of data frames, one per effect variable. Empty list if
#'   no effects are found.
#'
#' @keywords internal
#' @noRd
.ce_prediction_grid <- function(bmmfit, par, effects = NULL, resolution = 100) {
  user_formula <- bmmfit$bmm$user_formula
  par_formula <- user_formula[[par]]
  if (is.null(par_formula)) {
    list()
  } else {
    .ce_build_grids(bmmfit, par_formula, effects, resolution)
  }
}


.ce_build_grids <- function(bmmfit, par_formula, effects, resolution) {
  f <- stats::formula(par_formula)
  re_groups <- .extract_re_grouping_vars(f)
  rhs_vars <- all.vars(f[-2])
  rhs_vars <- setdiff(rhs_vars, c("0", "1", re_groups))

  effect_vars <- if (is.null(effects)) {
    rhs_vars
  } else {
    intersect(unlist(strsplit(as.character(effects), ":")), rhs_vars)
  }

  if (length(effect_vars) == 0) {
    list()
  } else {
    .ce_build_grids_for_vars(bmmfit$data, effect_vars, resolution)
  }
}


.ce_build_grids_for_vars <- function(orig_data, effect_vars, resolution) {
  grids <- list()

  for (var in effect_vars) {
    col <- orig_data[[var]]
    if (is.factor(col) || is.character(col)) {
      varying <- sort(unique(col))
    } else {
      rng <- range(col, na.rm = TRUE)
      varying <- seq(rng[1], rng[2], length.out = resolution)
    }

    newdata <- data.frame(x__ = varying)
    names(newdata) <- var

    for (v in setdiff(names(orig_data), var)) {
      cv <- orig_data[[v]]
      if (is.matrix(cv)) next
      if (is.factor(cv)) {
        newdata[[v]] <- factor(levels(cv)[1], levels = levels(cv))
      } else if (is.character(cv)) {
        newdata[[v]] <- cv[1]
      } else if (is.integer(cv)) {
        newdata[[v]] <- as.integer(round(stats::median(cv, na.rm = TRUE)))
      } else if (is.numeric(cv)) {
        newdata[[v]] <- mean(cv, na.rm = TRUE)
      } else {
        newdata[[v]] <- cv[1]
      }
    }

    grids[[var]] <- newdata
  }

  grids
}


#' Summarize posterior draws into conditional-effect statistics
#'
#' @description
#' Takes a draws matrix (n_draws x n_grid_points) and computes summary
#' statistics suitable for `brms_conditional_effects` data frames.
#'
#' @param draws Matrix. Posterior draws (rows = draws, columns = grid points).
#' @param prob Numeric. Probability mass for credible intervals (default 0.95).
#' @param robust Logical. If `TRUE`, use median/MAD instead of mean/SD.
#'
#' @return A list with elements `estimate`, `lower`, `upper`, `se` — each a
#'   numeric vector of length `ncol(draws)`.
#'
#' @keywords internal
#' @noRd
.ce_summarize_draws <- function(draws, prob = 0.95, robust = FALSE) {
  probs <- c((1 - prob) / 2, 1 - (1 - prob) / 2)
  if (robust) {
    estimate <- apply(draws, 2, stats::median)
    se <- apply(draws, 2, stats::mad)
  } else {
    estimate <- colMeans(draws)
    se <- apply(draws, 2, stats::sd)
  }
  lower <- apply(draws, 2, stats::quantile, probs = probs[1])
  upper <- apply(draws, 2, stats::quantile, probs = probs[2])
  list(estimate = estimate, lower = lower, upper = upper, se = se)
}


#' Compute softmax transformation for multinomial parameters
#'
#' @description
#' For mixture models with multinomial logit (softmax), manually computes
#' the softmax transformation by extracting conditional effects for all
#' relevant nlpars and applying the softmax formula.
#'
#' @param bmmfit A bmmfit object
#' @param par Character string. Parameter name to return (thetat or thetant)
#' @param ... Additional arguments passed to brms::conditional_effects()
#'
#' @return A brms_conditional_effects object with softmax-transformed values
#'
#' @keywords internal
#' @noRd
.compute_softmax_conditional_effects <- function(bmmfit, par, ...) {
  if (!"mixture3p" %in% class(bmmfit$bmm$model)) {
    NULL
  } else {
    .compute_softmax_ce_inner(bmmfit, par, ...)
  }
}


.compute_softmax_ce_inner <- function(bmmfit, par, ...) {
  ce_par <- .brms_conditional_effects(bmmfit, nlpar = par, ...)

  dots <- list(...)
  prob <- dots$prob %||% 0.95
  robust <- dots$robust %||% FALSE
  re_formula <- dots$re_formula %||% NA
  ndraws <- dots$ndraws

  result <- lapply(ce_par, function(df) {
    internal_cols <- grep("__$", names(df), value = TRUE)
    newdata <- df[, !names(df) %in% internal_cols, drop = FALSE]

    linpred_args <- list(
      object = bmmfit,
      newdata = newdata,
      re_formula = re_formula,
      allow_new_levels = TRUE
    )
    if (!is.null(ndraws)) linpred_args$ndraws <- ndraws

    draws_t <- do.call(
      brms::posterior_linpred,
      c(linpred_args, list(nlpar = "thetat"))
    )
    draws_nt <- do.call(
      brms::posterior_linpred,
      c(linpred_args, list(nlpar = "thetant"))
    )

    # numerically stable softmax: subtract max before exponentiating
    shift <- pmax(draws_t, draws_nt, 0)
    exp_t <- exp(draws_t - shift)
    exp_nt <- exp(draws_nt - shift)
    exp_0 <- exp(-shift)
    denom <- exp_t + exp_nt + exp_0
    if (par == "thetat") {
      softmax_draws <- exp_t / denom
    } else {
      softmax_draws <- exp_nt / denom
    }

    summ <- .ce_summarize_draws(softmax_draws, prob = prob, robust = robust)
    df$estimate__ <- summ$estimate
    df$lower__ <- summ$lower
    df$upper__ <- summ$upper
    df$se__ <- summ$se

    df
  })

  names(result) <- names(ce_par)
  class(result) <- class(ce_par)
  result
}


#' Compute conditional effects for multinomial family models
#'
#' @description
#' For models using `brms::multinomial()` family (e.g., m3), brms requires
#' `categorical = TRUE` even when requesting a specific nlpar, which conflicts
#' with nlpar-level computation. This helper bypasses that check by using
#' `brms::posterior_linpred()` directly with a manually constructed prediction
#' grid.
#'
#' @param bmmfit A bmmfit object
#' @param par Character string. Parameter name (nlpar) to compute effects for
#' @param ... Additional arguments (prob, robust, re_formula, ndraws, effects,
#'   resolution)
#'
#' @return A `brms_conditional_effects` object with one element per effect
#'
#' @keywords internal
#' @noRd
.compute_multinomial_conditional_effects <- function(bmmfit, par, ...) {
  dots <- list(...)
  prob <- dots$prob %||% 0.95
  robust <- dots$robust %||% FALSE
  re_formula <- dots$re_formula %||% NA
  ndraws <- dots$ndraws
  resolution <- dots$resolution %||% 100
  effects <- dots$effects

  grids <- .ce_prediction_grid(bmmfit, par,
                               effects = effects,
                               resolution = resolution)
  if (length(grids) == 0) {
    structure(list(), class = c("brms_conditional_effects", "list"))
  } else {
    .compute_multinomial_ce_grids(bmmfit, par, grids, prob, robust,
                                  re_formula, ndraws)
  }
}


.compute_multinomial_ce_grids <- function(bmmfit, par, grids, prob, robust,
                                           re_formula, ndraws) {
  result <- list()

  for (var in names(grids)) {
    newdata <- grids[[var]]

    linpred_args <- list(
      object = bmmfit,
      newdata = newdata,
      nlpar = par,
      re_formula = re_formula,
      allow_new_levels = TRUE
    )
    if (!is.null(ndraws)) linpred_args$ndraws <- ndraws

    draws <- do.call(brms::posterior_linpred, linpred_args)

    summ <- .ce_summarize_draws(draws, prob = prob, robust = robust)
    newdata$estimate__ <- summ$estimate
    newdata$lower__ <- summ$lower
    newdata$upper__ <- summ$upper
    newdata$se__ <- summ$se
    newdata$effect1__ <- newdata[[var]]
    newdata$cond__ <- factor("1")

    attr(newdata, "effects") <- var
    attr(newdata, "response") <- par

    result[[var]] <- newdata
  }

  class(result) <- c("brms_conditional_effects", "list")
  result
}


#' Apply link transformation to conditional effects
#'
#' @description
#' Internal function that applies link transformation to a
#' conditional_effects object from brms. Can apply either forward
#' or inverse transformation to the estimate and credible interval bounds.
#'
#' @param ce_object A brmsfit_conditional_effects object from brms::conditional_effects()
#' @param link Character string. Link function name
#' @param inverse Logical. If TRUE, apply inverse link (sampling → native).
#'   If FALSE, apply forward link (native → sampling).
#'
#' @return Modified conditional_effects object with transformed values
#'
#' @keywords internal
#' @noRd
.apply_link_transform <- function(ce_object, link, inverse = TRUE) {
  if (link == "identity") {
    ce_object
  } else {
    .apply_link_transform_inner(ce_object, link, inverse)
  }
}


.apply_link_transform_inner <- function(ce_object, link, inverse) {
  result <- lapply(ce_object, function(df) {
    df$estimate__ <- link_transform(df$estimate__, link, inverse = inverse)
    df$lower__ <- link_transform(df$lower__, link, inverse = inverse)
    df$upper__ <- link_transform(df$upper__, link, inverse = inverse)
    df
  })
  
  names(result) <- names(ce_object)
  class(result) <- class(ce_object)
  result
}

Try the bmm package in your browser

Any scripts or data that you put into this service are public.

bmm documentation built on March 30, 2026, 5:08 p.m.