R/aps.R

Defines functions drawAPS computeAPS estimate_aps

Documented in estimate_aps

.datatable.aware=T

#' Main APS estimation function.
#' @import data.table
#' @import stats
#' @param data Dataset containing ML input variables.
#' @param ml ML function for treatment recommendation
#' @param ml_type String indicating the ML object source package name, or "custom" for a user-defined
#' algorithm that runs prediction upon function call with a single input data.table object.
#' See details for supported packages. Defaults to "stats".
#' @param xc Character vector of column names of the continuous variables.
#' @param xd Character vector of column names of the discrete variables.
#' @param infer Boolean whether to infer continuous/discrete variables from remaining columns of data. Defaults to False.
#' @param s Number of draws for each APS estimation. Defaults to 100.
#' @param delta Radius of sampling ball. Can be either numeric or numeric vector. Defaults to 0.8.
#' @param L Named list where the names correspond to the names of the mixed
#'   variables in the data, and the values are numeric vectors indicating the
#'   set of discrete values for the variable.
#' @param seed Random seed
#' @param fcn Function to apply to output of ML function
#' @param parallel Boolean indicator for whether to parallelize the APS
#'   estimation. Defaults to FALSE.
#' @param cores Integer number of cores for parallelization. If NA, then detectCores() is called.
#' @param \dots Additional inputs to be passed to \code{fcn}
#' @return If a single \code{delta} value is passed, then the function returns a vector of
#' estimated Approximate Propensity Scores of the same length as the input data. If multiple
#' \code{delta} are passed, then a list of estimated APS vectors are returned, where the
#' keys are each \code{delta} value.
#'
#' @examples
#' data("iris")
#' # Iris examples
#' model <- lm(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width, data=iris)
#' estimate_aps(iris, model, xc = names(iris)[2:3], xd = names(iris)[4],
#'             infer=FALSE, s=50, delta=0.1)
#'
#' # Estimate APS while applying decision function assign_cutoff to model output with input cutoff=0.5
#' assign_cutoff <- function(X, cutoff){
#'   ret <- as.integer(X > cutoff)
#'   return(ret)
#' }
#' estimate_aps(iris, model, xc = names(iris)[2:3], xd = names(iris)[4],
#'              infer=FALSE, s=50, delta=0.1, fcn=assign_cutoff, cutoff=0.5)
#'
#' # Multiple deltas
#' estimate_aps(iris, model, xc = names(iris)[2:3], xd = names(iris)[4],
#'              infer=FALSE, s=50, delta=c(0.1,0.5,1))
#'
#' # Define mixed continuous/discrete variables
#' estimate_aps(iris, model, xc = names(iris)[2:3], xd = names(iris)[4],
#'              infer=FALSE, s=50, delta=0.1,
#'              L = list("Sepal.Width" = c(2, 3), "Petal.Length" = c(3, 4)))
#' @details
#' ML packages currently supported:
#' "mlr3", "caret", "stats" (base), "randomForest", "e1071", "bestridge", "rpart", "tree",
#' "custom" (for user-defined functions)
#'
#' Approximate propensity score estimation involves taking draws \eqn{X_c^1,\ldots,X_c^s}
#' from the uniform distribution on
#' \eqn{N(X_{ci}, \delta)}, where \eqn{N(X_{ci},\delta)} is the
#' \eqn{p_c} dimensional ball centered at \eqn{X_{ci}} with radius \eqn{\delta}.
#' \eqn{X_c^1, \ldots,X_c^s} are destandardized before passed for ML
#' inference. The estimation equation is \eqn{p^s(X_i;\delta) = \frac{1}{s}\sum_{s=1}^{s} ML(X_c^s, X_{di})}.
#'
#' If neither \code{xc} nor \code{xd} are passed and \code{infer=T}, then \code{data} is assumed to
#' be all relevant continuous inputs. If only one is passed, then the remaining
#' variables in \code{data} are assumed to be inputs of the other type. It is
#' recommended to pass both \code{xc} and \code{xd} in the case that not all the
#' variables in \code{data} are relevant for the ML input, instead of relying on inference.
#' @export
estimate_aps <- function(data, ml, ml_type = "stats", xc = NA, xd = NA, infer=FALSE, s = 100, delta = 0.8,
                         L = NA, seed = NA, fcn = NA, parallel = F, cores = NA, ...){
  data.table::setDT(data)
  if (infer == T){
    if (all(is.na(xc)) & all(is.na(xd))){
      warning("(Inference on) Neither `xc` nor `xd` passed. Assuming all variables in `data` are continuous...")
      X_c <- data
      X_d <- NA
    } else if(all(is.na(xc))){
      warning("(Inference on) `xd` passed but `xc` not passed. Assuming all remaining variables in `data` are continuous...")
      xc <- setdiff(names(data), xd)
      X_d <- data[, xd, with = F]
      X_c <- data[, xc, with = F]
      if (length(X_c) < 1){ # If no cts columns then raise error
        stop(paste0("No continuous variables left in input data after removing discrete.",
                    " APS estimation requires at least one continuous variable."))
      }
    } else if(all(is.na(xd))){
      warning("(Inference on) `xc` passed but `xd` not passed. Assuming all remaining variables in `data` are discrete...")
      xd <- setdiff(names(data), xc)
      X_d <- data[, xd, with = F]
      X_c <- data[, xc, with = F]
    } else{
      X_c <- data[, xc, with = F]
      X_d <- data[, xd, with = F]
    }
  } else{
    if (all(is.na(xc))|length(xc) == 0){
      stop(paste0("No continuous variables passed and variable inference is off.",
                  " APS estimation requres at least one continuous variable."))
    }
    X_c <- data[, xc, with = F]
    X_d <- NA
    if (!all(is.na(xd)) & length(xd) > 0){
      X_d <- data[, xd, with = F]
    }
  }

  # Preprocess mixed variables
  mixed_inds <- list()
  mixed_vals <- list()
  if (length(names(L)) != 0){
    for (var in names(L)){
      tmp_inds <- which(X_c[, get(var)] %in% L[[var]])
      tmp_vals <- X_c[tmp_inds, get(var)]
      mixed_inds[[var]] <- tmp_inds
      mixed_vals[[var]] <- tmp_vals
      X_c[tmp_inds, (var) := NA]
    }
  }

  # Standardized continuous variables
  cts_names <- names(X_c)
  mu <- unlist(X_c[, lapply(.SD, function(x) mean(x, na.rm=T))])
  sigma <- unlist(X_c[, lapply(.SD, function(x) sd(x, na.rm=T))])
  X_c[, (cts_names) := lapply(.SD, function(x) (x - mean(x,na.rm=T))/sd(x, na.rm=T)), .SDcols = cts_names]

  # Set seed
  if (!is.na(seed)){
    set.seed(seed, kind = "L'Ecuyer-CMRG")
  }
  aps <- computeAPS(X_c, X_d, ml, ml_type, s, delta, mu, sigma, mixed_inds, mixed_vals, fcn, parallel,
                    cores, ...)
  return(aps)
}

