Nothing
sanitize_newdata_call <- function(scall, newdata = NULL, model) {
if (rlang::quo_is_call(scall)) {
if (rlang::call_name(scall) %in% c("datagrid", "datagridcf", "typical", "counterfactual")) {
if (!"model" %in% rlang::call_args_names(scall)) {
scall <- rlang::call_modify(scall, model = model)
}
} else if (rlang::call_name(scall) %in% "visualisation_matrix") {
if (!"x" %in% rlang::call_args_names(scall)) {
scall <- rlang::call_modify(scall, x = get_modeldata)
}
}
out <- rlang::eval_tidy(scall)
} else {
out <- newdata
}
return(out)
}
build_newdata <- function(model, newdata, by, modeldata) {
if (isTRUE(checkmate::check_data_frame(by))) {
by <- setdiff(colnames(by), "by")
} else if (isTRUE(checkmate::check_flag(by))) {
by <- NULL
}
args <- list(model = model)
for (b in by) {
args[[b]] <- unique
}
newdata_explicit <- TRUE
# NULL -> modeldata
if (is.null(newdata)) {
newdata <- modeldata
newdata_explicit <- FALSE
# string -> datagrid()
} else if (identical(newdata, "mean")) {
newdata <- do.call("datagrid", args)
} else if (identical(newdata, "median")) {
args[["FUN_numeric"]] <- args[["FUN_integer"]] <- args[["FUN_logical"]] <- function(x) stats::median(x, na.rm = TRUE)
newdata <- do.call("datagrid", args)
} else if (identical(newdata, "tukey")) {
args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
newdata <- do.call("datagrid", args)
} else if (identical(newdata, "grid")) {
args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
args[["FUN_factor"]] <- args[["FUN_character"]] <- args[["FUN_logical"]] <- unique
newdata <- do.call("datagrid", args)
# grid with all unique values of categorical variables, and numerics at their means
} else if (identical(newdata, "marginalmeans")) {
args[["FUN_factor"]] <- args[["FUN_character"]] <- args[["FUN_logical"]] <- unique
newdata <- do.call("datagrid", args)
# Issue #580: outcome should not duplicate grid rows
dv <- hush(insight::find_response(model))
if (isTRUE(dv %in% colnames(newdata))) {
newdata[[dv]] <- get_mean_or_mode(newdata[[dv]])
newdata <- unique(newdata)
}
}
if (!inherits(newdata, "data.frame")) {
msg <- "Unable to extract the data from model of class `%s`. This can happen in a variety of cases, such as when a `marginaleffects` package function is called from inside a user-defined function, or using an `*apply()`-style operation on a list. Please supply a data frame explicitly via the `newdata` argument."
msg <- sprintf(msg, class(model)[1])
insight::format_error(msg)
}
# otherwise we get a warning in setDT()
if (inherits(model, "mlogit") && isTRUE(inherits(modeldata[["idx"]], "idx"))) {
modeldata$idx <- NULL
}
# required by some model-fitting functions
data.table::setDT(modeldata)
# required for the type of column indexing to follow
data.table::setDF(newdata)
out <- list(
"newdata" = newdata,
"newdata_explicit" = newdata_explicit,
"modeldata" = modeldata
)
return(out)
}
add_wts_column <- function(wts, newdata) {
# weights must be available in the `comparisons()` function, NOT in
# `tidy()`, because comparisons will often duplicate newdata for
# multivariate outcomes and the like. We need to track which row matches
# which.
if (!is.null(wts)) {
flag1 <- isTRUE(checkmate::check_string(wts)) && isTRUE(wts %in% colnames(newdata))
flag2 <- isTRUE(checkmate::check_numeric(wts, len = nrow(newdata)))
if (!flag1 && !flag2) {
msg <- sprintf("The `wts` argument must be a numeric vector of length %s, or a string which matches a column name in `newdata`. If you did not supply a `newdata` explicitly, `marginaleffects` extracted it automatically from the model object, and the `wts` variable may not have been available. The easiest strategy is often to supply a data frame such as the original data to `newdata` explicitly, and to make sure that it includes an appropriate column of weights, identified by the `wts` argument.", nrow(newdata))
stop(msg, call. = FALSE)
}
}
# weights: before sanitize_variables
if (!is.null(wts) && isTRUE(checkmate::check_string(wts))) {
newdata[["marginaleffects_wts_internal"]] <- newdata[[wts]]
} else {
newdata[["marginaleffects_wts_internal"]] <- wts
}
return(newdata)
}
set_newdata_attributes <- function(model, modeldata, newdata, newdata_explicit) {
attr(newdata, "newdata_explicit") <- newdata_explicit
# column classes
mc <- Filter(function(x) is.matrix(modeldata[[x]]), colnames(modeldata))
cl <- Filter(function(x) is.character(modeldata[[x]]), colnames(modeldata))
cl <- lapply(modeldata[, ..cl], unique)
vc <- attributes(modeldata)$marginaleffects_variable_class
column_attributes <- list(
"matrix_columns" = mc,
"character_levels" = cl,
"variable_class" = vc)
newdata <- set_marginaleffects_attributes(newdata, column_attributes, prefix = "newdata_")
# {modelbased} sometimes attaches useful attributes
exclude <- c("class", "row.names", "names", "data", "reference")
modelbased_attributes <- get_marginaleffects_attributes(newdata, exclude = exclude)
newdata <- set_marginaleffects_attributes(newdata, modelbased_attributes, prefix = "newdata_")
# original data
attr(newdata, "newdata_modeldata") <- modeldata
if (is.null(attr(newdata, "marginaleffects_variable_class"))) {
newdata <- set_variable_class(newdata, model = model)
}
return(newdata)
}
clean_newdata <- function(model, newdata) {
# rbindlist breaks on matrix columns
idx <- sapply(newdata, function(x) class(x)[1] == "matrix")
if (any(idx)) {
# Issue #363
# unpacking matrix columns works with {mgcv} but breaks {mclogit}
if (inherits(model, "gam")) {
newdata <- unpack_matrix_cols(newdata)
} else {
newdata <- newdata[, !idx, drop = FALSE]
}
}
# we will need this to merge the original data back in, and it is better to
# do it in a centralized upfront way.
if (!"rowid" %in% colnames(newdata)) {
newdata$rowid <- seq_len(nrow(newdata))
}
# mlogit: each row is an individual-choice, but the index is not easily
# trackable, so we pre-sort it here, and the sort in `get_predict()`. We
# need to cross our fingers, but this probably works.
if (inherits(model, "mlogit") && isTRUE(inherits(newdata[["idx"]], "idx"))) {
idx <- list(newdata[["idx"]][, 1], newdata[["idx"]][, 2])
newdata <- newdata[order(newdata[["idx"]][, 1], newdata[["idx"]][, 2]),]
}
# placeholder response
resp <- insight::find_response(model)
if (isTRUE(checkmate::check_character(resp, len = 1)) &&
!resp %in% colnames(newdata)) {
y <- hush(insight::get_response(model))
# protect df or matrix response
if (isTRUE(checkmate::check_atomic_vector(y))) {
newdata[[resp]] <- y[1]
}
}
return(newdata)
}
sanitize_newdata <- function(model, newdata, by, modeldata, wts) {
checkmate::assert(
checkmate::check_data_frame(newdata, null.ok = TRUE),
checkmate::check_choice(newdata, choices = c("mean", "median", "tukey", "grid", "marginalmeans")),
combine = "or")
tmp <- build_newdata(model = model, newdata = newdata, by = by, modeldata = modeldata)
newdata <- tmp[["newdata"]]
modeldata <- tmp[["modeldata"]]
newdata_explicit <- tmp[["newdata_explicit"]]
newdata <- clean_newdata(model, newdata)
newdata <- add_wts_column(newdata = newdata, wts = wts)
newdata <- set_newdata_attributes(
model = model,
modeldata = modeldata,
newdata = newdata,
newdata_explicit = newdata_explicit)
# sort rows of output when the user explicitly calls `by` or `datagrid()`
# otherwise, we return the same data frame in the same order, but
# here it makes sense to sort for a clean output.
sortcols <- attr(newdata, "newdata_variables_datagrid")
if (isTRUE(checkmate::check_character(by))) {
sortcols <- c(by, sortcols)
}
sortcols <- intersect(sortcols, colnames(newdata))
out <- data.table::copy(newdata)
if (length(sortcols) > 0) {
data.table::setorderv(out, cols = sortcols)
}
return(out)
}
dedup_newdata <- function(model, newdata, by, wts, comparison = "difference", cross = FALSE, byfun = NULL) {
flag <- isTRUE(checkmate::check_string(comparison, pattern = "avg"))
if (!flag && (
isFALSE(by) || # weights only make sense when we are marginalizing
!is.null(wts) ||
!is.null(byfun) ||
!isFALSE(cross) ||
isFALSE(getOption("marginaleffects_dedup", default = TRUE)))) {
return(newdata)
}
vclass <- attr(newdata, "marginaleffects_variable_class")
# copy to allow mod by reference later without overwriting newdata
out <- data.table(newdata)
dv <- hush(unlist(insight::find_response(model), use.names = FALSE))
if (isTRUE(checkmate::check_string(dv)) && dv %in% colnames(out)) {
out[, (dv) := NULL]
vclass <- vclass[names(vclass) != dv]
}
if ("rowid" %in% colnames(out)) {
out[, "rowid" := NULL]
}
categ <- c("factor", "character", "logical", "strata", "cluster", "binary")
if (!all(vclass %in% categ)) {
return(newdata)
}
cols <- colnames(out)
out <- out[, .("marginaleffects_wts_internal" = .N), by = cols]
data.table::setDF(out)
out[["rowid_dedup"]] <- seq_len(nrow(out))
attr(out, "marginaleffects_variable_class") <- vclass
return(out)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.