Nothing
#' @title Generic S3 method for creating an initial values function
#'
#' @description Called by bmm() to create an initfun for models that require
#' initial values to properly start sampling
#'
#' @param model The `bmmodel` object for with an initfun should be created
#' @param data The user supplied data.frame used to fit the model
#' @param formula The `brmsformula` generated by `configure_model` including the
#' family for the `bmmodel` to be estimated.
#'
#' @return An initfun with no arguments that generates
#' inital values suitable for the STAN parameters generated by the respective
#' model call
#'
#' @export
#' @keywords internal developer
create_initfun <- function(model, data, formula) {
UseMethod("create_initfun")
}
#' @export
create_initfun.bmmodel <- function(model, data, formula) {
if (is.null(model$init_ranges)) {
return(1)
}
# extract information from STAN code
standata_list <- standata(formula, data, formula$family)
stan_code <- stancode(formula, data, formula$family)
stanpars_list <- extract_parameter_dimensions(extract_stan_blocks(stan_code)$parameters)
function() {
# force evaluation of required information
force(stanpars_list)
force(standata_list)
force(model)
force(formula)
force(data)
bterms <- brms::brmsterms(formula)
model_pars <- names(model$parameters)
init_ranges <- model$init_ranges
links <- model$links
inits <- list()
for (spar in names(stanpars_list)) {
# parse stan parameter names; if it contains a model parameter return that,
# otherwise get the type of parameter, e.g. for covariance matrices and z-values
# for random effects over groups
parameter <- model_pars[unlist(lapply(paste0("_", model_pars), grepl, x = spar))]
if (length(parameter) == 0) {
parameter <- strsplit(spar, "_")[[1]][1]
}
parameter <- if (parameter == "Intercept") "mu" else parameter
type <- stanpars_list[[spar]]$type
dim_names <- stanpars_list[[spar]]$dims
dim <- unlist(standata_list[dim_names])
range <- init_ranges[[parameter]]
link <- links[[parameter]]
# Handle different parameter types
inits[[spar]] <- switch(type,
real = init_real_param(spar, range, link),
vector = init_vector_param(spar, dim, range, link, bterms$dpars[[parameter]], data),
matrix = matrix(runif(prod(dim), min = -.5, max = .5), nrow = dim[1]),
cholesky_factor_corr = ,
cholesky_factor_cov = ,
cov_matrix = ,
corr_matrix = diag(nrow = dim),
stop2("Unsupported parameter type: {type}")
)
}
inits
}
}
# Helper functions for initializing different parameter types ----------------
init_real_param <- function(par, init_range, link) {
if (grepl("Intercept", par)) {
link_transform(runif(1, min = init_range[1], max = init_range[2]), link)
} else if (grepl("sd_", par)) {
runif(1, min = 0.05, max = 0.1)
} else {
stop2("Initial values for reals are only specified for Intercepts and sd-parameters")
}
}
init_vector_param <- function(par, dim, init_range, link, bterms, data) {
if (grepl("b_", par)) {
if (has_intercept(bterms$fe)) {
return(runif(dim, min = -0.1, max = 0.1))
}
# For models without intercept, initialize first term with model-specific
# ranges; the rest (if any) with small random values
term_labels <- attr(terms(bterms$fe), "term.labels")
variables <- strsplit(term_labels[1], ":")[[1]] # may be interaction terms (e.g., "var1:var2")
n_first <- count_term_levels(data, variables)
c(
link_transform(runif(n_first, min = init_range[1], max = init_range[2]), link),
runif(dim - n_first, min = -0.1, max = 0.1)
)
} else if (grepl("sd_", par)) {
array(runif(prod(dim), min = 0.05, max = 0.1), dim = dim)
} else if (grepl("z_", par)) {
array(runif(prod(dim), min = -.5, max = .5), dim = dim)
} else {
stop2("Initial values for vectors are only specified for b-coefficients, sd and z parameters")
}
}
# Count levels for model formula terms - used for determining number of
# coefficients when initializing models without intercepts
count_term_levels <- function(data, vars) {
term_sizes <- vapply(data[vars], function(x) {
if (is.factor(x)) {
nlevels(x)
} else if (is.numeric(x)) {
# For continuous predictors, treat as single coefficient
1L
} else {
length(unique(na.omit(x)))
}
}, integer(1L))
# Return total number of coefficients (product for interactions)
prod(term_sizes)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.