test-code/turf-working/turf.R

library(tidyverse)
library(tidyselect)
library(arrangements)
library(onezero)
library(collapse)
library(Rfast)

# Once this is done, want to create S3 print and summary methods, see:
# https://njtierney.github.io/r/missing%20data/rbloggers/2016/11/06/simple-s3-methods/

turf <- function(
        data, cols, weight,
        k, depth = 1, min.prop,
        force.in, mutually.exclude,
        brute = Inf, greedy.entry = "reach"
) {

    # Parse out data and weights ----------------------------------------------

    # In this section, the data used in the actual turf analysis is parsed out
    # from the data set provided, and a vector of weights is either extracted
    # from the data, or is created if not provided.

    # Grab the data needed for the analysis
    item.df <- select(data, {{cols}})

    # Check and make sure the data is "onezero"
    oz.check <- sapply(item.df, is_onezero)

    bad.vars <- names(oz.check[!oz.check])

    if (length(bad.vars) > 0) {

        bad.vars.message <- paste0(
            "The following variables do not meet the requirements of `is_onezero`:\n",
            paste(bad.vars, collapse = ", ")
        )

        stop(bad.vars.message)

    }

    # Grab the names of the items
    item.names <- names(item.df)
    num.items <- length(item.names)

    # Do weights exist? If so, grab them, if not, make them.
    if (missing(weight)) {

        ss <- nrow(data)
        wgt.vec <- rep(1, times = ss)

    } else {

        wgt.df <- select(data, {{weight}})

        if (ncol(wgt.df) > 1) {
            stop("Can only provide one column of weights in `weight` argument.")
        }

        wgt.name <- names(wgt.df)

        if (wgt.name %in% item.names) {
            warning(paste0(
                "Column '",
                wgt.name,
                "' was supplied as an input to both `cols` and `weights` arguments, this is likely ill-advised."
            ))
        }

        wgt.vec <- pull(wgt.df, {{weight}})

    }

    wgt.base <- sum(wgt.vec, na.rm = TRUE)


    # Error checks ------------------------------------------------------------

    # k can't be bigger than number of columns, or less than 1
    if (!all(between(k, 1, num.items))) {
        stop(paste0(
            "Input to `k` must contain values between 1 and number of columns provided in `cols` (",
            num.items,
            ")."
        ))
    }

    # depth can't be bigger than k
    if (any(depth > k)) {
        stop("Input to `depth` cannot exceed `k`. Doing so would result in a reach of zero.")
    }


    # Minimum proportion ------------------------------------------------------

    # This chunk identifies variables that do not meet the minimum proportion
    # threshold set by `min.prop`. These variables will be removed from the
    # `item.names` object UNLESS they are also included in `force.in`. Forcing
    # a column/item in will override this argument. The vector of items to
    # exclude due to `min.prop` will be updated in the section regarding
    # forced inclusions.

    if (!missing(min.prop)) {

        item.props <- sapply(
            X = item.df,
            FUN = weighted.mean, w = wgt.vec, na.rm = TRUE
        )

        # The object containing item names can be directly subsetted by
        # this object since it is logical. Not first before handling forced
        # inclusions.
        item.meets.min.prop <- item.props >= min.prop

    } else {
        # item.meets.min.prop <- rep(TRUE, times = num.items)
        item.meets.min.prop <- !logical(length = num.items)
        names(item.meets.min.prop) <- item.names
    }


    # Forced inclusions -------------------------------------------------------

    # This chunk identifies the variables that must be present in every combo.
    # Note that the `item.meets.min.prop` object will need to be updates based
    # on this section.

    # Note that forced inclusions trump variables that would have been dropped
    # due to minimum proportion. This allows users to specify min prop but
    # have desired variables be unaffected by this cutoff. The
    # item.meets.min.prop object gets overridden by the forced inclusions
    # towards the end of this chunk.

    if (!missing(force.in)) {

        .force.in <- enquo(force.in)
        force.index <- eval_select(
            expr = .force.in,
            data = item.df
        )

        force.names <- names(force.index)

        # if k is less than the number of items forced in then there will be no
        # valid combinations to run
        if (any(k < length(force.names))) {
            stop(paste0(
                "Input to `k` must be greater than or equal to the number of items being forced in (",
                length(force.names),
                "), otherwise no valid combinations will be available."
            ))
        }

        bad.names <- force.names[!force.names %in% item.names]

        if (length(bad.names) > 0) {
            stop(paste0(
                "Invalid input to `force.in`, columns supplied in `force.in` must also be present in input to `cols`. The following columns must be added to `cols` if you want to force them in:\n",
                paste(bad.names, collapse = ", ")
            ))
        }

        # Update `item.meets.min.prop` if needed
        if (!all(item.meets.min.prop)) {

            min.prop.drop.names <- names(item.meets.min.prop[!item.meets.min.prop])

            if (any(min.prop.drop.names %in% force.names)) {

                min.prop.override <- min.prop.drop.names[min.prop.drop.names %in% force.names]

                warning(paste0(
                    "Inputs to `force.in` override results of `min.prop`. The following variables did not meet minimum proportion requirements but will still be kept in because they were requested to be forced in:\n",
                    paste(min.prop.override, sep = ", ")
                ))

                # update the `item.meets.min.prop`
                item.meets.min.prop[min.prop.override] <- TRUE

            }

        }

    }


    # Mutual exclusions -------------------------------------------------------

    # The user can pass an arbitray number of quoted or tidyselect expressions
    # to the `mutually.exclude` argument by passing them to `exclusions()`
    # Need to capture the expressions included in the `...` and turn them into
    # a list of variable names

    # Using the original data as the evaluation data for `eval_select()` because
    # the user could have picked locations and I want to make sure those reflect
    # the data that was originally input.

    if (!missing(mutually.exclude)) {

        me.list <-
            map(
                .x = mutually.exclude,
                .f = ~eval_select(
                    expr = .x,
                    data = data
                )
            ) %>%
            map(names)

    }


    # Prepare for turf --------------------------------------------------------

    # Make sure only the final items are getting included.
    # Note that the item.meets.min.prop object may have been overridden by
    # forced inclusions.
    item.names <- item.names[item.meets.min.prop]
    num.items  <- length(item.names)
    item.df <- item.df[item.names]

    # The `Rfast::rowsums()` function only operates on matrices, and for some
    # reason I also replaced missing values with zero. I don't remember why
    # but the results are equivalent with `base::rowSums()` so I'm sticking
    # to it.
    item.mat <- as.matrix(item.df)
    item.mat[is.na(item.mat)] <- 0


    # Shapley approximation ---------------------------------------------------

    if (missing(weight)) {
        shap.approx <-
            shapley_approx(
                data = data,
                cols = all_of(item.names),
                need = 1,
                tidy = TRUE
            ) %>%
            arrange(desc(shapley_approx))
    } else {

        shap.approx <-
            shapley_approx(
                data = data,
                cols = all_of(item.names),
                weight = {{weight}},
                need = 1,
                tidy = TRUE
            ) %>%
            arrange(desc(shapley_approx))
    }


    # Do the turf -------------------------------------------------------------

    # Are we doing shapley too?
    # Shapley is only done if the following things happen:
    #   1. All sizes of k from 1:num.items
    #   2. No forced inclusions
    #   3. No mutually exclusive items
    #   4. Brute force all combos

    do.shapley <-
        identical(k[1:num.items], 1:num.items) &
        (missing(force.in) & missing(mutually.exclude)) &
        brute >= max(k) &
        depth == 1



    # Are we doing incremental reach?
    do.incremental.reach <-
        identical(k[1:num.items], 1:num.items) &
        (missing(force.in) & missing(mutually.exclude)) &
        brute <= 1 &
        depth == 1 &
        greedy.entry %in% c("reach", "shapley")


    # Store results here
    # For reach values
    reach.list <- vector(
        mode = "list",
        length = length(k)
    )
    names(reach.list) <- paste0("k", k[1:length(reach.list)])

    # For the with/without reach values for Shapley
    if (do.shapley) {
        with.without.reach.list <- reach.list
    }

    # For incremental reach
    if (do.incremental.reach) {
        inc.reach <- vector(
            mode = "double",
            length = num.items
        )
        names(inc.reach) <- item.names
    }

    # loop thru each `k`
    for (i in seq_along(k)) {

        # skip the iteration if k > number of items
        # need to give a warning message too
        # this happens if items are dropped due to min.prop
        if (k[i] > num.items) {

            warning(paste0(
                "Unable to contiue past set size k = ",
                k[i],
                " due to items being dropped as a result of `min.prop`."
            ))
            break

        }

        # Full set of combinations --
        # Devon may step in here
        combos <- combinations(
            x = item.names,
            k = k[i]
        )


        # Set combo column names --
        colnames(combos) <- paste0("i", 1:k[i])

        # Reduce combos by inclusions --
        if (!missing(force.in)) {

            keep.combos <- apply(combos, 1, function(x) all(force.names %in% x))

            combos <- combos[keep.combos, , drop = FALSE]

        }

        # Reduce combos by mutual exclusions --
        # Mutual exclusions trump force ins
        if (!missing(mutually.exclude)) {

            for (me in seq_along(me.list)) {

                if (i == 3) browser()

                me.names <- me.list[[me]]

                drop.combos <- apply(
                    X = combos,
                    MARGIN = 1,
                    FUN = function(x) all(me.names %in% x)
                )

                combos <- combos[!drop.combos, , drop = FALSE]

            }

        }

        # If greedying...
        if (k[i] > brute & i > 1) {

            if (greedy.entry == "reach") {

                # Find the best combo from k-1
                item.keep.greedy <-
                    reach.list[[k[i]-1]] %>%
                    head(1) %>%
                    select(matches("i\\d")) %>%
                    unlist()

                # should even the first item be chosen by shapley?
            } else if (greedy.entry == "shapley") {

                item.keep.greedy <-
                    shap.approx %>%
                    head(k[i]-1) %>%
                    pull(1)


            }

            keep.combos <- apply(combos, 1, function(x) all(item.keep.greedy %in% x))

            combos <- combos[keep.combos, , drop = FALSE]


        }

        # Total number of combinations after forced inclusions and
        # mutual exclusions --
        n.combos <- nrow(combos)

        rlang::inform(glue::glue(
            "Set size {k[i]}: {n.combos} combinations"
        ))

        # skip iteration if there are no combos, need a warning message
        if (n.combos == 0) {
            rlang::inform(paste0(
                "There are no remaining combinations for set size k = ",
                k[i],
                ", skipping this iteration."
            ))
            next
        }

        # This matrix receives the reach and freq calcs
        fill <- matrix(
            data = NA,
            ncol = 2,
            nrow = n.combos,
            dimnames = list(
                NULL,
                c("reach", "freq")
            )
        )

        # calculate reach
        for (j in 1:nrow(combos)) {

            n.reached <- rowsums(item.mat[, combos[j, ], drop = FALSE])
            is.reached <- n.reached >= depth
            reach <- fmean(x = is.reached, w = wgt.vec)

            # denominator is now calculated outside the loop
            # since it does not change
            freq  <- sum(wgt.vec * n.reached) / wgt.base

            fill[j, ] <- c(reach, freq)

        }

        # Join the reach measures to the combinations
        reach.stats <-
            combos %>%
            as_tibble() %>%
            bind_cols(as_tibble(fill)) %>%
            rowid_to_column("combo") %>%
            arrange(desc(reach), desc(freq)) %>%
            add_column(
                k = k[i],
                .before = 1
            )

        # If all combinations of all items are calculated then shapley values
        # are available to be calculated. The first step in this is to calculate
        # the mean reach for every size `k` with and without each item.
        if (do.shapley) {

            # calculate the reach with/without each item
            item.cols <- paste0("i", 1:i)
            only.items <- reach.stats[item.cols]

            # This function calculates the mean with and without the item present
            # in the combination, used next to apply over the vector of item names
            with_without_reach <- function(cn) {
                tapply(
                    X = reach.stats$reach,
                    INDEX = apply(only.items, 1, function(x) ifelse(cn %in% x, "with", "without")),
                    FUN = mean
                )
            }

            with.without.reach <-
                item.names %>%
                lapply(with_without_reach) %>%
                do.call(rbind, .) %>%
                as_tibble() %>%
                add_column(
                    item = item.names,
                    .before = 1
                ) %>%
                add_column(
                    k = i,
                    .before = 1
                )


            # the `with_without_reach` function does not have a "without" entry
            # when it is a "k choose k"
            if (i == length(item.names)) {
                with.without.reach$without <- 0
            }

        } else {
            with.without.reach <- NULL
        }

        # store in the list
        reach.list[[i]] <- reach.stats

        if (do.shapley) {
            with.without.reach.list[[i]] <- with.without.reach
        }

    }

    reach.list <- discard(reach.list, is.null)


    # Calculate Shapley values and return -------------------------------------

    # To finish off calculating the shapley values, take the mean reach with
    # each variable and subtract from it the mean reach without each variable,
    # the shapley value is the mean of that difference.
    if (!do.shapley & !do.incremental.reach) {

        return(reach.list)

    }

    if (do.shapley) {

        shapley_values <-
            with.without.reach.list %>%
            reduce(bind_rows) %>%
            fmutate(gap = with - without) %>%
            fgroup_by(item) %>%
            fsummarise(shapley_value = fmean(gap)) %>%
            fungroup() %>%
            arrange(desc(shapley_value))

        return(list(
            reach = reach.list,
            shapley_values = shapley_values
        ))

    }

    if (do.incremental.reach) {

        inc <-

            # only need first row, they are already
            # sorted from high to low
            reach.list %>%
            map(head, 1) %>%

            # bind everything and grab cols needed
            bind_rows() %>%
            select(k, matches("i\\d"), reach) %>%

            # stack on the item
            pivot_longer(
                cols = matches("i\\d"),
                names_to = "item_num",
                values_to = "item"
            ) %>%
            drop_na() %>%

            # a group/filter with filtering for the first
            # rownumber will keep only the first instance of
            # each item
            group_by(item) %>%
            filter(row_number() == 1) %>%
            ungroup() %>%
            select(k, item, reach) %>%

            # calculate the incremental gain from adding that
            # variable to the list
            mutate(
                incremental = reach - lag(reach)
            )

        return(list(
            reach = reach.list,
            incremental = inc
        ))

    }

}



x <- turf(
    data = FoodSample,
    cols = Bisque:NYStrip,
    k = 1:4,
    min.prop = 0.2,
    mutually.exclude = exclusions(
        1:2,
        contains("e")
    ),
    brute = Inf,
    greedy.entry = "shapley"
)


exclusions <- function(...) quos(...)
ttrodrigz/onezero documentation built on May 9, 2023, 2:59 p.m.