## require(twang); data(AOD); mnps.AOD4 <- mnps2(treat ~ illact + crimjust + subprob + subdep + white, data = AOD, estimand = "ATT", stop.method = "es.max", n.trees = 1000, treatATT = "community")
mnps2 <- function(formula = formula(data), data, n.trees = 10000,
interaction.depth = 3,
shrinkage = 0.01,
bag.fraction = 1.0,
perm.test.iters = 0,
print.level = 2,
iterlim = 1000,
verbose =TRUE,
estimand = "ATE",
stop.method = "es.max",
sampw = NULL,
treatATT = NULL, ...){
stop.method <- levels(as.factor(stop.method)) ## alphbetizes for consitency in ordering of plots
origStopMeth <- stop.method
multinom <- TRUE
if(is.null(sampw)) sampw <- rep(1, nrow(data))
data <- data.frame(data, sampw=sampw)
terms <-
# all this is just to extract the variable names
mf <- = FALSE)
m <- match(c("formula", "data"), names(mf), 0)
mf <- mf[c(1, m)]
mf[[1]] <-"model.frame")
mf$na.action <- na.pass
# mf$na.action <- NULL
# mf$na.action <- "na.pass"
mf$subset <- rep(FALSE, nrow(data)) # drop all the data
mf <- eval(mf, parent.frame())
Terms <- attr(mf, "terms")
resp <- attr(mf, "response")
var.names <- attributes(Terms)$term.labels
treat.var <- as.character(formula[[2]])
allowableStopMethods <- c("ks.mean", "es.mean","ks.max", "es.max")
multFit <- gbm(formula = formula, data = data, weights = sampw, n.trees = n.trees,
interaction.depth = interaction.depth, shrinkage = shrinkage, bag.fraction=bag.fraction,
verbose = verbose, distribution = "multinomial")
nMethod <- length(stop.method)
methodList <- vector(mode="list", length = nMethod)
for(i in 1:nMethod){
if (!(stop.method[i] %in% allowableStopMethods)){
stop("Each element of stop.method must either be one of \nthe above character strings, or an object of the stop.method class.")
methodList[[i]] <- get(stop.method[i])
methodName <- paste(stop.method[i], ".", estimand, sep="")
methodList[[i]]$name <- methodName
else {
if (!inherits(stop.method[i], "stop.method")){
stop("Each element of stop.method must either be one of \nthe above character strings, or an object of the stop.method class.")
methodList[[i]] <- stop.method[i]
stop.method <- methodList
w <- ps <- data.frame(matrix(NA, nrow = nrow(data), ncol = nMethod))
stop.method.names <- sapply(stop.method,function(x){x$name})
names(w) <- names(ps) <- stop.method.names
balance <- matrix(NA, ncol = nMethod, nrow = 25)
respAll <- model.frame(formula, data = data, na.action = na.pass)[,1]
if(!is.factor(respAll)) stop("The treatment variable must be a factor variable with at least 3 levels.")
if(length(levels(respAll))<3) stop("The treatment variable must be a factor variable with at least 3 levels.")
respLev <- levels(respAll)
M <- length(respLev)
levExceptTreatATT <- NULL
if(estimand == "ATT"){
if(!(treatATT %in% respLev)) stop("'treatATT' must be one of the levels of the treatment variable.")
else levExceptTreatATT <- respLev[respLev != treatATT]
dummyPS <- list(gbm.obj = NA,treat = NULL, treat.var = treat.var,
desc = NULL, ps = NULL, w = NULL, sampw = NULL, estimand = estimand,
datestamp = date(), parameters = NULL, alerts = NULL,
iters = NA, balance = NULL, n.trees = NA, data = NA)
desc <- vector(mode = 'list', length = nMethod + 1)
dummyPS$desc <- desc
nFits <- ifelse(estimand == "ATT", M-1, M)
mnpsObj <- list(psList = vector(mode = 'list', length = nFits), nFits = nFits, estimand = estimand,
treatATT = treatATT, treatLev = respLev,
levExceptTreatATT = levExceptTreatATT, data = data, treatVar = respAll, treat.var = treat.var, stopMethods = stop.method,
sampw = sampw)
dummyPS$ps <- dummyPS$w <- matrix(NA, nrow = nrow(data), ncol = nMethod)
class(dummyPS) <- "ps"
for(i in 1:nFits) mnpsObj$psList[[i]] <- dummyPS
for( in 1:nMethod){
tp <- stop.method.names[]
if(tp %in% c("es.max.ATT","es.mean.ATT","es.max.ATE","es.mean.ATE")) ES <- TRUE
else ES <- FALSE
if(tp %in% c("es.max.ATT","ks.max.ATT","es.max.ATE","ks.max.ATE")) summaryFcn <- max
else summaryFcn <- mean
compFcn <- max
if(estimand == "ATT"){
if(is.null(treatATT)) stop("Must specify the 'treated' condition via the treatATT argument \n
when the estimand is set equal to ATT")
iters <- round(seq(1,multFit$n.trees,length=25))
bal <- rep(0, length(iters))
bal[1] <- pairwiseComp(wtVec = rep(1,nrow(data)), data = data, treat.var = treat.var,
vars = var.names, sampW = sampw, estimand = estimand,
compFcn =, summaryFcn =, treatATT = treatATT)[[1+ES]]
for(j in 2:length(iters)){
hldComp <- pairwiseComp(gbm1 = multFit, data = data, treat.var = treat.var,
vars = var.names, sampW = sampw, i = iters[j], estimand = estimand,
compFcn =, summaryFcn =, treatATT = treatATT)
bal[j] <- hldComp[[1+ES]]
#bal[i] <- max(apply(hldComp[[1+ES]], 1,
balance[,] <- bal
interval <- which.min(bal) +c(-1,1)
interval[1] <- max(1,interval[1])
interval[2] <- min(length(iters),interval[2])
maximum = FALSE,
tol = 1,
fun =,
compFcn =,
summaryFcn =,
vars = var.names,
treat.var = treat.var,
data = data,
sampW = sampw,
# na.action = stop.method[[]]$na.action,
na.action = "level",
gbm1 = multFit,
estimand = estimand,
treatATT = treatATT,
onlyES = (ES == 1), onlyKS = (ES == 0))
if(verbose) cat(" Optimized at",round(opt$minimum),"\n")
if(multFit$n.trees-opt$minimum < 100) warning("Optimal number of iterations is close to the specified n.trees. n.trees is likely set too small and better balance might be obtainable by setting n.trees to be larger.")
fittedProbs <- predict(multFit, n.trees = opt$minimum, type = "response")[,,1]
ps[,] <- fittedProbs[cbind(1:nrow(fittedProbs), as.numeric(respAll))]
if(estimand == "ATT") w[,] <- makeWeightsMNPS(ps[,], estimand = "ATT", sampW = sampw,
treatATT = (respAll == treatATT), treatATTps = fittedProbs[,treatATT])
else w[,] <- makeWeightsMNPS(ps[,], estimand = "ATE")
if(estimand == "ATE")
for(i in 1:M){
tempDt <- data
tempDt[,treat.var] <- data[,treat.var] == respLev[i]
unwDesc <- desc.wts(tempDt,c(treat.var, var.names),
treat.var = treat.var,
w = 1 + 0 * w[,],
sampw = sampw,
tp = "Make descriptions",
na.action = stop.method[[]]$na.action,
alerts.stack = NULL,
estimand = estimand,
multinom = TRUE)
unwDesc$n.trees <- NA
hldDesc <- desc.wts(tempDt,c(treat.var, var.names),
treat.var = treat.var,
w = w[,],
sampw = sampw,
tp = "Make descriptions",
na.action = stop.method[[]]$na.action,
alerts.stack = NULL,
estimand = estimand,
multinom = TRUE)
hldDesc$n.trees <- round(opt$minimum)
mnpsObj$psList[[i]]$gbm.obj <- NULL
mnpsObj$psList[[i]]$treat <- tempDt[,treat.var]
mnpsObj$psList[[i]]$treat.var <- treat.var
mnpsObj$psList[[i]]$desc[[1]] <- unwDesc
mnpsObj$psList[[i]]$desc[[ + 1]] <- hldDesc
mnpsObj$psList[[i]]$ps[,] = fittedProbs[, i]
mnpsObj$psList[[i]]$w[,] = w[,]
mnpsObj$psList[[i]]$sampw = sampw
mnpsObj$psList[[i]]$iters = iters
mnpsObj$psList[[i]]$balance = balance
mnpsObj$psList[[i]]$n.trees = round(opt$minimum)
mnpsObj$psList[[i]]$data = NULL
names(mnpsObj$psList[[i]]$desc) <- c("unw", stop.method.names)
if(estimand == "ATT")
for(i in 1:(M-1)){
tempDt <- data
tempDt[,treat.var] <- data[,treat.var] == treatATT
currentSubset <- data[,treat.var] %in% c(treatATT, levExceptTreatATT[i])
tempDt <- subset(tempDt, currentSubset)
unwDesc <- desc.wts(tempDt,c(var.names),
treat.var = treat.var,
w = 1 + 0 * w[currentSubset,],
sampw = sampw[currentSubset],
tp = "Make descriptions",
na.action = stop.method[[]]$na.action,
alerts.stack = NULL,
estimand = estimand,
multinom = TRUE)
unwDesc$n.trees <- NA
hldDesc <- desc.wts(tempDt,c(var.names),
treat.var = treat.var,
w = w[currentSubset,],
sampw = sampw[currentSubset],
tp = "Make descriptions",
na.action = stop.method[[]]$na.action,
alerts.stack = NULL,
estimand = estimand,
multinom = TRUE)
hldDesc$n.trees <- round(opt$minimum)
if( == 1){
mnpsObj$psList[[i]]$ps <- mnpsObj$psList[[i]]$w <- data.frame(matrix(NA, nrow = sum(currentSubset), ncol = nMethod))
names(mnpsObj$psList[[i]]$ps) <- names(mnpsObj$psList[[i]]$w) <- stop.method.names
mnpsObj$psList[[i]]$sampw <- rep(NA, sum(currentSubset))
mnpsObj$psList[[i]]$gbm.obj <- NULL
mnpsObj$psList[[i]]$treat <- tempDt[,treat.var]
mnpsObj$psList[[i]]$treat.var <- treat.var
mnpsObj$psList[[i]]$desc[[1]] <- unwDesc
mnpsObj$psList[[i]]$desc[[ + 1]] <- hldDesc
mnpsObj$psList[[i]]$ps[,] <- fittedProbs[currentSubset, i]
mnpsObj$psList[[i]]$w[,] <- w[currentSubset,]
mnpsObj$psList[[i]]$sampw <- sampw[currentSubset]
mnpsObj$psList[[i]]$iters <- iters
mnpsObj$psList[[i]]$balance <- balance
mnpsObj$psList[[i]]$n.trees <- round(opt$minimum)
mnpsObj$psList[[i]]$data <- data
names(mnpsObj$psList[[i]]$desc) <- c("unw", stop.method.names)
origStopMethLong <- origStopMeth
for(i in 1:length(origStopMeth)) origStopMethLong[i] <- paste(origStopMeth[i], ".", estimand, sep = "")
for(i in 1:length(mnpsObj$psList)){
mnpsObj$psList[[i]]$ps <- data.frame(mnpsObj$psList[[i]]$ps)
names(mnpsObj$psList[[i]]$ps) <- origStopMethLong
returnObj <- mnpsObj
returnObj$stopMethods <- origStopMeth
returnObj$gbm.obj <- multFit
class(returnObj) <- "mnps"
#mnps(formula = treat ~ age + educ, data = lalonde2, estimand = "ATT", treatATT = "0")
#ft1 <- mnps(formula = treat ~ age + educ, data = lalonde2, estimand = "ATT", treatATT = "0")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.