computeAPS <- function(X_c, X_d, ml, ml_type, s, delta, mu, sigma, mixed_inds, mixed_vals, fcn,
                       parallel, cores, ...){
  draws <- drawAPS(X_c, X_d, s, delta, mu, sigma, mixed_inds, mixed_vals, parallel, cores)
  nobs <- nrow(X_c)
  p_c <- length(names(X_c))

  # Run ML prediction depending on type of function
  if (!all(is.na(X_d)) & length(X_d) > 0){
    X_d_long <- X_d[rep(X_d[,.I], each=s)]
  } else{
    X_d_long <- NA
  }
  if (typeof(draws) == "list"){
    aps_list <- list()
    for (d in as.character(delta)){
      # Convert draws and discrete values back to data table
      tmp_draws <- aperm(draws[[d]], c(2,1,3))
      tmp_draws <- data.table::as.data.table(matrix(tmp_draws, s*nobs, p_c))
      data.table::setnames(tmp_draws, names(X_c))
      if (!all(is.na(X_d_long))){
        input <- cbind(tmp_draws, X_d_long)
      } else{
        input <- tmp_draws
      }
      if (ml_type %in% c("mlr3", "caret", "stats", "randomForest", "e1071", "bestridge")){ # All these packages extend 'predict' function
        pred_out <- predict(ml, input)
      } else if (ml_type %in% c("rpart", "tree")){
        pred_out <- predict(ml, input, type="vector")
      } else if (ml_type == "custom"){
          pred_out <- ml(input)
        } else {
        message(paste0("Package ", ml_type, " not found in internal library. Attempting to apply `predict()` syntax..."))
        pred_out <- tryCatch({
          predict(ml, input)
        },
        error=function(cond){
          message(paste0("`Predict` call raised error: ", cond))
          return(NA)
        },
        warning=function(cond){
          message(paste0("Warning: ", cond))
          return(NULL)
        })
      }
      # Optional: additional post-processing function
      if (typeof(fcn) != "logical"){
        pred_out <- fcn(pred_out, ...)
      }
      # Avg every s values of prediction output
      aps <- .colMeans(pred_out, s, length(pred_out)/s, na.rm=T)
      aps_list[[d]] <- aps
    }
    return(aps_list)
  } else{
    # Convert draws back to data table
    draws <- data.table::as.data.table(matrix(aperm(draws, c(2,1,3)), s*nobs, p_c))
    data.table::setnames(draws, names(X_c))
    if (!all(is.na(X_d_long))){
      input <- cbind(draws, X_d_long)
    } else{
      input <- draws
    }
    if (ml_type %in% c("mlr3", "caret", "stats", "randomForest", "e1071", "bestridge")){ # All these packages extend 'predict' function
      pred_out <- predict(ml, input)
    } else if (ml_type %in% c("rpart", "tree")){
      pred_out <- predict(ml, input, type="vector")
    } else if (ml_type == "custom"){
      pred_out <- ml(input)
    } else {
      message(paste0("Package ", ml_type, " not found in internal library. Attempting to apply `predict()` syntax..."))
      pred_out <- tryCatch({
        predict(ml, input)
      },
      error=function(cond){
        message(paste0("Function call raised error: ", cond))
        return(NA)
      },
      warning=function(cond){
        message(paste0("Warning: ", cond))
        return(NULL)
      })
    }
    # Optional: additional post-processing function
    if (typeof(fcn) != "logical"){
      pred_out <- fcn(pred_out, ...)
    }
    # Avg every s values of prediction output
    aps <- .colMeans(pred_out, s, length(pred_out)/s, na.rm=T)
    return(aps)
  }
}

