R/sanitize_newdata.R

Defines functions dedup_newdata sanitize_newdata clean_newdata set_newdata_attributes add_wts_column build_newdata sanitize_newdata_call

sanitize_newdata_call <- function(scall, newdata = NULL, model) {
    if (rlang::quo_is_call(scall)) {
        if (rlang::call_name(scall) %in% c("datagrid", "datagridcf", "typical", "counterfactual")) {
            if (!"model" %in% rlang::call_args_names(scall)) {
                scall <- rlang::call_modify(scall, model = model)
            }
        } else if (rlang::call_name(scall) %in% "visualisation_matrix") {
            if (!"x" %in% rlang::call_args_names(scall)) {
                scall <- rlang::call_modify(scall, x = get_modeldata)
            }
        }
        out <- rlang::eval_tidy(scall)
    } else {
        out <- newdata
    }
    return(out)
}


build_newdata <- function(model, newdata, by, modeldata) {
    if (isTRUE(checkmate::check_data_frame(by))) {
        by <- setdiff(colnames(by), "by")
    } else if (isTRUE(checkmate::check_flag(by))) {
        by <- NULL
    }
    args <- list(model = model)
    for (b in by) {
        args[[b]] <- unique
    }

    newdata_explicit <- TRUE

    # NULL -> modeldata
    if (is.null(newdata)) {
        newdata <- modeldata
        newdata_explicit <- FALSE

    # string -> datagrid()
    } else if (identical(newdata, "mean")) {
        newdata <- do.call("datagrid", args)

    } else if (identical(newdata, "median")) {
        args[["FUN_numeric"]] <- args[["FUN_integer"]] <- args[["FUN_logical"]] <- function(x) stats::median(x, na.rm = TRUE)
        newdata <- do.call("datagrid", args)

    } else if (identical(newdata, "tukey")) {
        args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
        newdata <- do.call("datagrid", args)

    } else if (identical(newdata, "grid")) {
        args[["FUN_numeric"]] <- function(x) stats::fivenum(x, na.rm = TRUE)
        args[["FUN_factor"]] <- args[["FUN_character"]] <- args[["FUN_logical"]] <- unique
        newdata <- do.call("datagrid", args)

    # grid with all unique values of categorical variables, and numerics at their means
    } else if (identical(newdata, "marginalmeans")) {
        args[["FUN_factor"]] <- args[["FUN_character"]] <- args[["FUN_logical"]] <- unique
        newdata <- do.call("datagrid", args)
        # Issue #580: outcome should not duplicate grid rows
        dv <- hush(insight::find_response(model))
        if (isTRUE(dv %in% colnames(newdata))) {
            newdata[[dv]] <- get_mean_or_mode(newdata[[dv]])
            newdata <- unique(newdata)
        }
    }

    if (!inherits(newdata, "data.frame")) {
        msg <- "Unable to extract the data from model of class `%s`. This can happen in a variety of cases, such as when a `marginaleffects` package function is called from inside a user-defined function, or using an `*apply()`-style operation on a list. Please supply a data frame explicitly via the `newdata` argument."
        msg <- sprintf(msg, class(model)[1])
        insight::format_error(msg)
    }

    # otherwise we get a warning in setDT()
    if (inherits(model, "mlogit") && isTRUE(inherits(modeldata[["idx"]], "idx"))) {
        modeldata$idx <- NULL
    }

    # required by some model-fitting functions
    data.table::setDT(modeldata)

    # required for the type of column indexing to follow
    data.table::setDF(newdata)

    out <- list(
        "newdata" = newdata,
        "newdata_explicit" = newdata_explicit,
        "modeldata" = modeldata
    )
    return(out)
}


add_wts_column <- function(wts, newdata) {
    # weights must be available in the `comparisons()` function, NOT in
    # `tidy()`, because comparisons will often duplicate newdata for
    # multivariate outcomes and the like. We need to track which row matches
    # which.
    if (!is.null(wts)) {
        flag1 <- isTRUE(checkmate::check_string(wts)) && isTRUE(wts %in% colnames(newdata))
        flag2 <- isTRUE(checkmate::check_numeric(wts, len = nrow(newdata)))
        if (!flag1 && !flag2) {
            msg <- sprintf("The `wts` argument must be a numeric vector of length %s, or a string which matches a column name in `newdata`. If you did not supply a `newdata` explicitly, `marginaleffects` extracted it automatically from the model object, and the `wts` variable may not have been available. The easiest strategy is often to supply a data frame such as the original data to `newdata` explicitly, and to make sure that it includes an appropriate column of weights, identified by the `wts` argument.", nrow(newdata))
            stop(msg, call. = FALSE)
        }
    }
    
    # weights: before sanitize_variables
    if (!is.null(wts) && isTRUE(checkmate::check_string(wts))) {
        newdata[["marginaleffects_wts_internal"]] <- newdata[[wts]]
    } else {
        newdata[["marginaleffects_wts_internal"]] <- wts
    }

    return(newdata)
}


