#' Shapley Value Approximation
#'
#' Approximates Shapley Values for a set of items. Shapley Values measure
#' the relative contribution each item has to the overall potential reach
#' if every item was included.
#'
#' @param data A data frame.
#' @param items Columns on which to analyze. Must contain only ones, zeros, or
#' `NA`. Suggest using [is_onezero][onezero::is_onezero] ahead of time to check.
#' @param case_weights An optional column of case weights to use in the
#' calculations. Rows with `NA` will be removed from the base.
#' @param item_weights An optional named vector of non-zero weights to associate
#' with each item. Items not specified will be given a default weight of 1.
#'
#' Examples could be profit, revenue, or simply relative weights.
#' @param depth Number of `items` needed in order to be considered "reached."
#' Can be any number between 1 to number of `items`. Default is 1.
#' @param return One of `"vector"` (default) or `"tibble"` specifying the
#' type of object to return.
#'
#' @importFrom dplyr select pull left_join coalesce
#' @importFrom purrr map_df map_dbl
#' @importFrom tibble enframe
#' @importFrom collapse fmean fsum
#' @importFrom Rfast rowsums
#' @importFrom rlang abort
#' @examples
#' shapley_approx(
#' data = FoodSample,
#' items = Bisque:Chili,
#' case_weights = weight,
#' item_weights = c(Bisque = 9.99, Chicken = 10.29, Tofu = 10.99, Chili = 7.49),
#' depth = 1,
#' return = "tibble"
#' )
#'
#' @export
shapley_approx <- function(
data,
items,
case_weights, item_weights,
depth = 1,
return = "vector"
) {
# Preliminary data checks -------------------------------------------------
if (!is.data.frame(data)) {
abort("Input to `data` must be a data frame.")
}
if (!is.numeric(depth) | length(depth) != 1) {
abort("Input to `depth` must be a single numeric value.")
}
depth <- floor(depth)
# Get names of things -----------------------------------------------------
# `items`
item.names <- names(eval_select(expr = enquo(items), data = data))
n.items <- length(item.names)
# `case_weights`
has.weights <- FALSE
if (!missing(case_weights)) {
case.weights.names <- names(
eval_select(
expr = enquo(case_weights),
data = data
)
)
has.weights <- TRUE
}
# Set up item weights -----------------------------------------------------
if (!missing(item_weights)) {
item.wgt.names <- names(item_weights)
if (is.null(item.wgt.names)) {
abort("Input to `item_weights` must be a named vector.")
}
if (any(item.wgt.names == "")) {
abort("Cannot have empty characters as names in `item_weights`.")
}
if (any(is.na(item_weights))) {
abort("Every element of `item_weights` must be named.")
}
if (length(unique(item.wgt.names)) != length(item.wgt.names)) {
abort("There cannot be duplicate names in the names of `item_weights`.")
}
pos.check <- all(sign(item_weights) == 1)
if (!pos.check) {
abort("All `item_weights` must be positive and non-zero.")
}
bad <- setdiff(item.wgt.names, item.names)
if (length(bad) > 0) {
bad.string <- paste(bad, collapse = ", ")
msg <- glue(
"The following items specified in `item_weights` were not included in `items` and will be ignored:\n{bad.string}"
)
warn(msg)
item_weights <- item_weights[names(item_weights) %in% item.names]
}
item.wgt <- rep(1, times = n.items)
names(item.wgt) <- item.names
item.wgt.default <- enframe(
x = item.wgt,
name = "item",
value = "default"
)
item.wgt.new <- enframe(item_weights, name = "item", value = "new")
item.wgt <-
item.wgt.default %>%
left_join(item.wgt.new, by = "item") %>%
mutate(wgt = coalesce(new, default)) %>%
pull(wgt, name = item)
}
# Set up weights ----------------------------------------------------------
if (missing(case_weights)) {
wgt.vec <- rep(1, times = nrow(data))
} else {
if (length(case.weights.names) > 1) {
abort("Can only provide one column of weights to `case_weights`.")
}
if (case.weights.names %in% item.names) {
abort(glue("Column '{case.weights.names}' cannot used both in `items` and `case_weights`."))
}
wgt.vec <- data[[case.weights.names]]
if (!is.numeric(wgt.vec)) {
abort("Input to `case_weights` must be a numeric column.")
}
}
# Make sure data is onezero -----------------------------------------------
item.df <- data[item.names]
oz.check <- dapply(item.df, is_onezero)
oz.fail <- any(!oz.check)
if (oz.fail) {
bad.names <- names(oz.check[!oz.check])
bad.names.string <- paste(bad.names, sep = ", ")
msg <- glue(
"All variables in `items` must contain only 0/1 data, the following do not:\n{bad.names.string}"
)
abort(msg)
}
# Check and warn about all missing rows -----------------------------------
# Check and see if any rows have 100% missing data
all.miss <-
dapply(
X = item.df,
FUN = function(x) all(is.na(x)),
MARGIN = 1
) %>%
which()
if (length(all.miss) > 0) {
all.miss.string <- paste(all.miss, collapse = ", ")
msg <- glue(
"{length(all.miss)} rows in `data` have 100% missing values for the items specified in `items`. They will still be retained in the analysis and treated as \"unreached\". If you do not want those rows in the TURF analysis, please remove them ahead of time."
)
warn(msg)
}
# Replace NA, apply item weights ------------------------------------------
# Replace NA with zero, makes sense since we are operating row-wise
# for reach. This radically improves the speed of Rfast::rowsums().
item.df[is.na(item.df)] <- 0
if (!missing(item_weights)) {
for (i in seq_along(1:n.items)) {
now <- item.df[, names(item.wgt[i]), drop = TRUE]
now[now == 1] <- item.wgt[i]
item.df[, i] <- now
}
}
# Validate `depth` --------------------------------------------------------
if (!between(depth, 1, n.items)) {
abort("Input to `depth` must be a value between 1 and number of `items` ({n.items}).")
}
# Shapley approx calculations ---------------------------------------------
# how many times was the row reached?
x_reach <- rowsums(as.matrix(item.df))
# was the row reached?
is_reached <- dapply(
X = item.df,
MARGIN = 1,
FUN = function(x) sum(x != 0) >= depth
)
# total proportion reached
p_reach <- fmean(is_reached, w = wgt.vec)
# proportion out the data
prop <- dapply(
X = item.df,
MARGIN = 2,
FUN = function(x)
ifelse(is_reached, (x * wgt.vec) / x_reach, 0)
)
# sum the proportioned out data
sums <- dapply(prop, fsum)
# share the sums
sums1 <- sums / sum(sums)
# multiply by total prop to get SV
sv <- sums1 * p_reach
# Return ------------------------------------------------------------------
if (return == "vector") {
return(sv)
} else if (return == "tibble") {
return(
enframe(
x = sv,
name = "item",
value = "shapley_value"
)
)
}
}
utils::globalVariables(c(
"new", "default", "wgt"
))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.