#' @title Create a reference grid
#' @name get_datagrid
#'
#' @description
#' Create a reference matrix, useful for visualisation, with evenly spread and
#' combined values. Usually used to generate predictions using [get_predicted()].
#' See this
#' [vignette](https://easystats.github.io/modelbased/articles/visualisation_matrix.html)
#' for a tutorial on how to create a visualisation matrix using this function.
#'
#' Alternatively, these can also be used to extract the "grid" columns from
#' objects generated by **emmeans** and **marginaleffects** (see those
#' [methods][get_datagrid.emmGrid] for more info).
#'
#' @param x An object from which to construct the reference grid.
#' @param by Indicates the _focal predictors_ (variables) for the reference grid
#' and at which values focal predictors should be represented. If not specified
#' otherwise, representative values for numeric variables or predictors are
#' evenly distributed from the minimum to the maximum, with a total number of
#' `length` values covering that range (see 'Examples'). Possible options for
#' `by` are:
#' - **Select variables only:**
#' - `"all"`, which will include all variables or predictors.
#' - a character vector of one or more variable or predictor names, like
#' `c("Species", "Sepal.Width")`, which will create a grid of all
#' combinations of unique values.
#'
#' **Note:** If `by` specifies only variable names, without associated
#' values, the following occurs: factor variables use all their levels,
#' numeric variables use a range of `length` equally spaced values between
#' their minimum and maximum, and character variables use all their unique
#' values.
#'
#' - **Select variables and values:**
#' - `by` can be a list of named elements, indicating focal predictors and
#' their representative values, e.g. `by = list(mpg = 10:20)`,
#' `by = list(Sepal.Length = c(2, 4), Species = "setosa")`, or
#' `by = list(Sepal.Length = seq(2, 5, 0.5))`.
#' - Instead of a list, it is possible to write a string representation, or
#' a character vector of such strings, e.g. `by = "mpg = 10:20"`,
#' `by = c("Sepal.Length = c(2, 4)", "Species = 'setosa'")`, or
#' `by = "Sepal.Length = seq(2, 5, 0.5)"`. Note the usage of single and
#' double quotes to assign strings within strings.
#' - In general, any expression after a `=` will be evaluated as R code, which
#' allows using own functions, e.g.
#' ```
#' fun <- function(x) x^2
#' get_datagrid(iris, by = "Sepal.Width = fun(2:5)")
#' ```
#'
#' **Note:** If `by` specifies variables *with* their associated values,
#' argument `length` is ignored.
#'
#' There is a special handling of assignments with _brackets_, i.e. values
#' defined inside `[` and `]`, which create summaries for *numeric* variables.
#' Following "tokens" that creates pre-defined representative values are
#' possible:
#'
#' - for mean and -/+ 1 SD around the mean: `"x = [sd]"`
#' - for median and -/+ 1 MAD around the median: `"x = [mad]"`
#' - for Tukey's five number summary (minimum, lower-hinge, median,
#' upper-hinge, maximum): `"x = [fivenum]"`
#' - for quartiles: `"x = [quartiles]"` (same as `"x = [fivenum]"`, but
#' *excluding* minimum and maximum)
#' - for terciles: `"x = [terciles]"`
#' - for terciles, *including* minimum and maximum: `"x = [terciles2]"`
#' - for a pretty value range: `"x = [pretty]"`
#' - for minimum and maximum value: `"x = [minmax]"`
#' - for 0 and the maximum value: `"x = [zeromax]"`
#' - for a random sample from all values: `"x = [sample <number>]"`, where
#' `<number>` should be a positive integer, e.g. `"x = [sample 15]"`.
#'
#' **Note:** the `length` argument will be ignored when using brackets-tokens.
#'
#' The remaining variables not specified in `by` will be fixed (see also arguments
#' `factors` and `numerics`).
#' @param length Length of numeric target variables selected in `by` (if no
#' representative values are additionally specified). This arguments controls
#' the number of (equally spread) values that will be taken to represent the
#' continuous (non-integer alike!) variables. A longer length will increase
#' precision, but can also substantially increase the size of the datagrid
#' (especially in case of interactions). If `NA`, will return all the unique
#' values.
#'
#' In case of multiple continuous target variables, `length` can also be a
#' vector of different values (see 'Examples'). In this case, `length` must be
#' of same length as numeric target variables. If `length` is a named vector,
#' values are matched against the names of the target variables.
#'
#' When `range = "range"` (the default), `length` is ignored for integer type
#' variables when `length` is larger than the number of unique values *and*
#' `protect_integers` is `TRUE` (default). Set `protect_integers = FALSE` to
#' create a spread of `length` number of values from minimum to maximum for
#' integers, including fractions (i.e., to treat integer variables as regular
#' numeric variables).
#'
#' `length` is furthermore ignored if "tokens" (in brackets `[` and `]`) are
#' used in `by`, or if representative values are additionally specified in
#' `by`.
#' @param range Option to control the representative values given in `by`, if no
#' specific values were provided. Use in combination with the `length`
#' argument to control the number of values within the specified range.
#' `range` can be one of the following:
#' - `"range"` (default), will use the minimum and maximum of the original
#' data vector as end-points (min and max). For integer variables, the
#' `length` argument will be ignored, and `"range"` will only use values
#' that appear in the data. Set `protect_integers = FALSE` to override this
#' behaviour for integer variables.
#' - if an interval type is specified, such as [`"iqr"`][IQR()],
#' [`"ci"`][bayestestR::ci()], [`"hdi"`][bayestestR::hdi()] or
#' [`"eti"`][bayestestR::eti()], it will spread the values within that range
#' (the default CI width is `95%` but this can be changed by adding for
#' instance `ci = 0.90`.) See [`IQR()`] and [`bayestestR::ci()`]. This can
#' be useful to have more robust change and skipping extreme values.
#' - if [`"sd"`][sd()] or [`"mad"`][mad()], it will spread by this dispersion
#' index around the mean or the median, respectively. If the `length`
#' argument is an even number (e.g., `4`), it will have one more step on the
#' positive side (i.e., `-1, 0, +1, +2`). The result is a named vector. See
#' 'Examples.'
#' - `"grid"` will create a reference grid that is useful when plotting
#' predictions, by choosing representative values for numeric variables
#' based on their position in the reference grid. If a numeric variable is
#' the first predictor in `by`, values from minimum to maximum of the same
#' length as indicated in `length` are generated. For numeric predictors not
#' specified at first in `by`, mean and -1/+1 SD around the mean are
#' returned. For factors, all levels are returned.
#'
#' `range` can also be a vector of different values (see 'Examples'). In this
#' case, `range` must be of same length as numeric target variables. If
#' `range` is a named vector, values are matched against the names of the
#' target variables.
#' @param factors Type of summary for factors *not* specified in `by`. Can be
#' `"reference"` (set at the reference level), `"mode"` (set at the most
#' common level) or `"all"` to keep all levels.
#' @param numerics Type of summary for numeric values *not* specified in `by`.
#' Can be `"all"` (will duplicate the grid for all unique values), any
#' function (`"mean"`, `"median"`, ...) or a value (e.g., `numerics = 0`).
#' @param preserve_range In the case of combinations between numeric variables
#' and factors, setting `preserve_range = TRUE` will drop the observations
#' where the value of the numeric variable is originally not present in the
#' range of its factor level. This leads to an unbalanced grid. Also, if you
#' want the minimum and the maximum to closely match the actual ranges, you
#' should increase the `length` argument.
#' @param reference The reference vector from which to compute the mean and SD.
#' Used when standardizing or unstandardizing the grid using `effectsize::standardize`.
#' @param include_smooth If `x` is a model object, decide whether smooth terms
#' should be included in the data grid or not.
#' @param include_random If `x` is a mixed model object, decide whether random
#' effect terms should be included in the data grid or not. If
#' `include_random` is `FALSE`, but `x` is a mixed model with random effects,
#' these will still be included in the returned grid, but set to their
#' "population level" value (e.g., `NA` for *glmmTMB* or `0` for *merMod*).
#' This ensures that common `predict()` methods work properly, as these
#' usually need data with all variables in the model included.
#' @param include_response If `x` is a model object, decide whether the response
#' variable should be included in the data grid or not.
#' @param protect_integers Defaults to `TRUE`. Indicates whether integers (whole
#' numbers) should be treated as integers (i.e., prevent adding any in-between
#' round number values), or - if `FALSE` - as regular numeric variables. Only
#' applies when `range = "range"` (the default), or if `range = "grid"` and the
#' first predictor in `by` is an integer.
#' @param data Optional, the data frame that was used to fit the model. Usually,
#' the data is retrieved via `get_data()`.
#' @param digits Number of digits used for rounding numeric values specified in
#' `by`. E.g., `x = [sd]` will round the mean and +-/1 SD in the data grid to
#' `digits`.
#' @param verbose Toggle warnings.
#' @param ... Arguments passed to or from other methods (for instance, `length`
#' or `range` to control the spread of numeric variables.).
#'
#' @return Reference grid data frame.
#'
#' @details
#' Data grids are an (artificial or theoretical) representation of the sample.
#' They consists of predictors of interest (so-called focal predictors), and
#' meaningful values, at which the sample characteristics (focal predictors)
#' should be represented. The focal predictors are selected in `by`. To select
#' meaningful (or representative) values, either use `by`, or use a combination
#' of the arguments `length` and `range`.
#'
#' @seealso [get_predicted()] to extract predictions, for which the data grid
#' is useful, and see the [methods][get_datagrid.emmGrid] for objects generated
#' by **emmeans** and **marginaleffects** to extract the "grid" columns.
#'
#' @examplesIf require("bayestestR", quietly = TRUE) && require("datawizard", quietly = TRUE)
#' # Datagrids of variables and dataframes =====================================
#' data(iris)
#' data(mtcars)
#'
#' # Single variable is of interest; all others are "fixed" ------------------
#'
#' # Factors, returns all the levels
#' get_datagrid(iris, by = "Species")
#' # Specify an expression
#' get_datagrid(iris, by = "Species = c('setosa', 'versicolor')")
#'
#' # Numeric variables, default spread length = 10
#' get_datagrid(iris, by = "Sepal.Length")
#' # change length
#' get_datagrid(iris, by = "Sepal.Length", length = 3)
#'
#' # change non-targets fixing
#' get_datagrid(iris[2:150, ],
#' by = "Sepal.Length",
#' factors = "mode", numerics = "median"
#' )
#'
#' # change min/max of target
#' get_datagrid(iris, by = "Sepal.Length", range = "ci", ci = 0.90)
#'
#' # Manually change min/max
#' get_datagrid(iris, by = "Sepal.Length = c(0, 1)")
#'
#' # -1 SD, mean and +1 SD
#' get_datagrid(iris, by = "Sepal.Length = [sd]")
#'
#' # rounded to 1 digit
#' get_datagrid(iris, by = "Sepal.Length = [sd]", digits = 1)
#'
#' # identical to previous line: -1 SD, mean and +1 SD
#' get_datagrid(iris, by = "Sepal.Length", range = "sd", length = 3)
#'
#' # quartiles
#' get_datagrid(iris, by = "Sepal.Length = [quartiles]")
#'
#' # Standardization and unstandardization
#' data <- get_datagrid(iris, by = "Sepal.Length", range = "sd", length = 3)
#'
#' # It is a named vector (extract names with `names(out$Sepal.Length)`)
#' data$Sepal.Length
#' datawizard::standardize(data, select = "Sepal.Length")
#'
#' # Manually specify values
#' data <- get_datagrid(iris, by = "Sepal.Length = c(-2, 0, 2)")
#' data
#' datawizard::unstandardize(data, select = "Sepal.Length")
#'
#'
#' # Multiple variables are of interest, creating a combination --------------
#'
#' get_datagrid(iris, by = c("Sepal.Length", "Species"), length = 3)
#' get_datagrid(iris, by = c("Sepal.Length", "Petal.Length"), length = c(3, 2))
#' get_datagrid(iris, by = c(1, 3), length = 3)
#' get_datagrid(iris, by = c("Sepal.Length", "Species"), preserve_range = TRUE)
#' get_datagrid(iris, by = c("Sepal.Length", "Species"), numerics = 0)
#' get_datagrid(iris, by = c("Sepal.Length = 3", "Species"))
#' get_datagrid(iris, by = c("Sepal.Length = c(3, 1)", "Species = 'setosa'"))
#'
#' # specify length individually for each focal predictor
#' # values are matched by names
#' get_datagrid(mtcars[1:4], by = c("mpg", "hp"), length = c(hp = 3, mpg = 2))
#'
#' # Numeric and categorical variables, generating a grid for plots
#' # default spread when numerics are first: length = 10
#' get_datagrid(iris, by = c("Sepal.Length", "Species"), range = "grid")
#'
#' # default spread when numerics are not first: length = 3 (-1 SD, mean and +1 SD)
#' get_datagrid(iris, by = c("Species", "Sepal.Length"), range = "grid")
#'
#' # range of values
#' get_datagrid(iris, by = c("Sepal.Width = 1:5", "Petal.Width = 1:3"))
#'
#' # With list-style by-argument
#' get_datagrid(
#' iris,
#' by = list(Sepal.Length = 1:3, Species = c("setosa", "versicolor"))
#' )
#'
#'
#' # With models ===============================================================
#'
#' # Fit a linear regression
#' model <- lm(Sepal.Length ~ Sepal.Width * Petal.Length, data = iris)
#'
#' # Get datagrid of predictors
#' data <- get_datagrid(model, length = c(20, 3), range = c("range", "sd"))
#' # same as: get_datagrid(model, range = "grid", length = 20)
#'
#' # Add predictions
#' data$Sepal.Length <- get_predicted(model, data = data)
#'
#' # Visualize relationships (each color is at -1 SD, Mean, and + 1 SD of Petal.Length)
#' plot(data$Sepal.Width, data$Sepal.Length,
#' col = data$Petal.Length,
#' main = "Relationship at -1 SD, Mean, and + 1 SD of Petal.Length"
#' )
#' @export
get_datagrid <- function(x, ...) {
UseMethod("get_datagrid")
}
# Functions for data.frames -----------------------------------------------
#' @rdname get_datagrid
#' @export
get_datagrid.data.frame <- function(x,
by = "all",
factors = "reference",
numerics = "mean",
length = 10,
range = "range",
preserve_range = FALSE,
protect_integers = TRUE,
digits = 3,
reference = x,
...) {
# find numerics that were coerced to factor in-formula
numeric_factors <- colnames(x)[vapply(x, function(i) isTRUE(attributes(i)$factor), logical(1))]
specs <- NULL
if (is.null(by)) {
targets <- data.frame()
} else {
# check for interactions in "by"
by <- .extract_at_interactions(by)
# Validate by argument ============================
# if list, convert to character
if (is.list(by)) {
by <- unname(vapply(names(by), function(i) {
if (is.numeric(by[[i]])) {
paste0(i, " = c(", toString(by[[i]]), ")")
} else {
paste0(i, " = c(", toString(sprintf("'%s'", by[[i]])), ")")
}
}, character(1)))
}
# if by is "all" or numeric or logical indices, extract related
# column names from data frame and use these as by-variables
if (all(by == "all")) {
by <- colnames(x)
}
if (is.numeric(by) || is.logical(by)) {
by <- colnames(x)[by]
}
# Deal with factor in-formula transformations ============================
# something like `y ~ as.factor(x)`. These attributes are available
# when data from fitted models is retrieved using `get_data()`
x[] <- lapply(x, function(i) {
if (isTRUE(attributes(i)$factor)) {
as.factor(i)
} else {
i
}
})
# Deal with logical in-formula transformations ============================
# something like `y ~ as.logical(x)`
x[] <- lapply(x, function(i) {
if (isTRUE(attributes(i)$logical)) {
as.logical(i)
} else {
i
}
})
# Deal with targets =======================================================
# Find eventual user-defined specifications for each target. Here we parse
# the `by` variable for user specified values or token, e.g. `by="mpg=c(40,50)"`
# or `by="mpg=[sd]"`.
specs <- do.call(
rbind,
lapply(by, .get_datagrid_clean_target, x = x, digits = digits)
)
# information about specification in our data frame should be a string
specs$varname <- as.character(specs$varname) # make sure it's a string not fac
specs <- specs[!duplicated(specs$varname), ] # Drop duplicates
# check and mark which focal predictors are factors/characters
specs$is_factor <- vapply(
x[specs$varname],
function(x) is.factor(x) || is.character(x),
TRUE
)
# Create target list of factors -----------------------------------------
facs <- list()
for (fac in specs[specs$is_factor, "varname"]) {
facs[[fac]] <- get_datagrid(
x[[fac]],
by = specs[specs$varname == fac, "expression"]
)
}
# Create target list of numerics ----------------------------------------
nums <- list()
numvars <- specs[!specs$is_factor, "varname"]
# dealing with numeric targets is a bit more complex than for factors. We
# may have a range of representative values, or certain meaningful values
# like mean/sd. Furthermore, the range of numeric representative values can
# be controlled with `length` and `range`, which we don't have/need for
# factors. thus, we must process these arguments, too.
if (length(numvars)) {
# Sanitize 'length' argument
if (length(length) == 1L) {
length <- rep(length, length(numvars))
} else if (length(length) != length(numvars)) {
format_error(
"The number of elements in `length` must match the number of numeric target variables (n = ",
length(numvars), ")."
)
}
# Sanitize 'range' argument
if (length(range) == 1) {
range <- rep(range, length(numvars))
} else if (length(range) != length(numvars)) {
format_error(
"The number of elements in `range` must match the number of numeric target variables (n = ",
length(numvars), ")."
)
}
# sanity check - do we have a named vector for `length`, and do all names
# match the numeric variables? If yes, match order
if (length(length) > 1 && !is.null(names(length)) && all(nzchar(names(length)))) {
if (!all(names(length) %in% numvars)) {
suggestion <- .misspelled_string(
numvars,
names(length),
default_message = "Please check the spelling."
)
format_error(paste0(
"Names of `length` do not match names of numeric variables specified in `by`. ",
suggestion$msg
))
}
length <- unname(length[match(names(length), numvars)])
}
# sanity check - do we have a named vector for `range`, and do all names
# match the numeric variables? If yes, match order
if (length(range) > 1 && !is.null(names(range)) && all(nzchar(names(range)))) {
if (!all(names(range) %in% numvars)) {
suggestion <- .misspelled_string(
numvars,
names(range),
default_message = "Please check the spelling."
)
format_error(paste0(
"Names of `range` do not match names of numeric variables specified in `by`. ",
suggestion$msg
))
}
range <- unname(range[match(names(range), numvars)])
}
# Get datagrids
for (i in seq_along(numvars)) {
num <- numvars[i]
nums[[num]] <- get_datagrid(x[[num]],
by = specs[specs$varname == num, "expression"],
reference = reference[[num]],
length = length[i],
range = range[i],
digits = digits,
protect_integers = protect_integers,
is_first_predictor = specs$varname[1] == num,
...
)
}
}
# Assemble the two - the goal is to have two named lists, where variable
# names are the names of the list-elements: one list contains elements of
# numeric variables, the other one factors.
targets <- expand.grid(c(nums, facs))
# sort targets data frame according to order specified in "by"
targets <- .safe(targets[specs$varname], targets)
# Preserve range ---------------------------------------------------------
if (preserve_range && length(facs) > 0 && length(nums) > 0L) {
# Loop through the combinations of factors
facs_combinations <- expand.grid(facs)
for (i in seq_len(nrow(facs_combinations))) {
# Query subset of original dataset
data_subset <- x[.data_match(x, to = facs_combinations[i, , drop = FALSE]), , drop = FALSE]
idx <- .data_match(targets, to = facs_combinations[i, , drop = FALSE])
# Skip if no instance of factor combination, drop the chunk
if (nrow(data_subset) == 0) {
targets <- targets[-idx, ]
break
}
# Else, filter given the range of numerics as they appear in the data
rows_to_remove <- NULL
for (num in names(nums)) {
mini <- min(data_subset[[num]], na.rm = TRUE)
maxi <- max(data_subset[[num]], na.rm = TRUE)
rows_to_remove <- c(rows_to_remove, which(targets[[num]] < mini | targets[[num]] > maxi))
}
if (length(rows_to_remove) > 0) {
targets <- targets[-idx[idx %in% rows_to_remove], ] # Drop incompatible rows
row.names(targets) <- NULL # Reset row.names
}
}
if (nrow(targets) == 0) {
format_error("No data left was left after range preservation. Try increasing `length` or setting `preserve_range` to `FALSE`.") # nolint
}
}
}
# Deal with the rest =========================================================
rest_vars <- names(x)[!names(x) %in% names(targets)]
if (length(rest_vars) >= 1) {
# set non-focal terms to mean/reference/...
rest_df <- lapply(
x[rest_vars],
.get_datagrid_summary,
numerics = numerics,
factors = factors,
...
)
rest_df <- expand.grid(rest_df, stringsAsFactors = FALSE)
if (nrow(targets) == 0) {
targets <- rest_df # If by = NULL
} else {
targets <- merge(targets, rest_df, sort = FALSE)
}
} else {
rest_vars <- NA
}
# Prepare output =============================================================
# Reset row names
row.names(targets) <- NULL
# convert factors back to numeric, if these variables were actually
# numeric in the original data
if (!is.null(numeric_factors) && length(numeric_factors)) {
for (i in numeric_factors) {
targets[[i]] <- .factor_to_numeric(targets[[i]])
}
}
# Attributes
attr(targets, "adjusted_for") <- rest_vars
attr(targets, "at_specs") <- specs
attr(targets, "at") <- by
attr(targets, "by") <- by
attr(targets, "preserve_range") <- preserve_range
attr(targets, "reference") <- reference
attr(targets, "data") <- x
# Printing decorations
attr(targets, "table_title") <- c("Visualisation Grid", "blue")
if (!(length(rest_vars) == 1 && is.na(rest_vars)) && length(rest_vars) >= 1) {
attr(targets, "table_footer") <- paste0("\nMaintained constant: ", toString(rest_vars))
}
if (!is.null(attr(targets, "table_footer"))) {
attr(targets, "table_footer") <- c(attr(targets, "table_footer"), "blue")
}
class(targets) <- unique(c("datagrid", "visualisation_matrix", class(targets)))
targets
}
# Functions that work on a vector (a single column) ----------------------
# See tests/test-get_datagrid.R for examples
## Numeric ------------------------------------
#' @rdname get_datagrid
#' @export
get_datagrid.numeric <- function(x,
length = 10,
range = "range",
protect_integers = TRUE,
digits = 3,
...) {
# Check and clean the target argument
specs <- .get_datagrid_clean_target(x, digits = digits, ...)
# If an expression is detected, run it and return it - we don't need
# to create any spread of values to cover the range; spread is user-defined
if (!is.na(specs$expression)) {
return(eval(parse(text = specs$expression)))
}
# If NA, return all unique
if (is.na(length)) {
return(sort(unique(x)))
}
# validation check
if (!is.numeric(length)) {
format_error("`length` argument should be an number.")
}
# Create a spread
.create_spread(
x,
length = length,
range = range,
digits = digits,
protect_integers = protect_integers,
...
)
}
#' @export
get_datagrid.double <- get_datagrid.numeric
## Factors & Characters ----------------------------------------------------
#' @rdname get_datagrid
#' @export
get_datagrid.factor <- function(x, ...) {
# Check and clean the target argument
specs <- .get_datagrid_clean_target(x, ...)
if (is.na(specs$expression)) {
# Keep only unique levels
if (is.factor(x)) {
out <- factor(levels(droplevels(x)), levels = levels(droplevels(x)))
} else {
out <- unique(x)
}
} else {
# Run the expression cleaned from target
out <- eval(parse(text = specs$expression))
}
out
}
#' @export
get_datagrid.character <- get_datagrid.factor
#' @export
get_datagrid.logical <- get_datagrid.character
# Functions that work on statistical models -------------------------------
#' @rdname get_datagrid
#' @export
get_datagrid.default <- function(x,
by = "all",
factors = "reference",
numerics = "mean",
preserve_range = TRUE,
reference = x,
include_smooth = TRUE,
include_random = FALSE,
include_response = FALSE,
data = NULL,
digits = 3,
verbose = TRUE,
...) {
# validation check
if (!is_model(x)) {
format_error("`x` must be a statistical model.")
}
# Retrieve data from model
data <- .get_model_data_for_grid(x, data)
data_attr <- attributes(data)
# save response - might be necessary to include
response <- find_response(x)
# check some exceptions here: logistic regression models with factor response
# usually require the response to be included in the model, else `get_modelmatrix()`
# fails, which is required to compute SE/CI for `get_predicted()`
minfo <- model_info(x, response = 1)
# check which response variables are possibly a factor. for multivariate
# models, "response" might be a vector, so we iterate with vapply() here
factor_response <- vapply(response, function(i) is.factor(data[[i]]), logical(1))
# any factor response for binomial?
if (minfo$is_binomial && minfo$is_logit && any(factor_response) && !include_response && verbose) {
format_warning(
"Logistic regression model has a categorical response variable. You may need to set `include_response=TRUE` to make it work for predictions." # nolint
)
}
# Deal with intercept-only models
if (isFALSE(include_response)) {
data <- data[!colnames(data) %in% response]
if (ncol(data) < 1L) {
format_error("Model only seems to be an intercept-only model. Use `include_response=TRUE` to create the reference grid.") # nolint
}
}
# check for interactions in "by"
by <- .extract_at_interactions(by)
# Drop random factors
random_factors <- find_random(x, flatten = TRUE, split_nested = TRUE)
if (!is.null(random_factors)) {
if (isFALSE(include_random)) {
# drop random factors, if these should not be included
keep <- c(find_predictors(x, effects = "fixed", flatten = TRUE), response)
if (!is.null(keep)) {
if (all(by != "all")) {
keep <- c(keep, by[by %in% random_factors])
random_factors <- setdiff(random_factors, by)
}
data <- data[colnames(data) %in% keep]
}
} else {
# make sure random factors are not numeric, else, wrong "levels" will be returned
data[random_factors] <- lapply(data[random_factors], as.factor)
}
}
# user wants to include all predictors?
if (all(by == "all")) by <- colnames(data)
# exluce smooth terms?
if (isFALSE(include_smooth) || identical(include_smooth, "fixed")) {
s <- find_smooth(x, flatten = TRUE)
if (!is.null(s)) {
by <- colnames(data)[!colnames(data) %in% clean_names(s)]
}
}
# set back custom attributes
data <- .replace_attr(data, data_attr)
vm <- get_datagrid(
data,
by = by,
factors = factors,
numerics = numerics,
preserve_range = preserve_range,
reference = data,
digits = digits,
...
)
# we still need random factors in data grid. we set these to
# "population level" if not conditioned on via "by"
if (isFALSE(include_random) && !is.null(random_factors)) {
if (inherits(x, c("glmmTMB", "brmsfit", "MixMod"))) {
vm[random_factors] <- NA
} else if (inherits(x, c("merMod", "rlmerMod", "lme"))) {
vm[random_factors] <- 0
}
}
# if model has weights, we need to add a dummy for certain classes, e.g. glmmTMB
w <- insight::find_weights(x)
if (!inherits(x, "brmsfit") && !is.null(w)) {
# for lme, can't be NA
if (inherits(x, c("lme", "gls"))) {
vm[w] <- 1
} else {
vm[w] <- NA_real_
}
}
if (isFALSE(include_smooth)) {
vm[colnames(vm) %in% clean_names(find_smooth(x, flatten = TRUE))] <- NULL
}
attr(vm, "model") <- x
vm
}
#' @export
get_datagrid.logitr <- function(x, ...) {
datagrid <- get_datagrid.default(x, ...)
obsID <- parse(text = safe_deparse(get_call(x)))[[1]]$obsID
datagrid[[obsID]] <- x$data[[obsID]][1]
datagrid
}
#' @export
get_datagrid.wbm <- function(x,
by = "all",
factors = "reference",
numerics = "mean",
preserve_range = TRUE,
reference = x,
include_smooth = TRUE,
include_random = FALSE,
data = NULL,
...) {
# Retrieve data from model
data <- .get_model_data_for_grid(x, data)
# add id and time variables
data[[x@call_info$id]] <- levels(stats::model.frame(x)[[x@call_info$id]])[1]
wave <- stats::model.frame(x)[[x@call_info$wave]]
if (is.factor(wave)) {
data[[x@call_info$wave]] <- levels(wave)[1]
} else {
data[[x@call_info$wave]] <- mean(wave)
}
# clean variable names
colnames(data) <- clean_names(colnames(data))
get_datagrid.default(
x = x, by = by, factors = factors, numerics = numerics,
preserve_range = preserve_range, reference = reference,
include_smooth = include_smooth, include_random = include_random,
include_response = TRUE, data = data, ...
)
}
# Functions that work on get_datagrid -------------------------------------
#' @export
get_datagrid.visualisation_matrix <- function(x,
reference = attributes(x)$reference,
...) {
datagrid <- get_datagrid(as.data.frame(x), reference = reference, ...)
if ("model" %in% names(attributes(x))) {
attr(datagrid, "model") <- attributes(x)$model
}
datagrid
}
#' @export
get_datagrid.datagrid <- get_datagrid.visualisation_matrix
# Functions for emmeans/marginaleffects ---------------
#' Extract a reference grid from objects created by `{emmeans}` and `{marginaleffects}`
#'
#' @param x An object created by a function such as [emmeans::emmeans()],
#' [marginaleffects::slopes()], etc.
#' @param ... Currently not used
#'
#' @details
#' Note that for `{emmeans}` inputs the results is a proper grid (all
#' combinations of values are represented), except when a nesting structure is
#' detected. Additionally, when the input is an `emm_list` object, the function
#' will `rbind()` the data-grids of all the elements in the input.
#'
#' For `{marginaleffects}` inputs, the output may very well be a non-grid
#' result. See examples.
#'
#' @return A `data.frame` with key columns that identify the rows in `x`.
#'
#' @examples
#' data("mtcars")
#' mtcars$cyl <- factor(mtcars$cyl)
#'
#' mod <- glm(am ~ cyl + hp + wt,
#' family = binomial("logit"),
#' data = mtcars
#' )
#'
#' @examplesIf insight::check_if_installed("emmeans", quietly = TRUE)
#' em1 <- emmeans::emmeans(mod, ~ cyl + hp, at = list(hp = c(100, 150)))
#' get_datagrid(em1)
#'
#' contr1 <- emmeans::contrast(em1, method = "consec", by = "hp")
#' get_datagrid(contr1)
#'
#' eml1 <- emmeans::emmeans(mod, pairwise ~ cyl | hp, at = list(hp = c(100, 150)))
#' get_datagrid(eml1) # not a "true" grid
#'
#' @examplesIf insight::check_if_installed("marginaleffects", quietly = TRUE, minimum_version = "0.25.0")
#' mfx1 <- marginaleffects::slopes(mod, variables = "hp")
#' get_datagrid(mfx1) # not a "true" grid
#'
#' mfx2 <- marginaleffects::slopes(mod, variables = c("hp", "wt"), by = "am")
#' get_datagrid(mfx2)
#'
#' contr2 <- marginaleffects::avg_comparisons(mod)
#' get_datagrid(contr2) # not a "true" grid
#' @export
get_datagrid.emmGrid <- function(x, ...) {
suppressWarnings({
s <- as.data.frame(x)
})
# We want all the columns *before* the estimate column
est_col_idx <- which(colnames(s) == attr(s, "estName"))
which_cols <- seq_len(est_col_idx - 1)
data.frame(s)[, which_cols, drop = FALSE]
}
#' @export
get_datagrid.emm_list <- function(x, ...) {
k <- length(x)
res <- vector("list", length = k)
for (i in seq_len(k)) {
res[[i]] <- get_datagrid(x[[i]])
}
all_cols <- Reduce(lapply(res, colnames), f = union)
for (i in seq_len(k)) {
res[[i]][, setdiff(all_cols, colnames(res[[i]]))] <- NA
}
out <- do.call("rbind", res)
clear_cols <- colnames(out)[sapply(out, Negate(anyNA))] # these should be first
out[, c(clear_cols, setdiff(colnames(out), clear_cols)), drop = FALSE]
}
#' @export
get_datagrid.slopes <- function(x, ...) {
cols_newdata <- colnames(attr(x, "newdata"))
cols_contrast <- colnames(x)[grep("^contrast_?", colnames(x))]
cols_misc <- c("term", "by", "hypothesis")
cols_grid <- union(union(cols_newdata, cols_contrast), cols_misc)
data.frame(x)[, intersect(colnames(x), cols_grid), drop = FALSE]
}
#' @export
get_datagrid.predictions <- get_datagrid.slopes
#' @export
get_datagrid.comparisons <- get_datagrid.slopes
# Utilities -----------------------------------------------------------------
# This function extract representative values specified in the `by` argument,
# e.g. `by="mpg=c(20,30,40)"` or `by="mpg=[sd]"`
#' @keywords internal
.get_datagrid_clean_target <- function(x, by = NULL, digits = 3, ...) {
by_expression <- NA
varname <- NA
original_target <- by
if (!is.null(by)) {
if (is.data.frame(x) && by %in% names(x)) {
return(data.frame(varname = by, expression = NA))
}
# If there is an equal sign
if (grepl("length.out =", by, fixed = TRUE)) {
by_expression <- by # This is an edgecase
} else if (grepl("=", by, fixed = TRUE)) {
parts <- trim_ws(unlist(strsplit(by, "=", fixed = TRUE), use.names = FALSE)) # Split and clean
varname <- parts[1] # left-hand part is probably the name of the variable
by <- parts[2] # right-hand part is the real target
}
# we have no edge case (by_expression = NA), thus we expect to have a variable
# name in `varname` and the values in `by`
if (is.na(by_expression) && is.data.frame(x)) {
# if data grid should be made for a data frame (and not a model object),
# check if the specified variables names are in the data frame.
if (is.na(varname) || !varname %in% colnames(x)) {
# we can either just have a variable name in `by`, then we evaluate `by`
# here. Or we can have an equal-sign, so the variable name is saved in
# `varname` - in this case, overwrite `by` with `varname` for informative
# error message.
if (!is.na(varname)) by <- varname
suggestion <- .misspelled_string(
colnames(x),
by,
default_message = "Please check the spelling."
)
format_error(paste0(
"Variable `", by, "` was not found in the data. ", # nolint
suggestion$msg
))
} else {
x <- x[[varname]]
}
}
# Tokens: If brackets are detected [a, b] --------------------
# ------------------------------------------------------------
if (is.na(by_expression) && grepl("\\[.*\\]", by)) {
# Clean --------------------
# Keep the content
parts <- trim_ws(unlist(regmatches(by, gregexpr("\\[.+?\\]", by)), use.names = FALSE))
# Drop the brackets
parts <- gsub("\\[|\\]", "", parts)
# do we have a range, indicated by colon? If yes, we just want these two
# values (i.e. we replace : by ,) and set the range-indicator to TRUE
if (grepl(":", parts, fixed = TRUE)) {
parts <- gsub(":", ",", parts, fixed = TRUE)
is_range <- TRUE
} else {
is_range <- FALSE
}
# Split by a separator like ','
parts <- trim_ws(unlist(strsplit(parts, ",", fixed = TRUE), use.names = FALSE))
# If the elements have quotes around them, drop them
if (all(grepl("\\'.*\\'", parts))) parts <- gsub("'", "", parts, fixed = TRUE)
if (all(grepl('\\".*\\"', parts))) parts <- gsub('"', "", parts, fixed = TRUE)
# Make expression ----------
shortcuts <- c(
"meansd", "sd", "mad", "quartiles", "zeromax", "minmax", "terciles",
"terciles2", "fivenum", "pretty"
)
if ((is.factor(x) && all(parts %in% levels(x))) || (is.character(x) && all(parts %in% x))) {
# Factor ----------------
# -----------------------
# Add quotes around them
parts <- paste0("'", parts, "'")
# Convert to character
by_expression <- paste0("as.factor(c(", toString(parts), "))")
} else if (length(parts) == 1) {
# If only one value, might be a shortcut. or a sampling request ----
# ------------------------------------------------------------------
if (grepl("sample", parts, fixed = TRUE)) {
n_to_sample <- suppressWarnings(as.numeric(trim_ws(gsub("sample", "", parts, fixed = TRUE))))
# do we have a proper definition of the sample size? If not, error
if (is.null(n_to_sample) || is.na(n_to_sample) || !length(n_to_sample)) {
format_error("The token `sample` must be followed by the number of samples to be drawn, e.g. `[sample 15]`.") # nolint
}
by_expression <- paste0("c(", paste(sample(x, n_to_sample), collapse = ","), ")")
} else if (parts %in% shortcuts) {
if (parts %in% c("meansd", "sd")) {
center <- mean(x, na.rm = TRUE)
spread <- stats::sd(x, na.rm = TRUE)
by_expression <- paste0("c(", round(center - spread, digits), ",", round(center, digits), ",", round(center + spread, digits), ")") # nolint
} else if (parts == "mad") {
center <- stats::median(x, na.rm = TRUE)
spread <- stats::mad(x, na.rm = TRUE)
by_expression <- paste0("c(", round(center - spread, digits), ",", round(center, digits), ",", round(center + spread, digits), ")") # nolint
} else if (parts == "fivenum") {
by_expression <- paste0("c(", paste(round(as.vector(stats::fivenum(x, na.rm = TRUE)), digits), collapse = ","), ")") # nolint
} else if (parts == "quartiles") {
by_expression <- paste0("c(", paste(round(as.vector(stats::quantile(x, na.rm = TRUE))[2:4], digits), collapse = ","), ")") # nolint
} else if (parts == "terciles") {
by_expression <- paste0("c(", paste(round(as.vector(stats::quantile(x, probs = (1:2) / 3, na.rm = TRUE)), digits), collapse = ","), ")") # nolint
} else if (parts == "terciles2") {
by_expression <- paste0("c(", paste(round(as.vector(stats::quantile(x, probs = (0:3) / 3, na.rm = TRUE)), digits), collapse = ","), ")") # nolint
} else if (parts == "pretty") {
by_expression <- paste0("c(", paste(as.vector(pretty(x, na.rm = TRUE)), collapse = ","), ")")
} else if (parts == "zeromax") {
by_expression <- paste0("c(0,", round(max(x, na.rm = TRUE), digits), ")")
} else if (parts == "minmax") {
by_expression <- paste0("c(", round(min(x, na.rm = TRUE), digits), ",", round(max(x, na.rm = TRUE), digits), ")")
}
} else if (is.numeric(parts)) {
# if value in brackets is not a character, it must be numeric ---
# ---------------------------------------------------------------
by_expression <- parts
} else {
by_expression <- NULL
}
} else if (is.numeric(x)) {
# Target variable is a numeric ------------
# -----------------------------------------
if (is_range && length(parts) == 2) {
# If we have a two values and range-indictor is TRUE, we have a range
by_expression <- paste0("seq(", parts[1], ", ", parts[2], ", length.out = length)")
} else {
# Else we have single values
parts <- as.numeric(parts)
by_expression <- paste0("c(", toString(parts), ")")
}
} else {
by_expression <- NULL
}
if (is.null(by_expression)) {
format_error(
paste0(
"The `by` argument (", by, ") should either indicate a valid factor level, the minimum and the maximum value of a vector, or one of the following options: ", # nolint
toString(shortcuts),
"."
)
)
}
# Else, try to directly eval the content --------
# -----------------------------------------------
} else {
by_expression <- by
# Try to eval and make sure it works
tryCatch(
{
# This is just to make sure that an expression with `length` in
# it doesn't fail because of this undefined var
length <- 10 # nolint
.dynEval(by)
},
error = function(r) {
format_error(
paste0("The `by` argument (`", original_target, "`) cannot be read and could be mispecified.")
)
}
)
}
}
data.frame(varname = varname, expression = by_expression, stringsAsFactors = FALSE)
}
# This functions deals with the non-focal predictors, i.e. sets numerics
# to their mean (or other value), factors to reference etc.
#' @keywords internal
.get_datagrid_summary <- function(x,
numerics = "mean",
factors = "reference",
remove_na = TRUE,
...) {
if (remove_na) x <- stats::na.omit(x)
if (is.numeric(x)) {
if (is.numeric(numerics)) {
# numerics set to a specific value
out <- numerics
} else if (numerics %in% c("all", "combination")) {
# all values in the variable are preserved
out <- unique(x)
} else {
# we have a function in "numerics", which is applied here
out <- eval(parse(text = paste0(numerics, "(x)")))
}
} else if (factors %in% c("all", "combination")) {
out <- unique(x)
} else if (factors == "mode") {
# Get mode
out <- as.character(.mode_value(x))
} else {
# Get reference
if (is.factor(x)) {
all_levels <- levels(x)
} else if (is.character(x) || is.logical(x)) {
all_levels <- unique(x)
} else {
format_error(paste0(
"Argument is not numeric nor factor but ", class(x), ".",
"Please report the bug at https://github.com/easystats/insight/issues"
))
}
# see "get_modelmatrix()" and #626. Reference level is currently
# a character vector, which causes the error
# > Error in `contrasts<-`(`*tmp*`, value = contr.funs[1 + isOF[nn]]) :
# > contrasts can be applied only to factors with 2 or more levels
# this is usually avoided by calling ".pad_modelmatrix()", but this
# function ignores character vectors. so we need to make sure that this
# factor level is also of class factor.
out <- factor(all_levels[1])
# although we have reference level only, we still need information
# about all levels, see #695
levels(out) <- all_levels
}
out
}
#' @keywords internal
.mode_value <- function(x) {
uniqv <- unique(x)
tab <- tabulate(match(x, uniqv))
idx <- which.max(tab)
uniqv[idx]
}
#' @keywords internal
.create_spread <- function(x,
length = 10,
range = "range",
ci = 0.95,
digits = 3,
protect_integers = TRUE,
...) {
range <- validate_argument(
tolower(range),
c("range", "iqr", "ci", "hdi", "eti", "sd", "mad", "grid")
)
# bayestestR only for some options
if (range %in% c("ci", "hdi", "eti")) {
check_if_installed("bayestestR")
}
# check if range = "grid" - then use mean/sd for every numeric that
# is not first predictor...
if (range == "grid") {
# if not first predictor, we want range = "sd" and length 3, i.e. numerics
# at 2nd or 3rd position should represent 3 value (mean +/- 1 SD). we set
# range to "sd" later, because we do *not* want SD when the first predictor
# is of type integer
if (isFALSE(list(...)$is_first_predictor)) {
length <- 3
range <- "sd"
}
# if we want a representative grid, and have integers as first focal
# predictors, we want at maximum all valid / unique values, but *not* a
# spread with fractions. This behaviour can only be overriden by setting
# protect_integers = FALSE
if (isTRUE(list(...)$is_first_predictor) && all(.is_integer(x)) && n_unique(x) < length && protect_integers) {
length <- n_unique(x)
}
}
# for integer values, we don't want a range with fractions, so we shorten
# length if necessary. This means, for numerics with, say, two or three values,
# we still have these two or three values after creating the spread. This
# behaviour can only be overriden by setting protect_integers = FALSE
if (all(.is_integer(x)) && n_unique(x) < length && range == "range" && protect_integers) {
length <- n_unique(x)
}
# If Range is a dispersion (e.g., SD or MAD)
if (range %in% c("sd", "mad")) {
spread <- -floor((length - 1) / 2):ceiling((length - 1) / 2)
if (range == "sd") {
disp <- stats::sd(x, na.rm = TRUE)
center <- mean(x, na.rm = TRUE)
labs <- ifelse(sign(spread) == -1, paste(spread, "SD"),
ifelse(sign(spread) == 1, paste0("+", spread, " SD"), "Mean") # nolint
)
} else {
disp <- stats::mad(x, na.rm = TRUE)
center <- stats::median(x, na.rm = TRUE)
labs <- ifelse(sign(spread) == -1, paste(spread, "MAD"),
ifelse(sign(spread) == 1, paste0("+", spread, " MAD"), "Median") # nolint
)
}
out <- round(center + spread * disp, digits)
names(out) <- labs
return(out)
}
# If Range is an interval
if (range == "iqr") { # nolint
mini <- stats::quantile(x, (1 - ci) / 2, ...)
maxi <- stats::quantile(x, (1 + ci) / 2, ...)
} else if (range == "ci") {
out <- bayestestR::ci(x, ci = ci, verbose = FALSE, ...)
mini <- out$CI_low
maxi <- out$CI_high
} else if (range == "eti") {
out <- bayestestR::eti(x, ci = ci, verbose = FALSE, ...)
mini <- out$CI_low
maxi <- out$CI_high
} else if (range == "hdi") {
out <- bayestestR::hdi(x, ci = ci, verbose = FALSE, ...)
mini <- out$CI_low
maxi <- out$CI_high
} else {
mini <- min(x, na.rm = TRUE)
maxi <- max(x, na.rm = TRUE)
}
round(seq(mini, maxi, length.out = length), digits)
}
#' @keywords internal
.data_match <- function(x, to, ...) {
if (!is.data.frame(to)) {
to <- as.data.frame(to)
}
idx <- seq_len(nrow(x))
for (col in names(to)) {
if (col %in% names(x)) {
idx <- idx[x[[col]][idx] %in% to[[col]]]
}
}
.to_numeric(row.names(x)[idx])
}
#' @keywords internal
.get_model_data_for_grid <- function(x, data) {
# Retrieve data, based on variable names
if (is.null(data)) {
data <- get_data(x, verbose = FALSE)
# make sure we only have variables from original data
all_vars <- find_variables(x, effects = "all", component = "all", flatten = TRUE)
if (!is.null(all_vars)) {
data <- .safe(data[intersect(all_vars, colnames(data))], data)
}
}
# still found no data - stop here
if (is.null(data)) {
format_error(
"Can't access data that was used to fit the model in order to create the reference grid.",
"Please use the `data` argument."
)
}
# find variables that were coerced on-the-fly
model_terms <- find_terms(x, flatten = TRUE)
# something like `y ~ as.factor(x)`. These attributes are available when data
# from fitted models is retrieved using `get_data(source = "mf")`. Since we
# do not retrieve data from the model frame, we do this step here manually
factors <- grepl("^(as\\.factor|as_factor|factor|as\\.ordered|ordered)\\((.*)\\)", model_terms)
if (any(factors)) {
factor_expressions <- lapply(model_terms[factors], str2lang)
cleaned_terms <- vapply(factor_expressions, all.vars, character(1))
for (i in cleaned_terms) {
if (is.numeric(data[[i]])) {
attr(data[[i]], "factor") <- TRUE
}
}
attr(data, "factors") <- cleaned_terms
}
logicals <- grepl("^(as\\.logical|as_logical|logical)\\((.*)\\)", model_terms)
if (any(logicals)) {
logical_expressions <- lapply(model_terms[logicals], str2lang)
cleaned_terms <- vapply(logical_expressions, all.vars, character(1))
for (i in cleaned_terms) {
if (is.numeric(data[[i]])) {
attr(data[[i]], "logical") <- TRUE
}
}
attr(data, "logicals") <- cleaned_terms
}
data
}
#' @keywords internal
.extract_at_interactions <- function(by) {
# don't process when by is a list
if (is.list(by)) {
return(by)
}
# get interaction terms, but only if these are not inside brackets (like "[4:8]")
# or parenthesis (like "c(1:3)").Furthermore, "interaction terms" only refer
# to a value without equal-sign, i.e. `by = "a:b"` is an interaction, but
# `by = "mpg=10:20"` is not.
pattern <- "(:|\\*)(?![^\\[]*\\])(?![^\\(]*\\))"
interaction_terms <- grepl(pattern, by, perl = TRUE) & !grepl("=", by, fixed = TRUE)
if (any(interaction_terms)) {
by <- unique(clean_names(trim_ws(compact_character(c(
by[!interaction_terms],
unlist(strsplit(by[interaction_terms], "(:|\\*)"))
)))))
}
by
}
#' @keywords internal
.replace_attr <- function(data, custom_attr) {
for (nm in setdiff(names(custom_attr), names(attributes(data.frame())))) {
attr(data, which = nm) <- custom_attr[[nm]]
}
data
}
#' @keywords internal
.dynEval <- function(x, minframe = 1L, remove_n_top_env = 0) {
n <- sys.nframe() - remove_n_top_env
x <- safe_deparse(x)
while (n > minframe) {
n <- n - 1L
env <- sys.frame(n)
r <- try(eval(str2lang(x), envir = env), silent = TRUE)
if (!inherits(r, "try-error") && !is.null(r)) {
return(r)
}
}
stop()
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.