drawAPS <- function(X_c, X_d, s, delta, mu, sigma, mixed_inds, mixed_vals, parallel, cores){
  nobs <- nrow(X_c)
  p_c <- length(X_c)
  draws <- array(rnorm(nobs * p_c * s), dim=c(nobs, s, p_c), dimnames=list(1:nobs, 1:s, names(X_c)))
  u <- array(runif(nobs * s), dim=c(nobs, s))
  multi_delta <- F
  # Set up parallel clusters if toggled
  if (parallel == T){
    if (is.na(cores)){
      cl <- parallel::makeCluster(parallel::detectCores())
    } else{
      cl <- parallel::makeCluster(cores)
    }
    doParallel::registerDoParallel(cl)
    `%dopar%` <- foreach::`%dopar%`
    `%:%` <- foreach::`%:%`
  }
  # Set relevant obs to NA
  if (length(names(mixed_inds)) > 0){
    if (parallel == T){
      old_names <- names(X_c)
      draws <- foreach::foreach(var = names(X_c), .combine = function(x,y) abind::abind(x,y,along=3)) %dopar%{
        temp <- draws[,,var]
        if (var %in% names(mixed_inds)){
          temp[mixed_inds[[var]],] <- NA
        }
        temp
      }
      # Recover old names
      dimnames(draws)[[3]] <- old_names
    } else{
      for (var in names(mixed_inds)){
        draws[mixed_inds[[var]],,var] <- NA
      }
    }
  }

  # Row-wise scale draws
  if (parallel == T){
    # Save dimnames
    old_dim <- dimnames(draws)
    # Nested dopar
    draws <- foreach::foreach(i = dimnames(draws)[[1]], .combine=function(x,y) abind::abind(x,y,along=1)) %:%
      foreach::foreach(j = dimnames(draws)[[3]], .combine=function(x,y) abind::abind(x,y,along=3)) %dopar% {
        tmp <- draws[i,,j]
        sd <- sqrt(sum(tmp^2, na.rm=T))
        tmp <- tmp/sd
        array(tmp, dim=c(1,length(tmp), 1))
      }
    # Recover dimnames
    dimnames(draws) <- old_dim
  } else{
    sd <- apply(draws, c(1,3), function(x) sqrt(sum(x^2, na.rm=T)))
    for (i in 1:dim(sd)[1]){
      for (j in 1:dim(sd)[2]){
        draws[i,,j] <- draws[i,,j]/sd[i,j]
      }
    }
  }

  # Set u -- check for all NA rows
  if (parallel == T){
    u <- foreach::foreach(i = 1:nrow(X_c), .combine=cbind) %dopar% {
      na_count <- sum(is.na(X_c[i,]))
      ct <- p_c - na_count
      if (ct == 0){
        rep(NA, s)
      } else{
        u[i,]^(1/ct)
      }
    }
  } else{
    na_counts <- as.array(apply(X_c, 1, function(x) sum(is.na(x))))
    ct <- p_c - na_counts
    for (i in 1:nrow(u)){
      if (ct[i] == 0){
        u[i,] <- NA
      }
      u[i,] <- u[i,]^(1/ct[i])
    }
  }
  u <- array(rep(u, p_c), dim=c(nobs, s, p_c), dimnames=list(1:nobs, 1:s, names(X_c))) # Increase dimension to be conformable with draws

  X_c <- array(unlist(X_c), dim=c(nobs,p_c), dimnames=list(1:nobs, names(X_c)))
  if (length(delta) > 1){ # Loop through deltas
    draws_list <- list()
    for (d in delta){
      tmp_draws <- draws
      if (parallel == T){
        # Save original colnames
        old_names <- dimnames(X_c)[[2]]
        tmp_draws <- foreach::foreach(j = dimnames(X_c)[[2]], .combine= function(x,y) abind::abind(x,y,along=3)) %dopar%{
          temp <- tmp_draws[,,j] * u[,,j] * delta + X_c[,j]
          temp <- temp * sigma[j] + mu[j]
          temp
        }
        dimnames(tmp_draws)[[3]] <- old_names
      }
      else{
        for (j in dimnames(X_c)[[2]]){
          # tmp_draws from uniform ball centered at standardized xc
          tmp_draws[,,j] <- tmp_draws[,,j] * u[,,j] * d + X_c[,j]
          # tmp_draws from ball centered at original xc
          tmp_draws[,,j] <- tmp_draws[,,j] * sigma[j] + mu[j]
        }
      }
      for (var in names(mixed_inds)){
        tmp_draws[mixed_inds[[var]],,var] <- mixed_vals[[var]]
      }
      draws_list[[as.character(d)]] <- tmp_draws
    }
    if (parallel == T){
      parallel::stopCluster(cl)
    }
    return(draws_list)
  } else{
    if (parallel == T){
      # Save original colnames
      old_names <- dimnames(X_c)[[2]]
      draws <- foreach::foreach(j = old_names, .combine= function(x,y) abind::abind(x,y,along=3)) %dopar%{
        temp <- draws[,,j] * u[,,j] * delta + X_c[,j]
        temp <- temp * sigma[j] + mu[j]
        temp
      }
      dimnames(draws)[[3]] <- old_names
    } else{
      for (j in dimnames(X_c)[[2]]){
        # draws from uniform ball centered at standardized xc
        draws[,,j] <- draws[,,j] * u[,,j] * delta + X_c[,j]
        # draws from ball centered at original xc
        draws[,,j] <- draws[,,j] * sigma[j] + mu[j]
      }
    }
    if (parallel == T){
      parallel::stopCluster(cl)
    }
    # add back discrete values
    for (var in names(mixed_inds)){
      draws[mixed_inds[[var]],,var] <- mixed_vals[[var]]
    }
    return(draws)
  }
}
factoryofthesun/r-IVaps documentation built on Dec. 20, 2021, 7:41 a.m.