R/sanitize_newdata.R

Defines functions dedup_newdata sanitize_newdata add_wts_column build_newdata sanitize_newdata_call

sanitize_newdata_call <- function(scall, newdata = NULL, model, by = NULL) {
    if (rlang::quo_is_call(scall)) {
        df <- FALSE
        if (grepl("^datagrid", rlang::call_name(scall))) {
            if (!"model" %in% rlang::call_args_names(scall)) {
                scall <- rlang::call_modify(scall, model = model)
            }
        } else if (isTRUE(rlang::call_name(scall) == "data.frame")) {
            df <- TRUE
        } else if (isTRUE(rlang::call_name(scall) == "subset")) {
            argnames <- rlang::call_args_names(scall)
            if (!"x" %in% argnames && length(argnames) == 1) {
                tmp <- get_modeldata(model, additional_variables = by)
                scall <- rlang::call_modify(scall, x = tmp)
            }
        } else if (isTRUE(rlang::call_name(scall) == "filter")) {
            argnames <- rlang::call_args_names(scall)
            if (!".data" %in% argnames && length(argnames) == 1) {
                tmp <- get_modeldata(model, additional_variables = by)
                scall <- rlang::call_modify(scall, .data = tmp)
            }
        } else if (rlang::call_name(scall) %in% "visualisation_matrix") {
            if (!"x" %in% rlang::call_args_names(scall)) {
                scall <- rlang::call_modify(scall, x = get_modeldata)
            }
        }
        out <- rlang::eval_tidy(scall)
        # newdata=data.frame() all columns must be printed as explicit in print.R
        if (isTRUE(df)) {
            attr(out, "implicit") <- unique(c(attr(out, "implicit"), colnames(out)))
        }
    } else {
        out <- newdata
    }
    return(out)
}


build_newdata <- function(model, newdata, by, modeldata) {
    if (isTRUE(checkmate::check_data_frame(by))) {
        by <- setdiff(colnames(by), "by")
    } else if (isTRUE(checkmate::check_flag(by))) {
        by <- NULL
    }
    args <- list(model = model)
    for (b in by) {
        args[[b]] <- unique
    }

    newdata_explicit <- TRUE

    # NULL -> modeldata
    if (is.null(newdata)) {
        newdata <- modeldata
        newdata_explicit <- FALSE

        # string -> datagrid()
    } else if (identical(newdata, "mean")) {
        newdata <- do.call("datagrid", args)
    } else if (identical(newdata, "median")) {
        args[["FUN_numeric"]] <- args[["FUN_integer"]] <- args[[
            "FUN_logical"
        ]] <- function(x) stats::median(x, na.rm = TRUE)
        newdata <- do.call("datagrid", args)
    } else if (identical(newdata, "tukey")) {
        args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
        newdata <- do.call("datagrid", args)
    } else if (identical(newdata, "grid")) {
        args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
        args[["FUN_factor"]] <- args[["FUN_character"]] <- args[[
            "FUN_logical"
        ]] <- unique
        newdata <- do.call("datagrid", args)

        # grid with all unique values of categorical variables, and numerics at their means
    } else if (identical(newdata, "balanced")) {
        args[["grid_type"]] <- "balanced"
        newdata <- do.call("datagrid", args)
        # Issue #580: outcome should not duplicate grid rows
        dv <- hush(insight::find_response(model))
        if (isTRUE(dv %in% colnames(newdata))) {
            newdata[[dv]] <- get_mean_or_mode(newdata[[dv]])
            newdata <- unique(newdata)
        }
    }

    if (!inherits(newdata, "data.frame")) {
        msg <- "Unable to extract the data from model of class `%s`. This can happen in a variety of cases, such as when a `marginaleffects` package function is called from inside a user-defined function, or using an `*apply()`-style operation on a list. Please supply a data frame explicitly via the `newdata` argument."
        msg <- sprintf(msg, class(model)[1])
        insight::format_error(msg)
    }

    out <- list(
        "newdata" = newdata,
        "explicit" = newdata_explicit,
        "modeldata" = modeldata
    )
    return(out)
}


add_wts_column <- function(wts, newdata, model) {
    # weights must be available in the `comparisons()` function, NOT in
    # `tidy()`, because comparisons will often duplicate newdata for
    # multivariate outcomes and the like. We need to track which row matches
    # which.
    if (isFALSE(wts)) {
        return(newdata)
    }

    if (isTRUE(wts)) {
        wtsname <- insight::find_weights(model)
        if (
            !is.character(wtsname) ||
                length(wtsname) != 1 ||
                !wtsname %in% colnames(newdata)
        ) {
            msg <- "Unable to retrieve weights automatically from the model. Please specify `wts` argument explicitly."
            insight::format_error(msg)
        } else {
            newdata[["marginaleffects_wts_internal"]] <- newdata[[wtsname]]
            return(newdata)
        }
    }

    flag1 <- isTRUE(checkmate::check_string(wts)) &&
        isTRUE(wts %in% colnames(newdata))
    flag2 <- isTRUE(checkmate::check_numeric(wts, len = nrow(newdata)))
    if (!flag1 && !flag2) {
        msg <- sprintf(
            "The `wts` argument must be a numeric vector of length %s, or a string which matches a column name in `newdata`. If you did not supply a `newdata` explicitly, `marginaleffects` extracted it automatically from the model object, and the `wts` variable may not have been available. The easiest strategy is often to supply a data frame such as the original data to `newdata` explicitly, and to make sure that it includes an appropriate column of weights, identified by the `wts` argument.",
            nrow(newdata)
        )
        stop(msg, call. = FALSE)
    }

    # weights: before sanitize_variables
    if (isTRUE(checkmate::check_string(wts))) {
        newdata[["marginaleffects_wts_internal"]] <- newdata[[wts]]
    } else {
        newdata[["marginaleffects_wts_internal"]] <- wts
    }

    return(newdata)
}


