R/shapley_approx.R

Defines functions shapley_approx

Documented in shapley_approx

#' Shapley Value Approximation
#'
#' Approximates Shapley Values for a set of items. Shapley Values measure
#' the relative contribution each item has to the overall potential reach
#' if every item was included.
#'
#' @param data A data frame.
#' @param items Columns on which to analyze. Must contain only ones, zeros, or
#' `NA`. Suggest using [is_onezero][onezero::is_onezero] ahead of time to check.
#' @param case_weights An optional column of case weights to use in the
#' calculations. Rows with `NA` will be removed from the base.
#' @param item_weights An optional named vector of non-zero weights to associate
#' with each item. Items not specified will be given a default weight of 1.
#'
#' Examples could be profit, revenue, or simply relative weights.
#' @param depth Number of `items` needed in order to be considered "reached."
#' Can be any number between 1 to number of `items`. Default is 1.
#' @param return One of `"vector"` (default) or `"tibble"` specifying the
#' type of object to return.
#'
#' @importFrom dplyr select pull left_join coalesce
#' @importFrom purrr map_df map_dbl
#' @importFrom tibble enframe
#' @importFrom collapse fmean fsum
#' @importFrom Rfast rowsums
#' @importFrom rlang abort
#' @examples
#' shapley_approx(
#'     data = FoodSample,
#'     items = Bisque:Chili,
#'     case_weights = weight,
#'     item_weights = c(Bisque = 9.99, Chicken = 10.29, Tofu = 10.99, Chili = 7.49),
#'     depth = 1,
#'     return = "tibble"
#' )
#'
#' @export
shapley_approx <- function(
        data,
        items,
        case_weights, item_weights,
        depth = 1,
        return = "vector"
) {


    # Preliminary data checks -------------------------------------------------

    if (!is.data.frame(data)) {
        abort("Input to `data` must be a data frame.")
    }

    if (!is.numeric(depth) | length(depth) != 1) {
        abort("Input to `depth` must be a single numeric value.")
    }
    depth <- floor(depth)


    # Get names of things -----------------------------------------------------

    # `items`
    item.names <- names(eval_select(expr = enquo(items), data = data))
    n.items <- length(item.names)

    # `case_weights`
    has.weights <- FALSE

    if (!missing(case_weights)) {

        case.weights.names <- names(
            eval_select(
                expr = enquo(case_weights),
                data = data
            )
        )

        has.weights <- TRUE

    }


    # Set up item weights -----------------------------------------------------

    if (!missing(item_weights)) {

        item.wgt.names <- names(item_weights)

        if (is.null(item.wgt.names)) {
            abort("Input to `item_weights` must be a named vector.")
        }

        if (any(item.wgt.names == "")) {
            abort("Cannot have empty characters as names in `item_weights`.")
        }

        if (any(is.na(item_weights))) {
            abort("Every element of `item_weights` must be named.")
        }

        if (length(unique(item.wgt.names)) != length(item.wgt.names)) {
            abort("There cannot be duplicate names in the names of `item_weights`.")
        }

        pos.check <- all(sign(item_weights) == 1)

        if (!pos.check) {
            abort("All `item_weights` must be positive and non-zero.")
        }

        bad <- setdiff(item.wgt.names, item.names)

        if (length(bad) > 0) {
            bad.string <- paste(bad, collapse = ", ")
            msg <- glue(
                "The following items specified in `item_weights` were not included in `items` and will be ignored:\n{bad.string}"
            )
            warn(msg)
            item_weights <- item_weights[names(item_weights) %in% item.names]

        }

        item.wgt <- rep(1, times = n.items)
        names(item.wgt) <- item.names

        item.wgt.default <- enframe(
            x = item.wgt,
            name = "item",
            value = "default"
        )

        item.wgt.new <- enframe(item_weights, name = "item", value = "new")

        item.wgt <-
            item.wgt.default %>%
            left_join(item.wgt.new, by = "item") %>%
            mutate(wgt = coalesce(new, default)) %>%
            pull(wgt, name = item)

    }


    # Set up weights ----------------------------------------------------------

    if (missing(case_weights)) {

        wgt.vec <- rep(1, times = nrow(data))

    } else {

        if (length(case.weights.names) > 1) {
            abort("Can only provide one column of weights to `case_weights`.")
        }

        if (case.weights.names %in% item.names) {
            abort(glue("Column '{case.weights.names}' cannot used both in `items` and `case_weights`."))
        }

        wgt.vec <- data[[case.weights.names]]

        if (!is.numeric(wgt.vec)) {
            abort("Input to `case_weights` must be a numeric column.")
        }

    }


    # Make sure data is onezero -----------------------------------------------

    item.df <- data[item.names]

    oz.check <- dapply(item.df, is_onezero)

    oz.fail <- any(!oz.check)

    if (oz.fail) {
        bad.names <- names(oz.check[!oz.check])
        bad.names.string <- paste(bad.names, sep = ", ")
        msg <- glue(
            "All variables in `items` must contain only 0/1 data, the following do not:\n{bad.names.string}"
        )
        abort(msg)
    }


    # Check and warn about all missing rows -----------------------------------

    # Check and see if any rows have 100% missing data
    all.miss <-
        dapply(
            X = item.df,
            FUN = function(x) all(is.na(x)),
            MARGIN = 1
        ) %>%
        which()

    if (length(all.miss) > 0) {
        all.miss.string <- paste(all.miss, collapse = ", ")
        msg <- glue(
            "{length(all.miss)} rows in `data` have 100% missing values for the items specified in `items`. They will still be retained in the analysis and treated as \"unreached\". If you do not want those rows in the TURF analysis, please remove them ahead of time."
        )
        warn(msg)
    }


    # Replace NA, apply item weights ------------------------------------------

    # Replace NA with zero, makes sense since we are operating row-wise
    # for reach. This radically improves the speed of Rfast::rowsums().
    item.df[is.na(item.df)] <- 0

    if (!missing(item_weights)) {

        for (i in seq_along(1:n.items)) {

            now <- item.df[, names(item.wgt[i]), drop = TRUE]
            now[now == 1] <- item.wgt[i]
            item.df[, i] <- now

        }

    }


    # Validate `depth` --------------------------------------------------------

    if (!between(depth, 1, n.items)) {
        abort("Input to `depth` must be a value between 1 and number of `items` ({n.items}).")
    }


    # Shapley approx calculations ---------------------------------------------

    # how many times was the row reached?
    x_reach <- rowsums(as.matrix(item.df))

    # was the row reached?
    is_reached <- dapply(
        X = item.df,
        MARGIN = 1,
        FUN = function(x) sum(x != 0) >= depth
    )

    # total proportion reached
    p_reach <- fmean(is_reached, w = wgt.vec)

    # proportion out the data
    prop <- dapply(
        X = item.df,
        MARGIN = 2,
        FUN = function(x)
            ifelse(is_reached, (x * wgt.vec) / x_reach, 0)
    )

    # sum the proportioned out data
    sums <- dapply(prop, fsum)

    # share the sums
    sums1 <- sums / sum(sums)

    # multiply by total prop to get SV
    sv <- sums1 * p_reach


    # Return ------------------------------------------------------------------

    if (return == "vector") {
        return(sv)
    } else if (return == "tibble") {
        return(
            enframe(
                x = sv,
                name = "item",
                value = "shapley_value"
            )
        )
    }


}

utils::globalVariables(c(
    "new", "default", "wgt"
))
ttrodrigz/onezero documentation built on May 9, 2023, 2:59 p.m.