R/mnps.fast.R

Defines functions mnps.fast

mnps.fast <- function(formula,
                      data,
                      # boosting options -- gbm version first, then xgboost name
                      n.trees=10000, nrounds=n.trees,
                      interaction.depth=3, max_depth=interaction.depth,
                      shrinkage=0.01, eta=shrinkage , 
                      bag.fraction = 1.0, subsample=bag.fraction,
                      params=NULL,
                      perm.test.iters=0,
                      print.level=2,       
                      verbose=TRUE,
                      estimand="ATE", 
                      stop.method = c("es.max"), 
                      sampw = NULL, 
                      version="gbm",
                      ks.exact=NULL,
                      tree_method="hist",
                      n.keep = 1,
                      n.grid = 25,
                      treatATT = NULL,
                      ...){
   
   stop.method <- levels(as.factor(stop.method))  ## alphbetizes for consitency in ordering of plots
   multinom <- TRUE
   
   if(is.null(sampw)) sampw <- rep(1, nrow(data))
   
   
   terms <- match.call()
   
   # all this is just to extract the variable names
   mf <- match.call(expand.dots = FALSE)
   
   
   m <- match(c("formula", "data"), names(mf), 0)
   mf <- mf[c(1, m)]
   mf[[1]] <- as.name("model.frame")
   mf$na.action <- na.pass
   #   mf$na.action <- NULL   
   #   mf$na.action <- "na.pass"
   mf$subset <- rep(FALSE, nrow(data)) # drop all the data
   
   mf <- eval(mf, parent.frame())
   Terms <- attr(mf, "terms")
   resp <- attr(mf, "response")
   var.names <- attributes(Terms)$term.labels
   treat.var <- as.character(formula[[2]])
   
   if(estimand == "ATT"){
      if(is.null(treatATT)) stop("Must specify the 'treated' condition via the treatATT argument \n
                                 when the estimand is set equal to ATT")
   }
   
   
   
   respAll <- model.frame(formula, data = data, na.action = na.pass)[,1]
   
   if(!is.factor(respAll)) stop("The treatment variable must be a factor variable with at least 3 levels.")
   
   respLev <- levels(respAll)
   M <- length(respLev)
   if(M < 3) stop("The treatment variable must be a factor variable with at least 3 levels.")
   
   levExceptTreatATT <- NULL
   
   if(estimand == "ATT"){
      if(!(treatATT %in% respLev)) stop("'treatATT' must be one of the levels of the treatment variable.")
      else levExceptTreatATT <- respLev[respLev != treatATT]
   }
   
   
   
   if(estimand == "ATE"){
      nFits <- M
      
      if(length(nrounds) == 1) nrounds <- rep(nrounds, nFits)
      
      hldFts <- vector(mode = "list", length = nFits)
      
      for(i in 1:nFits){
         ## fit GBMs, etc
         currResp <- as.numeric(respAll == respLev[i])
         currDat <- data.frame(currResp = currResp, data)
         currFormula <- update(formula, currResp ~ .)
         currPs <- ps.fast(formula = currFormula, data = currDat, n.trees = nrounds[i], interaction.depth = max_depth,
                      shrinkage = eta, bag.fraction = subsample, perm.test.iters = perm.test.iters, print.level = print.level, 
                      #iterlim = iterlim,
                      verbose = verbose, estimand = "ATE", stop.method = stop.method, sampw = sampw, multinom = TRUE, 
                      ks.exact=ks.exact,
                      version=version,
                      tree_method=tree_method,
                      n.keep = n.keep,
                      n.grid = n.grid,
                      keep.data = FALSE
                      )
      
         hldFts[[i]] <- currPs
         
      }
      names(hldFts) <- respLev	
   }
   
   if(estimand == "ATT"){
      nFits <- M - 1
      if(length(nrounds) == 1) nrounds <- rep(nrounds, nFits)
      
      hldFts <- vector(mode = "list", length = nFits)		
      for(i in 1:nFits){
         ## subset data, do 
         currDat <- data[respAll == treatATT | respAll == levExceptTreatATT[i], ]
         currResp <- respAll[respAll == treatATT | respAll == levExceptTreatATT[i]]
         sampwCurr <- sampw[respAll == treatATT | respAll == levExceptTreatATT[i]]
         #			currResp <- currResp == levExceptTreatATT[i]
         currResp <- currResp == treatATT			
         currDat <- data.frame(currResp = currResp, currDat)
         currFormula <- update(formula, currResp ~ .)
         currPs <- ps.fast(formula = currFormula, data = currDat, n.trees = nrounds[i], interaction.depth = max_depth,
                      shrinkage = eta, bag.fraction = subsample, perm.test.iters = perm.test.iters, print.level = print.level, 
                      #iterlim = iterlim,
                      verbose = verbose, estimand = "ATT", stop.method = stop.method, sampw = sampwCurr, multinom = TRUE,
                      ks.exact=ks.exact,
                      version=version,
                      tree_method=tree_method,
                      n.keep = n.keep,
                      n.grid = n.grid,
                      keep.data = FALSE
                      )
         
         hldFts[[i]] <- currPs
         
         
      }
      names(hldFts) <- levExceptTreatATT
   }
   
   returnObj <- list(psList = hldFts, nFits = nFits, estimand = estimand, treatATT = treatATT, treatLev = respLev, levExceptTreatATT = levExceptTreatATT, data = data, treatVar = respAll, treat.var = treat.var, stopMethods = stop.method, sampw = sampw , balanceVars=var.names, ks.exact=ks.exact)
   
   class(returnObj) <- "mnps"
   
   
   return(returnObj)
}

Try the twang package in your browser

Any scripts or data that you put into this service are public.

twang documentation built on May 29, 2024, 4:40 a.m.