set_newdata_attributes <- function(model, modeldata, newdata, newdata_explicit) {
    attr(newdata, "newdata_explicit") <- newdata_explicit

    # column classes
    mc <- Filter(function(x) is.matrix(modeldata[[x]]), colnames(modeldata))
    cl <- Filter(function(x) is.character(modeldata[[x]]), colnames(modeldata))
    cl <- lapply(modeldata[, ..cl], unique)
    vc <- attributes(modeldata)$marginaleffects_variable_class
    column_attributes <- list(
        "matrix_columns" = mc,
        "character_levels" = cl,
        "variable_class" = vc)
    newdata <- set_marginaleffects_attributes(newdata, column_attributes, prefix = "newdata_")

    # {modelbased} sometimes attaches useful attributes
    exclude <- c("class", "row.names", "names", "data", "reference")
    modelbased_attributes <- get_marginaleffects_attributes(newdata, exclude = exclude)
    newdata <- set_marginaleffects_attributes(newdata, modelbased_attributes, prefix = "newdata_")

    # original data
    attr(newdata, "newdata_modeldata") <- modeldata

    if (is.null(attr(newdata, "marginaleffects_variable_class"))) {
        newdata <- set_variable_class(newdata, model = model)
    }

    return(newdata)
}


clean_newdata <- function(model, newdata) {
    # rbindlist breaks on matrix columns
    idx <- sapply(newdata, function(x) class(x)[1] == "matrix")
    if (any(idx)) {
        # Issue #363
        # unpacking matrix columns works with {mgcv} but breaks {mclogit}
        if (inherits(model, "gam")) {
            newdata <- unpack_matrix_cols(newdata)
        } else {
            newdata <- newdata[, !idx, drop = FALSE]
        }
    }

    # we will need this to merge the original data back in, and it is better to
    # do it in a centralized upfront way.
    if (!"rowid" %in% colnames(newdata)) {
        newdata$rowid <- seq_len(nrow(newdata))
    }

    # mlogit: each row is an individual-choice, but the index is not easily
    # trackable, so we pre-sort it here, and the sort in `get_predict()`. We
    # need to cross our fingers, but this probably works.
    if (inherits(model, "mlogit") && isTRUE(inherits(newdata[["idx"]], "idx"))) {
        idx <- list(newdata[["idx"]][, 1], newdata[["idx"]][, 2])
        newdata <- newdata[order(newdata[["idx"]][, 1], newdata[["idx"]][, 2]),]
    }

    # placeholder response
    resp <- insight::find_response(model)
    if (isTRUE(checkmate::check_character(resp, len = 1)) &&
        !resp %in% colnames(newdata)) {
        y <- hush(insight::get_response(model))
        # protect df or matrix response
        if (isTRUE(checkmate::check_atomic_vector(y))) {
            newdata[[resp]] <- y[1]
        }
    }
    return(newdata)
}


sanitize_newdata <- function(model, newdata, by, modeldata, wts) {
    checkmate::assert(
        checkmate::check_data_frame(newdata, null.ok = TRUE),
        checkmate::check_choice(newdata, choices = c("mean", "median", "tukey", "grid", "marginalmeans")),
        combine = "or")
    tmp <- build_newdata(model = model, newdata = newdata, by = by, modeldata = modeldata)
    newdata <- tmp[["newdata"]]
    modeldata <- tmp[["modeldata"]]
    newdata_explicit <- tmp[["newdata_explicit"]]
    newdata <- clean_newdata(model, newdata)
    newdata <- add_wts_column(newdata = newdata, wts = wts)
    newdata <- set_newdata_attributes(
        model = model,
        modeldata = modeldata,
        newdata = newdata,
        newdata_explicit = newdata_explicit)

    # sort rows of output when the user explicitly calls `by` or `datagrid()`
    # otherwise, we return the same data frame in the same order, but 
    # here it makes sense to sort for a clean output.
    sortcols <- attr(newdata, "newdata_variables_datagrid")
    if (isTRUE(checkmate::check_character(by))) {
        sortcols <- c(by, sortcols)
    }
    sortcols <- intersect(sortcols, colnames(newdata))
    out <- data.table::copy(newdata)
    if (length(sortcols) > 0) {
        data.table::setorderv(out, cols = sortcols)
    }

    return(out)
}


dedup_newdata <- function(model, newdata, by, wts, comparison = "difference", cross = FALSE, byfun = NULL) {

    flag <- isTRUE(checkmate::check_string(comparison, pattern = "avg"))
    if (!flag && (
        isFALSE(by) || # weights only make sense when we are marginalizing
        !is.null(wts) ||
        !is.null(byfun) ||
        !isFALSE(cross) ||
        isFALSE(getOption("marginaleffects_dedup", default = TRUE)))) {
        return(newdata)
    }
    
    vclass <- attr(newdata, "marginaleffects_variable_class")

    # copy to allow mod by reference later without overwriting newdata
    out <- data.table(newdata)

    dv <- hush(unlist(insight::find_response(model), use.names = FALSE))
    if (isTRUE(checkmate::check_string(dv)) && dv %in% colnames(out)) {
        out[, (dv) := NULL]
        vclass <- vclass[names(vclass) != dv]
    }

    if ("rowid" %in% colnames(out)) {
        out[, "rowid" := NULL]
    }
    
    categ <- c("factor", "character", "logical", "strata", "cluster", "binary")
    if (!all(vclass %in% categ)) {
        return(newdata)
    }
    
    cols <- colnames(out)
    out <- out[, .("marginaleffects_wts_internal" = .N), by = cols]
    data.table::setDF(out)
    
    out[["rowid_dedup"]] <- seq_len(nrow(out))
    attr(out, "marginaleffects_variable_class") <- vclass
    
    return(out)
}

Try the marginaleffects package in your browser

Any scripts or data that you put into this service are public.

marginaleffects documentation built on Oct. 20, 2023, 1:07 a.m.