R/predict.FindIt.R

#' Computing predicted values for each sample in the data.
#' 
#' \code{predict.FindIt} takes an output from \code{FindIt} and returns
#' estimated treatment effects when \code{treat.type="single"} and predicted
#' outcomes for each treatment combination when \code{treat.type="multiple"}.
#' 
#' Useful for computing estimated treatment effects or predicted outcomes for
#' each treatment combination. By using \code{newdata}, researchers can compute
#' them for any samples.
#' 
#' @param object An output object from \code{FindIt}.
#' @param newdata An optional data frame in which to look for variables with
#' which to predict. If omitted, the data used in \code{FindIt} is used.
#' @param sort Whether to sort samples according to estimated treatment
#' effects.
#' @param decreasing When \code{sort=TRUE}, whether to sort the output in
#' descending order or not.
#' @param wts Weights.
#' @param unique If \code{unique=TRUE}, \code{predict} returns estimated
#' treatment effects or predicted outcomes for unique samples.
#' @param \dots further arguments passed to or from other methods.
#' @return \item{data}{A matrix of estimated treatment effects when
#' \code{treat.type="single"} and predicted outcomes for each treatment
#' combination when \code{treat.type="multiple"}.}
#' @author Naoki Egami, Marc Ratkovic and Kosuke Imai.
#' @examples
#' 
#' ## See the help page for FindIt() for an example.
#' 	
#' @export
predict.FindIt <- function (object, newdata, sort = TRUE, decreasing = TRUE, wts = 1, 
                            unique = FALSE, ...) {
  treat.type <- object$treat.type
  type <- object$type
  threshold <- object$threshold
  make.twoway <- object$make.twoway
  make.allway <- object$make.allway
  main <- object$main
  nway <- object$nway
  model.treat <- object$model.treat
  if (main) {
    model.main <- object$model.main
    model.int <- object$model.int
  }
  if (missing(newdata)) {
    different <- 0
  }
  if (!missing(newdata)) {
    different <- 1
    # update here
    if(all(colnames(newdata) == colnames(object$data)) == FALSE){
      stop(" `newdata` should have the same variables as `data` used in `FindIt`. ")
    }else{
      new_id <- seq(from = 1, to = nrow(newdata))
      newdata <- rbind(newdata, object$data)
    }
  }
  if (different == 0) {
    coefs2 <- object$coefs.orig
    y <- object$y
    treat <- object$treat
    y.orig <- object$y.orig
    treat.orig <- object$treat.orig
    X.t <- object$X.t
    if (main) {
      X.c <- object$X.c
      X.int <- object$X.int
      X.int.orig <- object$X.int.orig
      scale.c <- object$scale.c
      scale.int <- object$scale.int
      X.c.orig <- object$X.c.orig
    }
  }
  if (different == 1) {
    terms.treat <- terms(model.treat)
    terms.treat2 <- delete.response(terms.treat)
    treat.frame <- model.frame(terms.treat2, data = newdata, 
                               na.action = NULL)
    if (main) {
      terms.main <- terms(model.main)
      terms.int <- terms(model.int)
      main.frame <- model.frame(terms.main, data = newdata, 
                                na.action = NULL)
      int.frame <- model.frame(terms.int, data = newdata, 
                               na.action = NULL)
    }
    data.which <- c()
    if (main) {
      for (i in 1:nrow(treat.frame)) {
        if (any(is.na(main.frame[i, ])) | any(is.na(int.frame[i, 
        ])) | any(is.na(treat.frame[i, ]))) {
          data.which[i] <- 0
        }
        else {
          data.which[i] <- 1
        }
      }
    }
    if (main == FALSE) {
      for (i in 1:nrow(treat.frame)) {
        if (any(is.na(treat.frame[i, ]))) {
          data.which[i] <- 0
        }
        else {
          data.which[i] <- 1
        }
      }
    }
    treat.frame <- treat.frame[data.which == 1, ]
    treat <- treat.orig <- treat.frame
    if (main == TRUE) {
      X.c <- main.frame[data.which == 1, ]
      X.int <- int.frame[data.which == 1, ]
      X.int.orig <- X.int
      if (make.twoway == TRUE) {
        frame.meanPreC <- object$frame.meanC
        XC <- maketwoway(X.c, wts = wts, center = TRUE, 
                         frame.meanPre = frame.meanPreC, predict = TRUE)
        X.c <- XC$X
        scale.c <- XC$scale.X
        frame.meanPreInt <- object$frame.meanInt
        XInt <- maketwoway(X.int, wts = wts, center = TRUE, 
                           frame.meanPre = frame.meanPreInt, predict = TRUE)
        X.int <- XInt$X
        scale.int <- XInt$scale.X
        X.c <- as.matrix(X.c)
        X.int <- as.matrix(X.int)
      }
      X.c.m <- X.c
      scale.c <- scale.c[colnames(X.c.m) == colnames(X.c)]
      if (all(X.c[, 1] == 1) == FALSE) {
        X.c <- cbind(1, X.c)
        colnames(X.c)[1] <- "Intercept"
      }
      X.c <- X.c * wts^0.5
      if (treat.type == "single") {
        X.int.m <- X.int
        scale.int <- scale.int[colnames(X.int.m) == colnames(X.int)]
        if (all(X.int[, 1] == 1) == FALSE) {
          X.int <- cbind(1, X.int)
          colnames(X.int)[1] <- "Intercept"
        }
        X.t <- cbind(treat * 10000, (treat > 0) * X.int[, 
                                                        -1])
        X.t[treat != 0, -1] <- apply(X.t[treat != 0, 
                                         -1], 2, FUN = function(x) x - mean(x))
        colnames(X.t) <- c("treat", paste("treat", colnames(X.int)[-1], 
                                          sep = ":"))
      }
    }
    if (treat.type == "multiple") {
      if (make.allway == TRUE) {
        treat <- as.matrix(treat)
        Allway <- makeallway(treat, threshold, make.reference = FALSE, 
                             nway = nway)
        X.t <- Allway$FinalData
        X.t <- as.matrix(X.t)
        reference.main <- Allway$reference
      }
      if (make.allway == FALSE) {
        X.t <- as.matrix(treat)
        reference.main <- "No Reference"
      }
    }
    X.t <- X.t * wts^0.5
  }
  if (different == 1) {
    if (treat.type == "single") {
      coefs.orig <- object$coefs.orig
      names(coefs.orig) <- c(object$name.c, "treat", object$name.int[-1])
      coefs.orig.c <- coefs.orig[1:ncol(object$X.c)]
      coefs.orig.int <- coefs.orig[-c(1:ncol(object$X.c))]
      coefs.orig.c <- coefs.orig.c[is.element(object$name.c, 
                                              colnames(X.c))]
      coefs.orig.int <- coefs.orig.int[is.element(object$name.int, 
                                                  colnames(X.int))]
      col2.int <- colnames(X.int)
      col2.int <- gsub(":", ".", col2.int)
      col2.int <- gsub(" ", ".", col2.int)
      col2.int[1] <- "treat"
      coef.int.name <- names(coefs.orig.int)
      coef.int.name <- gsub(":", ".", coef.int.name)
      coef.int.name <- gsub(" ", ".", coef.int.name)
      equal.int <- (col2.int == coef.int.name)
      search.int <- seq(1:length(coefs.orig.int))[-which(equal.int)]
      coefs.orig.int.new <- c()
      for (i in 1:length(coefs.orig.int)) {
        if (equal.int[i] == TRUE) {
          coefs.orig.int.new[i] <- coefs.orig.int[i]
        }
        else {
          for (j in search.int) {
            if (col2.int[i] == coef.int.name[j]) {
              coefs.orig.int.new[i] <- coefs.orig.int[j]
            }
          }
        }
      }
      names(coefs.orig.int.new) <- colnames(X.int)
      names(coefs.orig.int.new)[1] <- "treat"
      col2.c <- colnames(X.c)
      col2.c <- gsub(":", ".", col2.c)
      col2.c <- gsub(" ", ".", col2.c)
      coef.c.name <- names(coefs.orig.c)
      coef.c.name <- gsub(":", ".", coef.c.name)
      coef.c.name <- gsub(" ", ".", coef.c.name)
      equal.c <- (col2.c == coef.c.name)
      search.c <- seq(1:length(coefs.orig.c))[-which(equal.c)]
      coefs.orig.c.new <- c()
      for (i in 1:length(coefs.orig.c)) {
        if (equal.c[i] == TRUE) {
          coefs.orig.c.new[i] <- coefs.orig.c[i]
        }
        else {
          for (j in search.c) {
            if (col2.c[i] == coef.c.name[j]) {
              coefs.orig.c.new[i] <- coefs.orig.c[j]
            }
          }
        }
      }
      names(coefs.orig.c.new) <- colnames(X.c)
      coefs2 <- c(coefs.orig.c.new, coefs.orig.int.new)
    }
    if (treat.type == "multiple") {
      coefs.orig <- object$coefs.orig
      if (main == FALSE) {
        names(coefs.orig) <- c("Intercept", object$name.t)
        coefs.orig.t <- coefs.orig[-1]
        coefs.orig.t <- coefs.orig.t[is.element(object$name.t, 
                                                colnames(X.t))]
        col2 <- colnames(X.t)
        col2 <- gsub(":", ".", col2)
        col2 <- gsub(" ", ".", col2)
        coef.t.name <- names(coefs.orig.t)
        coef.t.name <- gsub(":", ".", coef.t.name)
        coef.t.name <- gsub(" ", ".", coef.t.name)
        col2 <- col2[is.element(col2,coef.t.name)]
        equal <- (col2 == coef.t.name)
        search <- seq(1:length(coefs.orig.t))[-which(equal)]
        coefs.orig.new <- c()
        name.orig.new <- c()
        for (i in 1:length(coefs.orig.t)) {
          if (equal[i] == TRUE) {
            coefs.orig.new[i] <- coefs.orig.t[i]
            name.orig.new[i] <- names(coefs.orig.t)[i]
          }
          else {
            for (j in search) {
              if (col2[i] == coef.t.name[j]) {
                coefs.orig.new[i] <- coefs.orig.t[j]
                name.orig.new[i] <- names(coefs.orig.t)[j]
              }
            }
          }
        }
        coefs2s <- c(coefs.orig[1], coefs.orig.new)
        coefs.t2 <- coefs.orig.new
        names(coefs.t2) <- name.orig.new
        scale.out <- c(rep(1, length(coefs2s)))
        coefs2 <- coefs2s * scale.out
        names(coefs2) <- c("Intercept",name.orig.new)
      }
      if (main) {
        names(coefs.orig) <- c(object$name.c, object$name.t)
        coefs.orig.c <- coefs.orig[1:ncol(object$X.c)]
        coefs.orig.t <- coefs.orig[-c(1:ncol(object$X.c))]
        coefs.orig.c <- coefs.orig.c[is.element(object$name.c, 
                                                colnames(X.c))]
        coefs.orig.t <- coefs.orig.t[is.element(object$name.t, 
                                                colnames(X.t))]
        col2.t <- colnames(X.t)
        col2.t <- gsub(":", ".", col2.t)
        col2.t <- gsub(" ", ".", col2.t)
        coef.t.name <- names(coefs.orig.t)
        coef.t.name <- gsub(":", ".", coef.t.name)
        coef.t.name <- gsub(" ", ".", coef.t.name)
        col2.t <- col2.t[is.element(col2.t,coef.t.name)]
        equal.t <- (col2.t == coef.t.name)
        search.t <- seq(1:length(coefs.orig.t))[-which(equal.t)]
        coefs.orig.t.new <- c()
        name.orig.new <- c()
        for (i in 1:length(coefs.orig.t)) {
          if (equal.t[i] == TRUE) {
            coefs.orig.t.new[i] <- coefs.orig.t[i]
            name.orig.new[i] <- names(coefs.orig.t)[i]
          }
          else {
            for (j in search.t) {
              if (col2.t[i] == coef.t.name[j]) {
                coefs.orig.t.new[i] <- coefs.orig.t[j]
                name.orig.new[i] <- names(coefs.orig.t)[j]
              }
            }
          }
        }
        names(coefs.orig.t.new) <- name.orig.new
        col2.c <- colnames(X.c)
        col2.c <- gsub(":", ".", col2.c)
        col2.c <- gsub(" ", ".", col2.c)
        coef.c.name <- names(coefs.orig.c)
        coef.c.name <- gsub(":", ".", coef.c.name)
        coef.c.name <- gsub(" ", ".", coef.c.name)
        equal.c <- (col2.c == coef.c.name)
        search.c <- seq(1:length(coefs.orig.c))[-which(equal.c)]
        coefs.orig.c.new <- c()
        for (i in 1:length(coefs.orig.c)) {
          if (equal.c[i] == TRUE) {
            coefs.orig.c.new[i] <- coefs.orig.c[i]
          }
          else {
            for (j in search.c) {
              if (col2.c[i] == coef.c.name[j]) {
                coefs.orig.c.new[i] <- coefs.orig.c[j]
              }
            }
          }
        }
        names(coefs.orig.c.new) <- colnames(X.c)
        coefs2s <- c(coefs.orig.c.new, coefs.orig.t.new)
        scale.out <- c(1, scale.c, rep(1, dim(X.t)[2]))
        coefs2 <- coefs2s * scale.out
        coefs.t2 <- coefs.orig.t.new
        names(coefs.t2) <- name.orig.new
      }
      X.t <- X.t[, is.element(colnames(X.t), names(coefs.t2))]            
    }
  }
  if (main) {
    if (any(is.element(colnames(X.c), object$name.c) == FALSE)) {
      warning("The new data set has more variation\n                         than the original model.")
    }
    if (any(is.element(colnames(X.int), object$name.int) == 
            FALSE)) {
      warning("The new data set has more variation\n                     than the original model.")
    }
  }
  if (any(is.element(colnames(X.t), object$name.t) == FALSE)) {
    warning("The new data set has more variation than the original model.")
  }
  if (treat.type == "single") {
    scale.out <- c(1, scale.c, 1e-04, scale.int)
    coefs <- coefs2 * scale.out
    preds.treat <- cbind(X.c, cbind(10000, X.int[, -1])) %*% 
      coefs
    preds.control <- cbind(X.c, 0 * cbind(10000, X.int[, 
                                                       -1])) %*% coefs
    preds <- cbind(X.c, X.t) %*% coefs
    if (type == "binary") {
      preds.treat <- sign(preds.treat) * pmin(abs(preds.treat), 
                                              1)
      preds.control <- sign(preds.control) * pmin(abs(preds.control), 
                                                  1)
      preds <- sign(preds) * pmin(abs(preds), 1)
      ATE <- mean(preds.treat - preds.control)/2
      preds.diff <- (preds.treat - preds.control)/2
    }
    if (type == "continuous") {
      preds <- preds
      preds.diff <- preds.treat - preds.control
      ATE <- mean(preds.treat - preds.control)
    }
    if (different == 0) {
      pred.data <- cbind(preds.diff, y.orig, treat.orig, 
                         X.int.orig)
      pred.data <- as.data.frame(pred.data)
      colnames(pred.data) <- c("Treatment.effect", "outcome", 
                               "treatment", colnames(X.int.orig))
    }
    if (different == 1) {
      pred.data <- cbind(preds.diff, treat.orig, X.int.orig)
      pred.data <- as.data.frame(pred.data)
      colnames(pred.data) <- c("Treatment.effect", "treatment", 
                               colnames(X.int.orig))
      
      pred.data <- pred.data[new_id, , drop = FALSE]
    }
    if (sort == TRUE) {
      pred.data.out <- pred.data[order(pred.data$Treatment.effect, 
                                       decreasing = decreasing), ]
    }
    else {
      pred.data.out <- pred.data
    }
  }
  if (treat.type == "multiple") {
    if (different == 0) {
      if (main) {
        scale.out <- c(1, scale.c, rep(1, dim(X.t)[2]))
        coefs <- coefs2 * scale.out
      }
      if (main == FALSE) {
        scale.out <- c(1, rep(1, ncol(X.t)))
        coefs <- coefs2 * scale.out
      }
    }
    else {
      coefs <- coefs2
    }
    if (type == "binary") {
      if (main) {
        X.t1 <- as.matrix(X.t)
        preds.treat <- cbind(X.c, X.t1) %*% coefs
        preds.control <- X.c %*% coefs[c(1:ncol(X.c))]
        preds.treat <- sign(preds.treat) * pmin(abs(preds.treat), 
                                                1)
        preds.control <- sign(preds.control) * pmin(abs(preds.control), 
                                                    1)
        preds.diff <- (preds.treat - preds.control)/2
        
        Y_t <- (preds.treat/2 + 0.5)
        Y_c <- (preds.control/2 + 0.5)
      }
      if (main == FALSE) {                                
        preds.treat <- cbind(1, X.t) %*% coefs
        preds.treat <- sign(preds.treat) * pmin(abs(preds.treat), 1)
        # preds.diff <- preds.treat/2
        
        preds.control <- rep(1, nrow(X.t)) %*% coefs[1]
        preds.control <- sign(preds.control) *pmin(abs(preds.control), 1)
        
        preds.diff <- (preds.treat - preds.control)/2
        
        Y_t <- (preds.treat/2 + 0.5)
        Y_c <- (preds.control/2 + 0.5)
      }
      ATE <- mean(preds.diff)
    }
    if (type == "continuous") {
      if (main) {
        preds.diff <- X.t %*% coefs[-c(1:(dim(X.c)[2]))]
        
        Y_t <- cbind(X.c, X.t) %*% coefs
        Y_c <- X.c %*% coefs[c(1:ncol(X.c))]
      }
      if (main == FALSE) {
        preds.diff <- X.t %*% coefs[-1]
        
        Y_t <- cbind(1, X.t) %*% coefs
        Y_c <- rep(1, nrow(X.t)) %*% coefs[1]
      }
      ATE <- mean(preds.diff)
    }
    if (unique == FALSE) {
      preds.diff <- preds.diff
      pred.data <- cbind(preds.diff, treat.orig)
      colnames(pred.data) <- c("Treatment.effect", colnames(treat.orig))
      pred.data <- as.data.frame(pred.data)
      if (sort == TRUE) {
        pred.data.out <- pred.data[order(pred.data$Treatment.effect, 
                                         decreasing = decreasing), ]
      }
      else {
        pred.data.out <- pred.data
      }
    }
    if (unique == TRUE) {
      X.t <- as.matrix(X.t)
      X.t.u <- X.t
      preds.diff.u <- as.data.frame(preds.diff)
      rownames(preds.diff.u) <- rownames(X.t.u) <- seq(1:nrow(X.t.u))
      treat.unique <- unique(treat.orig, MARGIN = 1)
      X.t.unique <- unique(X.t.u, MARGIN = 1)
      X.t.unique2 <- unique(X.t, MARGIN = 1)
      preds.diff <- preds.diff.u[rownames(preds.diff.u) %in% 
                                   rownames(X.t.unique), ]
      preds.diff <- as.data.frame(preds.diff)
      rownames(X.t.unique) <- rownames(X.t.unique2)
      rownames(preds.diff) <- rownames(X.t.unique2)
      pred.data <- cbind(preds.diff, treat.unique)
      colnames(pred.data) <- c("Treatment.effect", colnames(treat.orig))
      pred.data <- as.data.frame(pred.data)
      pred.data.out <- pred.data[order(pred.data$Treatment.effect, 
                                       decreasing = decreasing), ]
    }
  }
  pred.data.out <- as.data.frame(pred.data.out)
  if (treat.type == "multiple"){
    internal <- list("Y_t" = Y_t, "Y_c" = Y_c)
  }else{
    internal<- NULL
  }
  out <- list(treat.type = treat.type, # ATE = ATE, 
              data = pred.data.out, 
              coefs = coefs, orig.coef = coefs2, internal = internal)
  class(out) <- "PredictFindIt"
  invisible(out)
}
kosukeimai/FindIt documentation built on March 19, 2024, 11:25 a.m.