sanitize_newdata <- function(model, newdata, by, modeldata, wts) {
    checkmate::assert(
        checkmate::check_data_frame(newdata, null.ok = TRUE),
        checkmate::check_choice(
            newdata,
            choices = c("mean", "median", "tukey", "grid", "balanced")
        ),
        combine = "or"
    )

    tmp <- build_newdata(
        model = model,
        newdata = newdata,
        by = by,
        modeldata = modeldata
    )
    newdata <- tmp[["newdata"]]
    modeldata <- tmp[["modeldata"]]

    # Issue #1327: matrix columns with single column breaks rbindlist(). See `scale()`
    newdata <- unpack_matrix_1col(newdata)

    # Issue #363
    # unpacking matrix columns works with {mgcv} but breaks {mclogit}
    if (inherits(model, "gam")) {
        newdata <- unpack_matrix_cols(newdata)
    }

    # placeholder response
    resp <- insight::find_response(model)
    if (
        isTRUE(checkmate::check_character(resp, len = 1)) &&
            !resp %in% colnames(newdata)
    ) {
        y <- hush(insight::get_response(model))
        # protect df or matrix response
        if (isTRUE(checkmate::check_atomic_vector(y))) {
            newdata[[resp]] <- y[1]
        }
    }

    # we will need this to merge the original data back in, and it is better to
    # do it in a centralized upfront way.
    if (!"rowid" %in% colnames(newdata)) {
        newdata$rowid <- seq_len(nrow(newdata))
    }

    # add weights column if available
    if (is.null(wts)) wts <- FALSE
    newdata <- add_wts_column(newdata = newdata, wts = wts, model = model)

    # otherwise we get a warning in setDT()
    if (inherits(model, "mlogit") && isTRUE(inherits(modeldata[["idx"]], "idx"))) {
        modeldata$idx <- NULL
        newdata$idx <- NULL
    }
    
    data.table::setDT(newdata)

    # attributes: misc
    attr(newdata, "explicit") <- attr(tmp$newdata, "explicit")
    attr(newdata, "newdata_modeldata") <- modeldata

    # attributes: column classes
    if (!is.null(modeldata)) {
        mc <- Filter(function(x) is.matrix(modeldata[[x]]), colnames(modeldata))
        cl <- Filter(function(x) is.character(modeldata[[x]]), colnames(modeldata))
        modeldata <- subset(modeldata, select = cl)
        cl <- lapply(modeldata, unique)
        vc <- attributes(modeldata)$marginaleffects_variable_class
        column_attributes <- list(
            "matrix_columns" = mc,
            "character_levels" = cl,
            "variable_class" = vc
        )
        newdata <- set_marginaleffects_attributes(newdata, column_attributes)
    }

    return(newdata)
}


dedup_newdata <- function(
    model,
    newdata,
    by,
    wts,
    comparison = "difference",
    cross = FALSE,
    byfun = NULL
) {
    # issue #1113: elasticities or custom functions should skip dedup because it is difficult to align x and y
    elasticities <- c("eyexavg", "eydxavg", "dyexavg")
    if (
        isTRUE(checkmate::check_choice(comparison, elasticities)) ||
            isTRUE(checkmate::check_function(comparison))
    ) {
        return(newdata)
    }

    elasticities <- c("eyex", "eydx", "dyex")
    if (!isFALSE(by) && isTRUE(checkmate::check_choice(comparison, elasticities))) {
        return(data.table(newdata))
    }

    flag <- isTRUE(checkmate::check_string(comparison, pattern = "avg"))
    if (
        !flag &&
            (isFALSE(by) || # weights only make sense when we are marginalizing
                !isFALSE(wts) ||
                !is.null(byfun) ||
                !isFALSE(cross) ||
                isFALSE(getOption("marginaleffects_dedup", default = TRUE)))
    ) {
        return(newdata)
    }

    vclass <- attr(newdata, "marginaleffects_variable_class")

    # copy to allow mod by reference later without overwriting newdata
    out <- data.table(newdata)

    dv <- hush(unlist(insight::find_response(model), use.names = FALSE))
    if (isTRUE(checkmate::check_string(dv)) && dv %in% colnames(out)) {
        out[, (dv) := NULL]
        vclass <- vclass[names(vclass) != dv]
    }

    # rowid is useless, except for intercept-only models, where we want to retain all rows
    if ("rowid" %in% colnames(out) && ncol(out) > 1) {
        out[, "rowid" := NULL]
    }

    categ <- c("factor", "character", "logical", "strata", "cluster", "binary")
    if (!all(vclass %in% categ)) {
        return(newdata)
    }

    cols <- colnames(out)
    out <- out[, .("marginaleffects_wts_internal" = .N), by = cols]

    out[["rowid_dedup"]] <- seq_len(nrow(out))
    attr(out, "marginaleffects_variable_class") <- vclass

    return(out)
}

Try the marginaleffects package in your browser

Any scripts or data that you put into this service are public.

marginaleffects documentation built on June 8, 2025, 12:44 p.m.