test-code/turf4.R

turf4 <- function(
        data, items, case_weights,
        k = 1, depth = 1,
        force_in, force_in_together,
        force_out, force_out_together,
        greedy_begin = 20, greedy_entry = "shapley",
        progress = FALSE
) {


    # Preliminary error checks ------------------------------------------------

    # Start total clock
    total.clock1 <- Sys.time()

    if (!is.data.frame(data)) {
        abort("Input to `data` must be a data frame.")
    }

    if (!is.numeric(k)) {
        abort("Input to `k` must be one or more numeric values.")
    }

    if (!is.numeric(depth) | length(depth) != 1) {
        abort("Input to `depth` must be a single numeric value.")
    }
    depth <- floor(depth)

    if (!is.numeric(greedy_begin) | length(greedy_begin) != 1) {
        abort("Input to `greedy_begin` must be a single numeric value.")
    }
    greedy_begin <- floor(greedy_begin)

    if (!greedy_entry %in% c("shapley", "reach")) {
        abort("Input to `greedy_entry` must be one of 'shapley' or 'reach'.")
    }

    if (!greedy_entry == "shapley" & !missing(force_out_together)) {
        abort("'shapley' greedy entry cannot be combined with mutual exclusions.")
    }


    # Get names of things -----------------------------------------------------

    # `items`
    item.names <- names(eval_select(expr = enquo(items), data = data))

    # `case_weights`
    has.weights <- FALSE

    if (!missing(case_weights)) {

        case.weights.names <- names(
            eval_select(
                expr = enquo(case_weights),
                data = data
            )
        )

        has.weights <- TRUE

    }


    # `force_in`
    do.force.in <- FALSE
    force.in.names <- character()

    if (!missing(force_in)) {

        force.in.names <- names(
            eval_select(
                expr = enquo(force_in),
                data = data
            )
        )

        do.force.in <- TRUE

    }

    # `force_in_together`
    do.force.in.together <- FALSE
    force.in.together.names <- list()

    if (!missing(force_in_together)) {

        force.in.together.names <-
            map(
                .x = force_in_together,
                .f = ~eval_select(
                    expr = .x,
                    data = data
                )
            ) %>%
            map(names) %>%
            discard(~length(.x) <= 1) %>%
            # this removes accidental duplicates
            map(sort) %>%
            unique()

        do.force.in.together <- length(force.in.together.names) > 0

    }


    # `force_out`
    do.force.out <- FALSE
    force.out.names <- character()

    if (!missing(force_out)) {

        force.out.names <- names(
            eval_select(
                expr = enquo(force_out),
                data = data
            )
        )

        do.force.out <- TRUE
    }

    # `force_out_together`
    do.force.out.together <- FALSE
    force.out.together.names <- list()

    # `force_out_together`
    if (!missing(force_out_together)) {

        force.out.together.names <-
            map(
                .x = force_out_together,
                .f = ~eval_select(
                    expr = .x,
                    data = data
                )
            ) %>%
            map(names) %>%
            discard(~length(.x) <= 1) %>%
            # this removes accidental duplicates
            map(sort) %>%
            unique()

        do.force.out.together <- length(force.out.together.names) > 0

    }


    # Set up weights ----------------------------------------------------------

    if (missing(case_weights)) {

        wgt.vec <- rep(1, times = nrow(data))

    } else {

        if (length(case.weights.names) > 1) {
            abort("Can only provide one column of weights to `case_weights`.")
        }

        if (case.weights.names %in% item.names) {
            abort(glue("Column '{case.weights.names}' cannot used both in `items` and `case_weights`."))
        }

        wgt.vec <- data[[case.weights.names]]

        if (!is.numeric(wgt.vec)) {
            abort("Input to `case_weights` must be a numeric column.")
        }

    }

    # Use this as the denominator for reach calculations later
    wgt.base <- sum(wgt.vec, na.rm = TRUE)


    # Validate: all `force_in` in `items` -------------------------------------

    if (do.force.in) {

        # any items in `force_in` that are not in `items`?
        bad.names <- setdiff(force.in.names, item.names)
        fail <- length(bad.names) > 0

        if (fail) {

            # warn that mismatches will be ignored
            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("All items in `force_in` should appear in `items`. The following items are missing in `items` and will be ignored:\n{bad.names.string}")
            warn(msg)

            # update the vector of names
            force.in.names <- force.in.names[!force.in.names %in% bad.names]

            # set it to false if there is nothing left
            if (length(force.in.names) == 0) {
                do.force.in <- FALSE
            }

        }

    }


    # Validate: all `force_out` in `items` ------------------------------------

    if (do.force.out) {

        # any items in `force_out` that are not in `items`?
        bad.names <- setdiff(force.out.names, item.names)
        fail <- length(bad.names) > 0

        if (fail) {

            # warn that mismatches will be ignored
            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("All items in `force_out` should appear in `items`. The following items are missing in `items` and will be ignored:\n{bad.names.string}")
            warn(msg)

            # update the vector of names
            force.out.names <- force.out.names[!force.out.names %in% bad.names]

            # set it to false if there is nothing left
            if (length(force.out.names) == 0) {
                do.force.out <- FALSE
            }

        }

    }


    # Validate: no overlap between `force_in` and `force_out` -----------------

    if (do.force.in & do.force.out) {

        # any items overlap?
        bad.names <- intersect(force.in.names, force.out.names)
        fail <- length(bad.names) > 0

        if (fail) {

            bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
            msg <- glue("Cannot have items that appear both in `force_in` and `force_out`. The following variables are in both:\n{bad.names.string}")
            abort(msg)

        }

    }


    # Validate: all `force_in_together` in `items` ----------------------------

    if (do.force.in.together) {

        # index for dropping them if their lengths become 1
        drop.fi <- numeric()

        for (fi in seq_along(force.in.together.names)) {

            bad.names <- setdiff(force.in.together.names[[fi]], item.names)
            fail <- length(bad.names) > 0

            if (fail) {

                bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                msg <- glue("All items in `force_in_together` should appear in `items`. The following items are missing in `items` and will be ignored:\n{bad.names.string}")
                warn(msg)

                # drop them
                force.in.together.names[[fi]] <-
                    force.in.together.names[[fi]][!force.in.together.names[[fi]] %in% bad.names]

                if (length(force.in.together.names[[fi]]) < 2) {
                    drop.fi <- c(drop.fi, fi)
                }

            }

        }

        # drop the combos if necessary
        if (length(drop.fi) > 0) force.in.together.names <- force.in.together.names[-drop.fi]

        # reset the "do" things if necessary
        if (length(force.in.together.names) == 0)  do.force.in.together  <- FALSE

    }


    # Validate: all `force_in_together` in `items` ----------------------------

    if (do.force.out.together) {

        # index for dropping them if their lengths become 1
        drop.fo <- numeric()

        for (fo in seq_along(force.out.together.names)) {

            bad.names <- setdiff(force.out.together.names[[fo]], item.names)
            fail <- length(bad.names) > 0

            if (fail) {

                bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                msg <- glue("All items in `force_in_together` should appear in `items`. The following items are missing in `items` and will be ignored:\n{bad.names.string}")
                warn(msg)

                # drop them
                force.out.together.names[[fo]] <-
                    force.out.together.names[[fo]][!force.out.together.names[[fo]] %in% bad.names]

                if (length(force.out.together.names[[fo]]) < 2) {
                    drop.fo <- c(drop.fo, fo)
                }

            }

        }

        # drop the combos if necessary
        if (length(drop.fo) > 0) force.out.together.names <- force.out.together.names[-drop.fo]

        # reset the "do" things if necessary
        if (length(force.out.together.names) == 0)  do.force.out.together  <- FALSE

    }


    # Validate: no overlap btw `force_in_together` & `force_out_together -----

    if (do.force.in.together & do.force.out.together) {

        # index of list items to drop
        drop.fi <- numeric()
        drop.fo <- numeric()

        for (fi in seq_along(force.in.together.names)) {
            for (fo in seq_along(force.out.together.names)) {

                bad.names <- intersect(
                    force.in.together.names[[fi]],
                    force.out.together.names[[fo]]
                )

                # this fails if intersection is more than 1
                fail <- length(bad.names) >= 2

                if (fail) {

                    bad.names.string <- paste(glue("{bad.names}"), collapse = ", ")
                    msg <- glue("The following items appeared both in `force_in_together` and `force_out_together`, constraints with these items will be ignored:\n{bad.names.string}")
                    warn(msg)

                    # update index of list items to drop
                    drop.fi <- c(drop.fi, fi)
                    drop.fo <- c(drop.fo, fo)

                }

            }
        }

        # drop the combos if necessary
        if (length(drop.fi) > 0) force.in.together.names <- force.in.together.names[-drop.fi]
        if (length(drop.fo) > 0) force.out.together.names <- force.out.together.names[-drop.fo]

        # reset the "do" things if necessary
        if (length(force.in.together.names) == 0)  do.force.in.together  <- FALSE
        if (length(force.out.together.names) == 0) do.force.out.together <- FALSE

    }


    # Build constraints list --------------------------------------------------

    constraints <- list(
        "force_in"           = force.in.names,
        "force_in_together"  = force.in.together.names,
        "force_out"          = force.out.names,
        "force_out_together" = force.out.together.names
    )


    # Get the turf data -------------------------------------------------------

    # Just grab the columns to be turf'd
    item.mat <- select(data, all_of(item.names))


    # Validate `k` ------------------------------------------------------------

    # Make sure user is only requesting sets of `k` that make sense.

    # These are the number of items before any forced exclusions take place
    n.items <- ncol(item.mat)

    # Establish final `k` --
    # Setting to floor will force into an integer
    # Need to sort to make sure it cycles thru the set sizes in an order
    #   that makes sense
    # Drop duplicates, no need to run set sizes more than once
    # Subset `k` to be between 1 and the number of items
    k <-
        k %>%
        floor() %>%
        sort() %>%
        unique() %>%
        .[between(., 1, n.items)]

    if (length(k) == 0) {
        abort(glue(
            "Input to `k` must contain integers between 1 and number of columns in `turf_cols` ({n.items})."
        ))
    }


    # Drop items due to exclusions --------------------------------------------

    # Doing this here because it will save from having to create unnecessary
    # combinations with {arrangements}
    if (do.force.out) {
        item.mat <- select(item.mat, -all_of(force.out.names))
    }


    # Finally prep turf data --------------------------------------------------

    # Rfast::rowsums() needs the data to be in a matrix
    item.mat <- as.matrix(item.mat)

    # this needs to be updated if there are any forced exclusions
    item.names <- colnames(item.mat)
    n.items <- ncol(item.mat)

    # Check and make sure that all of the columns contain only 1/0/NA
    oz.check <- dapply(item.mat, is_onezero)

    oz.fail <- any(!oz.check)

    if (oz.fail) {
        bad.names <- names(oz.check[!oz.check])
        bad.names.string <- paste(bad.names, sep = ", ")
        msg <- glue(
            "All variables in `turf_cols` must contain only 0/1 data, the following do not:\n{bad.names.string}"
        )
        abort(msg)
    }


    # Check and see if any rows have 100% missing data
    all.miss <-
        dapply(
            X = item.mat,
            FUN = function(x) all(is.na(x)),
            MARGIN = 1
        ) %>%
        which()

    if (length(all.miss) > 0) {
        all.miss.string <- paste(all.miss, collapse = ", ")
        msg <- glue(
            "{length(all.miss)} rows in `data` have 100% missing values for the items specified in `turf_cols`. They will still be retained in the analysis and treated as \"unreached\". If you do not want those rows in the TURF analysis, please remove them ahead of time."
        )
        warn(msg)
    }

    # Replace NA with zero, makes sense since we are operating row-wise
    # for reach. This radically improves the speed of Rfast::rowsums().
    item.mat[is.na(item.mat)] <- 0


    # Initialize objects for storing and displaying results -------------------

    # For reach values
    reach.list <- vector(
        mode = "list",
        length = length(k)
    )

    # how long does it take to make and constrain the combos?
    combo.clock <- vector("double", length = length(k))

    # how long does turf calculation take place?
    turf.clock <- vector("double", length = length(k))

    # Used for string padding
    # This gets used for setting name of reach list as well as the item number
    # identifiers in the data
    pad <- max(nchar(k))
    names(reach.list) <- paste0("k_", str_pad(k, width = pad, side = "left", pad = "0"))


    # Shapley approximation ---------------------------------------------------

    # Shapley approximation used for shapley greedy entry
    # Only need if using that method and greedy will actually kick in
    if (greedy_entry == "shapley" & greedy_begin <= max(k)) {
        if (missing(case_weights)) {

            shap.approx <-
                shapley_approx(
                    data = data,
                    cols = all_of(item.names),
                    need = 1,
                    tidy = TRUE
                ) %>%
                arrange(desc(shapley_approx)) %>%
                rename(item = 1, sv = 2)

        } else {

            shap.approx <-
                shapley_approx(
                    data = data,
                    cols = all_of(item.names),
                    weight = {{case_weights}},
                    need = 1,
                    tidy = TRUE
                ) %>%
                arrange(desc(shapley_approx))  %>%
                rename(item = 1, sv = 2)
        }
    }


    # Do TURF ----------------------------------------------------------------

    cat_line(rule("TURF", line = 1))

    # string padding item names
    pad <- nchar(n.items)
    shap.keep <- 0

    # Begin iteratin'
    for (i in seq_along(k)) {

        cat_line(glue("Running best of {k[i]}\n"))

        # k cannot exceed number of items - this happens when there are
        # forced exclusions and running all k from 1 to # orig items
        if (k[i] > n.items) {

            cat_line(
                col_yellow(
                    "i Skipping this set size, cannot perform due to forced exclusions"
                )
            )
            next
        }

        # depth cannot exceed k
        if (depth > k[i]) {
            cat_line(
                col_yellow(
                    glue(
                        "i Skipping this set size, `depth` cannot exceed `k`"
                    )
                )
            )
            next
        }


        # start combo clock
        combo.clock1 <- Sys.time()

        # generate full set of combinations
        combos <- combinations(
            x = item.names,
            k = k[i]
        )
        colnames(combos) <- paste0("i_", str_pad(1:k[i], width = 2, side = "left", pad = "0"))

        # reduce by forced inclusions
        if (do.force.in) {

            # index of rows to keep
            keep.combos <- dapply(
                X = combos,
                FUN = function(x) all(force.in.names %in% x),
                MARGIN = 1
            )

            # subset the combos
            combos <- combos[keep.combos, , drop = FALSE]

            # check and make sure rows remain
            if (nrow(combos) == 0) {

                cat_line(
                    col_yellow(
                        "i Skipping this set size, no more combinations after forced inclusions"
                    )
                )

                next

            }

        }

        # reduce by forced exclusions
        if (do.force.out) {

            # index of rows to keep
            # can't have any of the exclusions
            keep.combos <- dapply(
                X = combos,
                FUN = function(x) !any(force.out.names %in% x),
                MARGIN = 1
            )

            # subset the combos
            combos <- combos[keep.combos, , drop = FALSE]

            # check and make sure rows remain
            if (nrow(combos) == 0) {

                cat_line(
                    col_yellow(
                        "i Skipping this set size, no more combinations after forced exclusions"
                    )
                )

                next

            }

        }

        # reduce by forced in together
        if (do.force.in.together) {

            # loop thru the list
            for (j in seq_along(force.in.together.names)) {

                keep.names <- force.in.together.names[[j]]

                # no need to continue if the number of forced mutual
                # incluions exceeds the current set size
                if (length(keep.names) > k[i]) {
                    next
                }

                # if any are present, all have to be present
                keep.combos <- dapply(
                    X = combos,
                    FUN = function(x) {
                        chk1 <- ifelse(any(x %in% keep.names), 1, -1)
                        chk2 <- ifelse(all(keep.names %in% x), 1, -1)

                        (chk1 * chk2) == 1
                    },
                    MARGIN = 1
                )

                combos <- combos[keep.combos, , drop = FALSE]

                if (nrow(combos) == 0) {

                    cat_line(
                        col_yellow(
                            "i Skipping this set size, no combinations left after items forced in together"
                        )
                    )

                    break # completely exit this part of the for loop
                }

            }

        }

        # the forced mutual inclusions will break out of its for loop
        # skip over to the next set size if there are no combinations left
        if (nrow(combos) == 0) next


        # reduce by forced out together
        if (do.force.out.together) {

            # loop thru the list
            for (j in seq_along(force.out.together.names)) {

                drop.names <- force.out.together.names[[j]]

                # no need to continue if the number of forced mutual
                # incluions exceeds the current set size
                if (length(drop.names) > k[i]) {
                    next
                }


                # keep only comos where all of the drop names aren't in
                keep.combos <- dapply(
                    X = combos,
                    MARGIN = 1,
                    FUN = function(x) !all(drop.names %in% x)
                )

                combos <- combos[keep.combos, , drop = FALSE]

                if (nrow(combos) == 0) {

                    cat_line(
                        col_yellow(
                            "i Skipping this set size, no combinations left after items forced out together"
                        )
                    )

                    break # completely exit this part of the for loop
                }

            }

        }

        # the forced mutual inclusions will break out of its for loop
        # skip over to the next set size if there are no combinations left
        if (nrow(combos) == 0) next


        # If greedying
        # set size has to be when greedy kicks in
        # can't be the first iteration
        # the set size has to be larger than the number of force inclusions
        if (k[i] >= greedy_begin & i > 1 & k[i] > length(force.in.names)) {

            if (greedy_entry == "reach") {

                # Find the best combo from k-1
                item.keep.greedy <-
                    reach.list[[i-1]] %>%
                    roworderv(cols = "reach", decreasing = TRUE) %>%
                    head(1) %>%
                    select(matches("i")) %>%
                    unlist() %>%
                    unname()

            } else if (greedy_entry == "freq") {

                item.keep.greedy <-
                    reach.list[[i-1]] %>%
                    roworderv(cols = "freq", decreasing = TRUE) %>%
                    head(1) %>%
                    select(matches("i")) %>%
                    unlist() %>%
                    unname()

            } else if (greedy_entry == "shapley") {

                shap.keep <- shap.keep + 1

                item.keep.greedy <-
                    shap.approx %>%
                    # force-ins have to be there already
                    # this skips them
                    filter(!item %in% force.in.names) %>%
                    head(shap.keep) %>%
                    pull(1)

            }

            keep.combos <- dapply(
                X = combos,
                MARGIN = 1,
                FUN = function(x) all(item.keep.greedy %in% x)
            )

            combos <- combos[keep.combos, , drop = FALSE]

        }


        # placeholder matrix
        n.combos <- nrow(combos)

        # end combo clock
        combo.clock2 <- Sys.time()
        combo.clock[i] <- as.numeric(
            difftime(combo.clock2, combo.clock1, units = "secs")
        )

        # this matrix receives the reach and freq calcs
        fill <- matrix(
            data = NA,
            ncol = 2,
            nrow = n.combos,
            dimnames = list(
                NULL,
                c("reach", "freq")
            )
        )

        # calculate reach
        # pb <- progress_bar$new(total = n.combos)
        # It should be noted that time was spent on optimizing this part of
        # the calculation. This for-loop is marginally faster than any of the
        # apply functions I have tried.

        if (progress) prog.seq <- seq(1, n.combos, length.out = 10) %>% floor()


        turf.clock1 <- Sys.time()

        for (j in 1:nrow(combos)) {

            if (progress) {
                if (any(prog.seq == j)) {
                    cat("\r")
                    cat("  Progress:", scales::percent(j/n.combos))
                    cat("\r")
                }
            }

            # this way of indexing via `[` is faster than collapse::ss()
            n.reached <- rowsums(item.mat[, combos[j, ], drop = FALSE])
            is.reached <- n.reached >= depth
            reach <- fmean(x = is.reached, w = wgt.vec)

            # denominator is now calculated outside the loop
            # since it does not change
            freq  <- sum(wgt.vec * n.reached) / wgt.base

            fill[j, ] <- c(reach, freq)

        }

        # Join the reach measures to the combinations
        reach.stats <-
            combos %>%
            as_tibble() %>%
            bind_cols(as_tibble(fill), .) %>%
            rowid_to_column("combo") # %>%
        # roworderv(cols = c("reach", "freq"), decreasing = TRUE)

        # update the list
        reach.list[[i]] <- reach.stats

        # log the clock
        turf.clock2 <- Sys.time()

        turf.clock[i] <- as.numeric(
            difftime(turf.clock2, turf.clock1, units = "secs")
        )

    }


    # update clocks and reach output
    k.null       <- map_lgl(reach.list, is.null)
    combo.clock  <- combo.clock[!k.null]
    turf.clock   <- turf.clock[!k.null]
    reach.list   <- reach.list[!k.null]
    total.clock2 <- Sys.time()
    total.clock  <- as.numeric(
        difftime(total.clock2, total.clock1, units = "secs")
    )


    # Organize and return output ----------------------------------------------

    reach.list <-
        reach.list %>%
        enframe(name = "k", value = "reach") %>%
        mutate(k = parse_number(k))


    clock.list <- list(
        total = total.clock,
        by_k = tibble(
            k = reach.list$k,
            n_combos = map_int(reach.list$reach, nrow),
            combo_secs = combo.clock,
            turf_secs = turf.clock,
            combo_per_sec = n_combos / combo_secs,
            turf_per_sec =  n_combos / turf_secs
        )
    )

    info <- list(
        n = wgt.base,
        n_items = length(item.names),
        items = item.names,
        k = k,
        depth = depth,
        case_weights = list(
            weighted = has.weights,
            name = ifelse(has.weights, case.weights.names, NA_character_)
        ),
        greedy = list(
            begin = greedy_begin,
            entry = greedy_entry
        ),
        progress = progress
    )

    out <- list(
        reach = reach.list,
        info = info,
        constraints = constraints,
        clock = clock.list
    )

    class(out) <- "turf"

    out

}
ttrodrigz/onezero documentation built on May 9, 2023, 2:59 p.m.