R/hypothesis_helper.R

Defines functions specify_hypothesis

Documented in specify_hypothesis

#' (EXPERIMENTAL) Complex aggregation and test functions for the `hypothesis` argument
#'
#' @description
#' Warning: This function is experimental. It may be renamed, the user interface may change, or the functionality may migrate to arguments in other `marginaleffects` functions.
#'
#' This function creates aggregation and test functions for use with the `hypothesis` argument in `marginaleffects` functions like `predictions()`, `slopes()`, and `comparisons()`. The benefit of this function is that it handles a lot of the "boilerplate" code such as label creation and transformations by subgroups.
#'
#' @param hypothesis String or Function. Compute a test statistic. 
#' - String: "reference" or "sequential"
#' - Function: Accepts a single argument named `estimate` and returns a numeric vector.
#' @param by Character vector. Variable names which indicate subgroups in which the `hypothesis` function should be applied.
#' @param label Function. Accepts a vector of row labels and combines them to create hypothesis labels. 
#' @param label_columns Character vector. Column names to use for hypothesis labels. Default is `c("group", "term", "rowid", attr(x, "variables_datagrid"), attr(x, "by"))`.
#' @return `specify_hypothesis()` is a "function factory", which means that executing it will return a function suitable for use in the `hypothesis` argument of a `marginaleffects` function.
#' @export
specify_hypothesis <- function(
    hypothesis = "reference",
    label = NULL,
    label_columns = NULL,
    by = c("term", "group", "contrast")) {

    checkmate::assert_character(by, null.ok = TRUE)
    checkmate::assert_function(label, null.ok = TRUE)
    checkmate::assert_character(label_columns, null.ok = TRUE)
    checkmate::assert(
        checkmate::check_function(hypothesis),
        checkmate::check_choice(hypothesis, choices = c("reference", "sequential"))
    )

    if (is.null(label)) label <- function(x) "custom"

    if (identical(hypothesis, "reference")) {
        hypothesis <- function(x) (x - x[1])[2:length(x)]
        label <- function(x) sprintf("(%s) - (%s)", x, x[1])[2:length(x)]
    } else if (identical(hypothesis, "sequential")) {
        hypothesis <- function(x) (x - data.table::shift(x))[2:length(x)]
        label = function(x) sprintf("(%s) - (%s)", x, data.table::shift(x))[2:length(x)]
    }

    fun <- function(x) {

        x <- data.table::copy(x)
        estimate <- x$estimate

        # automatic by argument
        if (is.null(by)) {
            by <- grep("^term$|^contrast|^group$", colnames(x), value = TRUE)
            if (length(by) == 0) by <- NULL
        } else {
            bad <- setdiff(by, c(colnames(x), "term", "group", "contrast", "rowid"))
            if (length(bad) > 0) {
                msg <- sprintf("Missing column(s): %s", paste(bad, collapse = ", "))
                insight::format_error(msg)
            }
            by <- intersect(by, colnames(x))
            if (length(by) == 0) by <- NULL
        }

        # row labels
        if (!"rowid" %in% colnames(x)) x[, "rowid" := seq_len(.N)]
        if (is.null(label_columns)) {
            label_columns <- c("group", "term", "rowid", attr(x, "variables_datagrid"), attr(x, "by"))
        }
        label_columns <- setdiff(label_columns, setdiff(by, "rowid"))
        label_columns <- intersect(label_columns, colnames(x))
        if (length(label_columns) == 0) label_columns <- "rowid"

        tmp <- x[, ..label_columns]
        for (col in colnames(tmp)) {
            tmp[, (col) := sprintf("%s[%s]", col, tmp[[col]])]
        }
        tmp <- apply(tmp, 1, paste, collapse = ", ")
        x[, marginaleffects_internal_label := tmp]

        if (is.null(by)) {
            out <- x[, list(
                hypothesis = label(marginaleffects_internal_label),
                estimate = hypothesis(estimate))]
        } else {
            out <- x[, list(
                hypothesis = label(marginaleffects_internal_label),
                estimate = hypothesis(estimate)),
            by = by]
        }

        attr(out, "hypothesis_function_by") <- by
        return(out)
    }

    return(fun)
}

Try the marginaleffects package in your browser

Any scripts or data that you put into this service are public.

marginaleffects documentation built on May 29, 2024, 4:03 a.m.