test-code/turf2.R

library(rlang)
library(glue)
library(tidyselect)
library(tidyverse)
library(collapse)
library(arrangements)
library(Rfast)
library(cli)
library(onezero)

turf2 <- function(
        data, turf_cols, case_weights,
        k = 1, depth = 1,
        force_in, force_in_together,
        force_out, force_out_together,
        greedy_begin = 20, greedy_entry = "shapley"
) {

    # Preliminary error checks ------------------------------------------------

    if (!is.data.frame(data)) {
        abort("Input to `data` must be a data frame.")
    }

    if (!is.numeric(k)) {
        abort("Input to `k` must be one or more numeric values.")
    }

    if (!is.numeric(depth) | length(depth) != 1) {
        abort("Input to `depth` must be a single numeric value.")
    }

    if (!is.numeric(greedy_begin) | length(greedy_begin) != 1) {
        abort("Input to `greedy_begin` must be a single numeric value.")
    }

    if (!greedy_entry %in% c("shapley", "reach")) {
        abort("Input to `greedy_entry` must be one of 'shapley' or 'reach'.")
    }

    if (!greedy_entry == "shapley" & !missing(force_out_together)) {
        abort("'shapley' greedy entry cannot be combined with mutual exclusions.")
    }


    # Get names of things -----------------------------------------------------

    # `turf_cols` --
    item.names <- names(eval_select(expr = enquo(turf_cols), data = data))

    # `case_weights` --
    has.weights <- FALSE

    if (!missing(case_weights)) {

        case.weights.names <- names(
            eval_select(
                expr = enquo(case_weights),
                data = data
            )
        )

        has.weights <- TRUE

    }


    # `force_in` --
    do.force.in <- FALSE
    force.in.names <- character()

    if (!missing(force_in)) {

        force.in.names <-
            force_in %>%
            map(eval_select, data = data) %>%
            flatten_int()  %>%
            names() %>%
            unique()

        do.force.in <- TRUE

    }

    # `force_in_together` --
    do.force.in.together <- FALSE
    force.in.together.names <- list()

    if (!missing(force_in_together)) {

        force.in.together.names <-
            map(
                .x = force_in_together,
                .f = ~eval_select(
                    expr = .x,
                    data = data
                )
            ) %>%
            map(names) %>%
            discard(~length(.x) <= 1) %>%
            # this removes accidental duplicates
            map(sort) %>%
            unique()

        do.force.in.together <- length(force.in.together.names) > 0

    }


    # `force_out` --
    do.force.out <- FALSE
    force.out.names <- character()

    if (!missing(force_out)) {

        force.out.names <-
            force_out %>%
            map(eval_select, data = data) %>%
            flatten_int()  %>%
            names() %>%
            unique()

        do.force.out <- TRUE
    }

    # `force_out_together` --
    do.force.out.together <- FALSE
    force.out.together.names <- list()

    # `force_out_together`
    if (!missing(force_out_together)) {

        force.out.together.names <-
            map(
                .x = force_out_together,
                .f = ~eval_select(
                    expr = .x,
                    data = data
                )
            ) %>%
            map(names) %>%
            discard(~length(.x) <= 1) %>%
            # this removes accidental duplicates
            map(sort) %>%
            unique()

        do.force.out.together <- length(force.out.together.names) > 0

    }


    # Set up weights ----------------------------------------------------------

    if (missing(case_weights)) {

        wgt.vec <- rep(1, times = nrow(data))

    } else {

        if (length(case.weights.names) > 1) {
            abort("Can only provide one column of weights to `case_weights`.")
        }

        if (case.weights.names %in% item.names) {
            abort(glue("Column '{case.weights.names}' cannot used both in `turf_cols` and `case_weights`."))
        }

        wgt.vec <- data[[case.weights.names]]

        if (!is.numeric(wgt.vec)) {
            abort("Input to `case_weights` must be a numeric column.")
        }

    }

    # Use this as the denominator for reach calculations later
    wgt.base <- sum(wgt.vec, na.rm = TRUE)


    # Validate: all `force_in` in `turf_cols` ---------------------------------

    if (do.force.in) {

        # any items in `force_in` that are not in `turf_cols`?
        bad.names <- setdiff(force.in.names, item.names)
        fail <- length(bad.names) > 0

        if (fail) {

            # warn that mismatches will be ignored
            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("All items in `force_in` should appear in `turf_cols`. The following items are missing in `turf_cols` and will be ignored:\n{bad.names.string}")
            warn(msg)

            # update the vector of names
            force.in.names <- force.in.names[!force.in.names %in% bad.names]

            # set it to false if there is nothing left
            if (length(force.in.names) == 0) {
                do.force.in <- FALSE
            }

        }

    }


    # Validate: all `force_out` in `turf_cols` --------------------------------

    if (do.force.out) {

        # any items in `force_out` that are not in `turf_cols`?
        bad.names <- setdiff(force.out.names, item.names)
        fail <- length(bad.names) > 0

        if (fail) {

            # warn that mismatches will be ignored
            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("All items in `force_out` should appear in `turf_cols`. The following items are missing in `turf_cols` and will be ignored:\n{bad.names.string}")
            warn(msg)

            # update the vector of names
            force.out.names <- force.out.names[!force.out.names %in% bad.names]

            # set it to false if there is nothing left
            if (length(force.out.names) == 0) {
                do.force.out <- FALSE
            }

        }

    }


    # Validate: no overlap between `force_in` and `force_out` -----------------

    if (do.force.in & do.force.out) {

        # any items overlap?
        bad.names <- intersect(force.in.names, force.out.names)
        fail <- length(bad.names) > 0

        if (fail) {

            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("Cannot have items that appear both in `force_in` and `force_out`. The following variables are in both:\n{bad.names.string}")
            abort(msg)

        }

    }


    # Validate: all `force_in_together` in `turf_cols` ------------------------

    if (do.force.in.together) {

        # index for dropping them if their lengths become 1
        drop.fi <- numeric()

        for (fi in seq_along(force.in.together.names)) {

            bad.names <- setdiff(force.in.together.names[[fi]], item.names)
            fail <- length(bad.names) > 0

            if (fail) {

                bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                msg <- glue("All items in `force_in_together` should appear in `turf_cols`. The following items are missing in `turf_cols` and will be ignored:\n{bad.names.string}")
                warn(msg)

                # drop them
                force.in.together.names[[fi]] <-
                    force.in.together.names[[fi]][!force.in.together.names[[fi]] %in% bad.names]

                if (length(force.in.together.names[[fi]]) < 2) {
                    drop.fi <- c(drop.fi, fi)
                }

            }

        }

        # drop the combos if necessary
        if (length(drop.fi) > 0) force.in.together.names <- force.in.together.names[-drop.fi]

        # reset the "do" things if necessary
        if (length(force.in.together.names) == 0)  do.force.in.together  <- FALSE

    }


    # Validate: all `force_in_together` in `turf_cols` ------------------------

    if (do.force.out.together) {

        # index for dropping them if their lengths become 1
        drop.fo <- numeric()

        for (fo in seq_along(force.out.together.names)) {

            bad.names <- setdiff(force.out.together.names[[fo]], item.names)
            fail <- length(bad.names) > 0

            if (fail) {

                bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                msg <- glue("All items in `force_in_together` should appear in `turf_cols`. The following items are missing in `turf_cols` and will be ignored:\n{bad.names.string}")
                warn(msg)

                # drop them
                force.out.together.names[[fo]] <-
                    force.out.together.names[[fo]][!force.out.together.names[[fo]] %in% bad.names]

                if (length(force.out.together.names[[fo]]) < 2) {
                    drop.fo <- c(drop.fo, fo)
                }

            }

        }

        # drop the combos if necessary
        if (length(drop.fo) > 0) force.out.together.names <- force.out.together.names[-drop.fo]

        # reset the "do" things if necessary
        if (length(force.out.together.names) == 0)  do.force.out.together  <- FALSE

    }


    # Validate: no overlap btw `force_in_together` & `force_out_together -----

    if (do.force.in.together & do.force.out.together) {

        # index of list items to drop
        drop.fi <- numeric()
        drop.fo <- numeric()

        for (fi in seq_along(force.in.together.names)) {
            for (fo in seq_along(force.out.together.names)) {

                bad.names <- intersect(
                    force.in.together.names[[fi]],
                    force.out.together.names[[fo]]
                )

                # this fails if intersection is more than 1
                fail <- length(bad.names) >= 2

                if (fail) {

                    bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                    msg <- glue("The following items appeared both in `force_in_together` and `force_out_together`, constraints with these items will be ignored:\n{bad.names.string}")
                    warn(msg)

                    # update index of list items to drop
                    drop.fi <- c(drop.fi, fi)
                    drop.fo <- c(drop.fo, fo)

                }

            }
        }

        # drop the combos if necessary
        if (length(drop.fi) > 0) force.in.together.names <- force.in.together.names[-drop.fi]
        if (length(drop.fo) > 0) force.out.together.names <- force.out.together.names[-drop.fo]

        # reset the "do" things if necessary
        if (length(force.in.together.names) == 0)  do.force.in.together  <- FALSE
        if (length(force.out.together.names) == 0) do.force.out.together <- FALSE

    }


    # Build constraints list --------------------------------------------------

    constraints <- list(
        "force_in"           = force.in.names,
        "force_in_together"  = force.in.together.names,
        "force_out"          = force.out.names,
        "force_out_together" = force.out.together.names
    )


    # Get TURF data, validate `k` ---------------------------------------------

    item.mat <-

        # get columns needed
        data %>%
        select(all_of(item.names)) %>%

        # remove cases with 100% missing data
        filter(
            !if_all(
                .cols = everything(),
                .fns = ~is.na(.x)
            )
        )

    # Validate K before any items are dropped through exclusions
    n.items <- ncol(item.mat)

    k <-
        k %>%
        round() %>%
        sort() %>%
        unique()

    if (!all(between(k, 1, n.items))) {
        abort(glue(
            "Input to `k` must contain values between 1 and number of columns in `turf_cols` ({n.items})."
        ))
    }

    # Drop items due to exclusions
    # Doing this here because it will save from having to create unnecessary
    # combinations with {arrangements}
    if (do.force.out) {
        item.mat <- select(item.mat, -all_of(force.out.names))
    }

    item.mat <- as.matrix(item.mat)

    # this needs to be updated if there are any forced exclusions
    item.names <- colnames(item.mat)
    n.items <- ncol(item.mat)

    oz.check <- dapply(item.mat, is_onezero)

    oz.fail <- any(!oz.check)

    if (oz.fail) {
        bad.names <- names(oz.check[!oz.check])
        bad.names.string <- paste(bad.names, sep = ", ")
        msg <- glue(
            "All variables in `turf_cols` must contain only 0/1 data, the following do not:\n{bad.names.string}"
        )
        abort(msg)
    }


    # Replace NA with zero, makes sense since we are operating row-wise
    # for reach.
    item.mat[is.na(item.mat)] <- 0


    # Initialize objects for storing results ----------------------------------

    # For reach values
    reach.list <- vector(
        mode = "list",
        length = length(k)
    )

    pad <- max(nchar(k))
    names(reach.list) <- paste0("k_", str_pad(k, width = pad, side = "left", pad = "0"))


    # Shapley approximation ---------------------------------------------------

    # Shapley approximation used for shapley greedy entry

    if (missing(case_weights)) {

        shap.approx <-
            shapley_approx(
                data = data,
                cols = all_of(item.names),
                need = 1,
                tidy = TRUE
            ) %>%
            arrange(desc(shapley_approx)) %>%
            rename(item = 1, sv = 2)

    } else {

        shap.approx <-
            shapley_approx(
                data = data,
                cols = all_of(item.names),
                weight = {{case_weights}},
                need = 1,
                tidy = TRUE
            ) %>%
            arrange(desc(shapley_approx))  %>%
            rename(item = 1, sv = 2)
    }


    # Begin turf message ------------------------------------------------------

    k.range <- range(k)

    # nicer output for displaying k range
    if (identical(seq(k.range[1], k.range[2], by = 1), k) & length(k) > 1) {
        k.string <- glue("{k.range[1]}-{k.range[2]}")
    } else {
        k.string <- paste(k, collapse = ", ")
    }

    cat_line(cli::rule("TURF", line = 2))
    cat_line(style_italic(" Sample size: "), scales::comma(nrow(item.mat)))
    cat_line(style_italic("  # of items: "), n.items)
    cat_line(style_italic("   Set sizes: "), k.string)
    cat_line(style_italic("       Depth: "), depth)

    if (greedy_begin <= max(k)) {
        cat_line(style_italic("   Greedy at: "), greedy_begin)
        cat_line(style_italic("Greedy entry: "), greedy_entry)
    }
    cat_line()


    # Inclusion/exclusion messages --------------------------------------------

    if (any(do.force.in, do.force.in.together, do.force.out, do.force.out.together)) {

        cat_line(rule("Constraints"))

        if (do.force.in) {
            msg <- paste(force.in.names, collapse = ", ")
            cat_line("Items included in every combination")
            cat_line(paste("\U2022", msg))
            cat_line()
        }

        if (do.force.in.together) {

            num <- length(force.in.together.names)

            msg <-
                force.in.together.names %>%
                map(paste, collapse = ", ") %>%
                map2(
                    .x = .,
                    .y = 1:num,
                    .f = ~glue("\U2022 {.x}")
                ) %>%
                paste(collapse = "\n")

            cat_line("Items that must all appear together within a combination")
            cat_line(msg)
            cat_line()

        }

        if (do.force.out) {
            msg <- paste(force.out.names, collapse = ", ")
            cat_line("Items excluded from every combination")
            cat_line(paste("\U2022", msg))
            cat_line()
        }

        if (do.force.out.together) {

            num <- length(force.out.together.names)

            msg <-
                force.out.together.names %>%
                map(paste, collapse = ", ") %>%
                map2(
                    .x = .,
                    .y = 1:num,
                    .f = ~glue("\U2022 {.x}")
                ) %>%
                paste(collapse = "\n")

            cat_line("Items that cannot all appear together within any combination")
            cat_line(msg)
            cat_line()

        }

    }



    # Begin TURF --------------------------------------------------------------

    cat_line(cli::rule("Analysis", line = 1))

    # string padding item names
    pad <- max(nchar(k))
    shap.keep <- 0

    # Begin iteratin'
    for (i in seq_along(k)) {

        cat_line(glue("Running best of {k[i]}\n"))

        # k cannot exceed number of items - this happens when there are
        # forced exclusions and running all k from 1 to # orig items
        if (k[i] > n.items) {

            cat_line(
                col_yellow(
                    "i Skipping this set size, cannot perform due to forced exclusions"
                )
            )
            next
        }

        # depth cannot exceed k
        if (depth > k[i]) {
            cat_line(
                col_yellow(
                    glue(
                        "i Skipping this set size, `depth` cannot exceed `k`"
                    )
                )
            )
            next
        }

        # generate full set of combinations ---
        combos <- combinations(
            x = item.names,
            k = k[i]
        )
        colnames(combos) <- paste0("i_", str_pad(1:k[i], width = 2, side = "left", pad = "0"))

        # reduce by forced inclusions ---
        if (do.force.in) {

            # index of rows to keep
            keep.combos <- dapply(
                X = combos,
                FUN = function(x) all(force.in.names %in% x),
                MARGIN = 1
            )

            # subset the combos
            combos <- combos[keep.combos, , drop = FALSE]

            # check and make sure rows remain
            if (nrow(combos) == 0) {

                cat_line(
                    col_yellow(
                        "i Skipping this set size, no more combinations after forced inclusions"
                    )
                )

                next

            }

        }

        # reduce by forced exclusions ---
        if (do.force.out) {

            # index of rows to keep
            # can't have any of the exclusions
            keep.combos <- dapply(
                X = combos,
                FUN = function(x) !any(force.out.names %in% x),
                MARGIN = 1
            )

            # subset the combos
            combos <- combos[keep.combos, , drop = FALSE]

            # check and make sure rows remain
            if (nrow(combos) == 0) {

                cat_line(
                    col_yellow(
                        "i Skipping this set size, no more combinations after forced exclusions"
                    )
                )

                next

            }

        }

        # reduce by forced in together ---
        if (do.force.in.together) {

            # loop thru the list
            for (j in seq_along(force.in.together.names)) {

                keep.names <- force.in.together.names[[j]]

                # no need to continue if the number of forced mutual
                # incluions exceeds the current set size
                if (length(keep.names) > k[i]) {
                    next
                }

                # if any are present, all have to be present
                keep.combos <- dapply(
                    X = combos,
                    FUN = function(x) {
                        chk1 <- ifelse(any(x %in% keep.names), 1, -1)
                        chk2 <- ifelse(all(keep.names %in% x), 1, -1)

                        (chk1 * chk2) == 1
                    },
                    MARGIN = 1
                )

                combos <- combos[keep.combos, , drop = FALSE]

                if (nrow(combos) == 0) {

                    cat_line(
                        col_yellow(
                            "i Skipping this set size, no combinations left after items forced in together"
                        )
                    )

                    break # completely exit this part of the for loop
                }

            }

        }

        # the forced mutual inclusions will break out of its for loop
        # skip over to the next set size if there are no combinations left
        if (nrow(combos) == 0) next


        # reduce by forced out together ---
        if (do.force.out.together) {

            # loop thru the list
            for (j in seq_along(force.out.together.names)) {

                drop.names <- force.out.together.names[[j]]

                # no need to continue if the number of forced mutual
                # incluions exceeds the current set size
                if (length(drop.names) > k[i]) {
                    next
                }


                # keep only comos where all of the drop names aren't in
                keep.combos <- dapply(
                    X = combos,
                    MARGIN = 1,
                    FUN = function(x) !all(drop.names %in% x)
                )

                combos <- combos[keep.combos, , drop = FALSE]

                if (nrow(combos) == 0) {

                    cat_line(
                        col_yellow(
                            "i Skipping this set size, no combinations left after items forced out together"
                        )
                    )

                    break # completely exit this part of the for loop
                }

            }

        }

        # the forced mutual inclusions will break out of its for loop
        # skip over to the next set size if there are no combinations left
        if (nrow(combos) == 0) next


        # If greedying ---
        # set size has to be when greedy kicks in
        # can't be the first iteration
        # the set size has to be larger than the number of force inclusions
        if (k[i] >= greedy_begin & i > 1 & k[i] > length(force.in.names)) {

            if (greedy_entry == "reach") {

                # Find the best combo from k-1
                item.keep.greedy <-
                    reach.list[[i-1]] %>%
                    head(1) %>%
                    select(matches("i")) %>%
                    unlist() %>%
                    unname()

            } else if (greedy_entry == "freq") {

                item.keep.greedy <-
                    reach.list[[i-1]] %>%
                    # re-sorting since default sorts by reach
                    collapse::roworderv(cols = "freq", decreasing = TRUE) %>%
                    head(1) %>%
                    select(matches("i")) %>%
                    unlist() %>%
                    unname()

            } else if (greedy_entry == "shapley") {

                shap.keep <- shap.keep + 1

                item.keep.greedy <-
                    shap.approx %>%
                    # force-ins have to be there already
                    # this skips them
                    filter(!item %in% force.in.names) %>%
                    head(shap.keep) %>%
                    pull(1)

            }

            keep.combos <- dapply(
                X = combos,
                MARGIN = 1,
                FUN = function(x) all(item.keep.greedy %in% x)
            )

            combos <- combos[keep.combos, , drop = FALSE]

        }


        # placeholder matrix ---
        n.combos <- nrow(combos)

        # this matrix receives the reach and freq calcs
        fill <- matrix(
            data = NA,
            ncol = 2,
            nrow = n.combos,
            dimnames = list(
                NULL,
                c("reach", "freq")
            )
        )

        # calculate reach ---
        for (j in 1:nrow(combos)) {

            n.reached <- rowsums(item.mat[, combos[j, ], drop = FALSE])
            is.reached <- n.reached >= depth
            reach <- fmean(x = is.reached, w = wgt.vec)

            # denominator is now calculated outside the loop
            # since it does not change
            freq  <- sum(wgt.vec * n.reached) / wgt.base

            fill[j, ] <- c(reach, freq)

        }

        # Join the reach measures to the combinations
        reach.stats <-
            combos %>%
            as_tibble() %>%
            bind_cols(as_tibble(fill)) %>%
            rowid_to_column("combo") %>%
            roworderv(cols = c("reach", "freq"), decreasing = TRUE) %>% # faster than arrange
            add_column(
                k = k[i],
                .before = 1
            )

        reach.list[[i]] <- reach.stats

    }

    reach.list <- discard(reach.list, is.null)

    # Return ------------------------------------------------------------------

    cat_line()

    list(
        reach = reach.list,
        items = item.names,
        options = list(
            k = k,
            depth = depth,
            weighted = has.weights,
            greedy_begin = greedy_begin,
            greedy_entry = greedy_entry
        ),
        constraints = constraints
    )


}

# greedy still not working...
out <- turf2(
    data = FoodSample[-1],
    turf_cols = 1:10,
    k = c(1:10),
    force_in = constraints(Chicken),
    # force_in = constraints(Chili, Turkey),
    greedy_begin = 1,
    greedy_entry = "shapley"
)
ttrodrigz/onezero documentation built on May 9, 2023, 2:59 p.m.