R/import_lift_method.R

Defines functions import_lift_method

Documented in import_lift_method

#' @title Import lift_method for lift modeling with rpart
#'
#' @description \code{import_lift_method} is a function that
#' imports a list of functions to serve as a user defined method
#' with the rpart function.
#' See example below for more details on it's usage.
#'
#' @return A list containing eval, split and init functions.
#' @example examples/RCTree_example.R
#' @details The rpart function accepts in the method argument
#' a user defined list. This function imports the list that
#' implements a causal tree. In addition to the method a user
#' also needs to input the parms argument with the baseline lift
#' in the population and the significance level for lift confidence
#' intervals returned in the yval2 object. See example below for more
#' details on how to use.
#' @seealso \code{\link{extract_segments}}
#'
#' @export

import_lift_method <- function(){
  lift_method = list(
    eval = function(y, wt, parms) {
      # y - matrix, first column is binary result, second is binary treatment
      OR_treatment <- mean(y[y[, 2] == 1, 1])
      OR_treatment <- replace(OR_treatment, is.nan(OR_treatment), 0)
      OR_control <- mean(y[y[, 2] == 0, 1])
      OR_control <- replace(OR_control, is.nan(OR_control), 0)

      lift <- OR_treatment - OR_control

      treatment_cases <- sum(y[, 2])
      control_cases <- nrow(y) - treatment_cases
      lift_sd <- sqrt(OR_treatment*(1 - OR_treatment)/treatment_cases +
                        OR_control*(1 - OR_control)/control_cases)

      lift_lower = lift - qnorm(parms$alpha/2, lower.tail = F)*lift_sd
      lift_upper = lift + qnorm(parms$alpha/2, lower.tail = F)*lift_sd

      deviance <- 1 - lift

      list(label = setNames(c(lift, lift_lower, lift_upper), c("lift", "lift_lower", "lift_upper")), deviance = deviance)
    },
    split = function(y, wt, x, parms, continuous)
    {
      n <- nrow(y)
      if (continuous) {
        # continuous x variable
        positive_cases <- sum(y[, 1])
        treatment_cases <- sum(y[, 2])

        cases_left <- 1:(n-1)
        cases_right <- (n-1):1
        positive_treatment_left <- cumsum(y[, 1] * y[, 2])[-n]
        positive_treatment_right <- sum(y[, 1] * y[, 2]) - positive_treatment_left
        treatment_n_left <- cumsum(y[, 2])[-n]
        treatment_n_right <- treatment_cases - treatment_n_left

        positive_control_left <- positive_cases - positive_treatment_left
        positive_control_right <- positive_cases - positive_treatment_right
        control_n_left <- n - treatment_n_left
        control_n_right <- n - treatment_n_right

        OR_treatment_left <- positive_treatment_left/treatment_n_left
        OR_control_left <- positive_control_left/control_n_left

        OR_treatment_left[is.nan(OR_treatment_left)] <- 0
        OR_control_left[is.nan(OR_control_left)] <- 0

        lift_left <- OR_treatment_left - OR_control_left

        OR_treatment_right <- positive_treatment_right/treatment_n_right
        OR_control_right <- positive_control_right/control_n_right

        OR_treatment_right[is.nan(OR_treatment_right)] <- 0
        OR_control_right[is.nan(OR_control_right)] <- 0

        lift_right <- OR_treatment_right - OR_control_right

        excess_lift_left <- (lift_left - parms$baseline_lift)*cases_left
        excess_lift_right <- (lift_right - parms$baseline_lift)*cases_right

        goodness <- pmax(excess_lift_left, excess_lift_right)/n
        list(goodness = goodness, direction = sign(lift_left - lift_right))
      } else {
        # Categorical X variable
        positive_cases <- sum(y[, 1])
        treatment_cases <- sum(y[, 2])

        cases_x <- tapply(y[, 1], x, length)
        positive_cases_x <- tapply(y[, 1], x, sum)

        treatment_cases_x <- tapply(y[, 2], x, sum)
        positive_treatment_x <- tapply(y[, 1] * y[, 2], x, sum)
        OR_treatment_x <- positive_treatment_x/treatment_cases_x
        OR_treatment_x[is.nan(OR_treatment_x)] <- 0

        control_cases_x <- cases_x - treatment_cases_x
        positive_control_x <- positive_cases_x - positive_treatment_x
        OR_control_x <- positive_control_x/control_cases_x
        OR_control_x[is.nan(OR_control_x)] <- 0

        lifts <- OR_treatment_x - OR_control_x

        ux <- sort(unique(x))
        ord <- order(abs(lifts))
        nx <- length(ux)

        cases_x_left <- cumsum(cases_x[ord])
        cases_x_right <- cumsum(rev(cases_x[ord]))

        positive_treatment_left <- cumsum(positive_treatment_x[ord])[-nx]
        positive_treatment_right <- sum(y[, 1] * y[, 2]) - positive_treatment_left
        treatment_n_left <- cumsum(treatment_cases_x[ord])[-nx]
        treatment_n_right <- treatment_cases - treatment_n_left
        OR_treatment_left <- positive_treatment_left/treatment_n_left
        OR_treatment_left[is.nan(OR_treatment_left)] <- 0
        OR_treatment_right <- positive_treatment_right/treatment_n_right
        OR_treatment_right[is.nan(OR_treatment_right)] <- 0

        positive_control_left <- positive_cases - positive_treatment_left
        positive_control_right <- positive_cases - positive_treatment_right
        control_n_left <- cases_x_left[-nx] - treatment_n_left
        control_n_right <- cases_x_right[-nx] - treatment_n_right
        OR_control_left <- positive_control_left/control_n_left
        OR_control_left[is.nan(OR_control_left)] <- 0
        OR_control_right <- positive_control_right/control_n_right
        OR_control_right[is.nan(OR_control_right)] <- 0

        lift_left <- OR_treatment_left - OR_control_left

        lift_right <- OR_treatment_right - OR_control_right

        excess_lift_left <- (lift_left - parms$baseline_lift)*cases_x_left[-nx]
        excess_lift_right <- (lift_right - parms$baseline_lift)*cases_x_right[-nx]

        goodness <- pmax(excess_lift_left, excess_lift_right)/n
        list(goodness = goodness, direction = ux[ord])
      }
    },
    init = function(y, offset, parms, wt) {
      if (!is.matrix(y) | ncol(y) != 2){
        stop("y has to be a 2 column matrix")
      }
      if (any(!y[, 1] %in% c(0, 1) | !y[, 2] %in% c(0, 1))){
        stop("y input columns must be binary")
      }
      if (is.null(parms$baseline_lift)){
        stop("Must specify baseline_lift in parms argument")
      }
      if (is.null(parms$alpha)){
        parms$alpha <- 0.05
      }
      if (!missing(offset) && length(offset) > 0){
        warning("offset argument ignored")
      }
      sfun <- function(yval, dev, wt, ylevel, digits ) {
        paste(" lift=", format(signif(yval, digits)),
              ", deviance=" , format(signif(dev, digits)),
              sep = '')
      }
      environment(sfun) <- .GlobalEnv
      list(y = y, parms = parms, numresp = 3, numy = 2, summary = sfun)
    }
  )
  return(lift_method)
}
IyarLin/RCTree documentation built on April 13, 2020, 12:37 a.m.