R/helpers-inits.R

Defines functions count_term_levels init_vector_param init_real_param create_initfun.bmmodel create_initfun

Documented in create_initfun

#' @title Generic S3 method for creating an initial values function
#'
#' @description Called by bmm() to create an initfun for models that require
#'   initial values to properly start sampling
#'
#' @param model The `bmmodel` object for with an initfun should be created
#' @param data The user supplied data.frame used to fit the model
#' @param formula The `brmsformula` generated by `configure_model` including the
#'   family for the `bmmodel` to be estimated.
#'
#' @return An initfun with no arguments that generates
#'   inital values suitable for the STAN parameters generated by the respective
#'   model call
#'
#' @export
#' @keywords internal developer
create_initfun <- function(model, data, formula) {
  UseMethod("create_initfun")
}

#' @export
create_initfun.bmmodel <- function(model, data, formula) {
  if (is.null(model$init_ranges)) {
    return(1)
  }

  # extract information from STAN code
  standata_list <- standata(formula, data, formula$family)
  stan_code <- stancode(formula, data, formula$family)
  stanpars_list <- extract_parameter_dimensions(extract_stan_blocks(stan_code)$parameters)

  function() {
    # force evaluation of required information
    force(stanpars_list)
    force(standata_list)
    force(model)
    force(formula)
    force(data)

    bterms <- brms::brmsterms(formula)
    model_pars <- names(model$parameters)
    init_ranges <- model$init_ranges
    links <- model$links
    inits <- list()

    for (spar in names(stanpars_list)) {
      # parse stan parameter names; if it contains a model parameter return that,
      # otherwise get the type of parameter, e.g. for  covariance matrices and z-values
      # for random effects over groups
      parameter <- model_pars[unlist(lapply(paste0("_", model_pars), grepl, x = spar))]
      if (length(parameter) == 0) {
        parameter <- strsplit(spar, "_")[[1]][1]
      }
      parameter <- if (parameter == "Intercept") "mu" else parameter

      type <- stanpars_list[[spar]]$type
      dim_names <- stanpars_list[[spar]]$dims
      dim <- unlist(standata_list[dim_names])
      range <- init_ranges[[parameter]]
      link <- links[[parameter]]

      # Handle different parameter types
      inits[[spar]] <- switch(type,
        real = init_real_param(spar, range, link),
        vector = init_vector_param(spar, dim, range, link, bterms$dpars[[parameter]], data),
        matrix = matrix(runif(prod(dim), min = -.5, max = .5), nrow = dim[1]),
        cholesky_factor_corr = ,
        cholesky_factor_cov = ,
        cov_matrix = ,
        corr_matrix = diag(nrow = dim),
        stop2("Unsupported parameter type: {type}")
      )
    }

    inits
  }
}


# Helper functions for initializing different parameter types ----------------
init_real_param <- function(par, init_range, link) {
  if (grepl("Intercept", par)) {
    link_transform(runif(1, min = init_range[1], max = init_range[2]), link)
  } else if (grepl("sd_", par)) {
    runif(1, min = 0.05, max = 0.1)
  } else {
    stop2("Initial values for reals are only specified for Intercepts and sd-parameters")
  }
}

init_vector_param <- function(par, dim, init_range, link, bterms, data) {
  if (grepl("b_", par)) {
    if (has_intercept(bterms$fe)) {
      return(runif(dim, min = -0.1, max = 0.1))
    }

    # For models without intercept, initialize first term with model-specific
    # ranges; the rest (if any) with small random values
    term_labels <- attr(terms(bterms$fe), "term.labels")
    variables <- strsplit(term_labels[1], ":")[[1]] # may be interaction terms (e.g., "var1:var2")
    n_first <- count_term_levels(data, variables)

    c(
      link_transform(runif(n_first, min = init_range[1], max = init_range[2]), link),
      runif(dim - n_first, min = -0.1, max = 0.1)
    )
  } else if (grepl("sd_", par)) {
    array(runif(prod(dim), min = 0.05, max = 0.1), dim = dim)
  } else if (grepl("z_", par)) {
    array(runif(prod(dim), min = -.5, max = .5), dim = dim)
  } else {
    stop2("Initial values for vectors are only specified for b-coefficients, sd and z parameters")
  }
}


# Count levels for model formula terms - used for determining number of
# coefficients when initializing models without intercepts
count_term_levels <- function(data, vars) {
  term_sizes <- vapply(data[vars], function(x) {
    if (is.factor(x)) {
      nlevels(x)
    } else if (is.numeric(x)) {
      # For continuous predictors, treat as single coefficient
      1L
    } else {
      length(unique(na.omit(x)))
    }
  }, integer(1L))

  # Return total number of coefficients (product for interactions)
  prod(term_sizes)
}

Try the bmm package in your browser

Any scripts or data that you put into this service are public.

bmm documentation built on March 30, 2026, 5:08 p.m.