#User-defined weighting function
weightit2user <- function(Fun, covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, ps, moments, int, ...) {
A <- list(...)
if (is_not_null(covs)) {
covs <- covs[subset, , drop = FALSE]
}
if (is_not_null(treat)) {
treat <- treat[subset]
}
if (is_not_null(s.weights)) {
s.weights <- s.weights[subset]
}
if (is_not_null(ps)) {
ps <- ps[subset]
}
#Get a list of function args for the user-defined function Fun
Fun_formal <- as.list(formals(Fun))
if (has_dots <- ("..." %in% names(Fun_formal))) {
Fun_formal[["..."]] <- NULL
}
fun_args <- Fun_formal
for (i in names(fun_args)) {
if (exists(i, inherits = FALSE)) fun_args[i] <- list(get0(i, inherits = FALSE))
else if (i %in% names(A)) {
fun_args[i] <- A[i]
A[[i]] <- NULL
}
#else just use Fun default
}
if (has_dots) fun_args <- c(fun_args, A)
obj <- do.call(Fun, fun_args)
if (is.numeric(obj)) {
obj <- list(w = obj)
}
else if (!is.list(obj) || !any(c("w", "weights") %nin% names(obj))) {
stop("The output of the user-provided function must be a list with an entry named \"w\" or \"weights\" containing the estimated weights.", call. = FALSE)
}
else {
names(obj)[names(obj) == "weights"] <- "w"
}
if (is_null(obj[["w"]])) stop("No weights were estimated.", call. = FALSE)
if (!is.vector(obj[["w"]], mode = "numeric")) stop("The \"w\" or \"weights\" entry of the output of the user-provided function must be a numeric vector of weights.", call. = FALSE)
if (all(is.na(obj[["w"]]))) stop("All weights were generated as NA.", call = FALSE)
if (length(obj[["w"]]) != length(treat)) {
stop(paste(length(obj[["w"]]), "weights were estimated, but there are", length(treat), "units."), call. = FALSE)
}
return(obj)
}
weightitMSM2user <- function(Fun, covs.list, treat.list, s.weights, subset, stabilize, missing, moments, int, ...) {
A <- list(...)
if (is_not_null(covs.list)) {
covs.list <- covs <- lapply(covs.list, function(c) c[subset, , drop = FALSE])
}
if (is_not_null(treat.list)) {
treat.list <- treat <- lapply(treat.list, function(t) t[subset])
}
if (is_not_null(s.weights)) {
s.weights <- s.weights[subset]
}
#Get a list of function args for the user-defined function Fun
Fun_formal <- as.list(formals(Fun))
if (has_dots <- any(names(Fun_formal) == "...")) {
Fun_formal[["..."]] <- NULL
}
fun_args <- Fun_formal
for (i in names(fun_args)) {
if (exists(i, inherits = FALSE)) fun_args[i] <- list(get0(i, inherits = FALSE))
else if (is_not_null(A[[i]])) {
fun_args[[i]] <- A[[i]]
A[[i]] <- NULL
}
#else just use Fun default
}
if (has_dots) fun_args <- c(fun_args, A)
obj <- do.call(Fun, fun_args)
if (is.numeric(obj)) {
obj <- list(w = obj)
}
else if (!is.list(obj) || !any(c("w", "weights") %nin% names(obj))) {
stop("The output of the user-provided function must be a list with an entry named \"w\" or \"weights\" containing the estimated weights.", call. = FALSE)
}
else {
names(obj)[names(obj) == "weights"] <- "w"
}
if (is_null(obj[["w"]])) stop("No weights were estimated.", call. = FALSE)
if (!is.vector(obj[["w"]], mode = "numeric")) stop("The \"w\" or \"weights\" entry of the output of the user-provided function must be a numeric vector of weights.", call. = FALSE)
if (all(is.na(obj[["w"]]))) stop("All weights were generated as NA.", call = FALSE)
if (length(obj[["w"]]) != length(treat.list[[1]])) {
stop(paste(length(obj[["w"]]), "weights were estimated, but there are", length(treat.list[[1]]), "units."), call. = FALSE)
}
return(obj)
}
#Propensity score estimation with regression
# weightit2ps <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, ps, .data, ...) {
weightit2ps <- function(formula, data, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, ps, ...) {
A <- list(...)
fit.obj <- NULL
bin.treat <- get_treat_type(treat) == "binary"
ord.treat <- is.ordered(treat)
treat.name <- deparse1(attr(terms(update(formula, . ~ 0)), "variables")[[2]])
if (missing == "ind") {
formula <- add_miss_ind_to_formula(formula, data)
}
data[[treat.name]] <- treat
# if (ncol(covs) > 1) {
# if (missing == "saem") {
# covs0 <- covs
# for (i in colnames(covs)[anyNA_col(covs)]) covs0[is.na(covs0[,i]),i] <- covs0[!is.na(covs0[,i]),i][1]
# colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs0))]
# }
# else colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
# covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
# }
#Process link
if (ord.treat) acceptable.links <- c("logit", "probit", "loglog", "cloglog", "cauchit")
else if (bin.treat || isFALSE(A$use.mlogit)) {
if (missing == "saem") acceptable.links <- "logit"
else acceptable.links <- expand_grid_string(c("", "br."), c("logit", "probit", "cloglog", "identity", "log", "cauchit"))
}
else acceptable.links <- c("logit", "probit", "bayes.probit", "br.logit")
if (is_null(A$link)) A$link <- acceptable.links[1]
else {
which.link <- acceptable.links[pmatch(A$link, acceptable.links, nomatch = 0)][1]
if (is.na(which.link)) {
A$link <- acceptable.links[1]
stop(paste0("Only ", word_list(acceptable.links, quotes = TRUE, is.are = TRUE), " allowed as the link for ",
if (bin.treat) "binary" else if (ord.treat) "ordinal" else "multinomial",
" treatments", if (missing == "saem") " with missing = \"saem\"", "."),
call. = FALSE)
}
else A$link <- which.link
}
use.br <- startsWith(A$link, "br.")
# use.bayes <- startsWith(A$link, "bayes.")
if (use.br) A$link <- substr(A$link, 4, nchar(A$link))
# else if (use.bayes) A$link <- substr(A$link, 7, nchar(A$link))
if (bin.treat) {
t.lev <- get_treated_level(treat)
ps <- make_df(levels(treat), nrow = length(subset))
data[[treat.name]] <- NA_integer_
data[[treat.name]][subset] <- binarize(treat[subset], one = t.lev)
# if (missing == "saem") {
# check.package("misaem")
#
# data <- data.frame(treat, covs)
#
# withCallingHandlers({
# fit <- misaem::miss.glm(formula(data), data = data, control = as.list(A[["control"]]))
# },
# warning = function(w) {
# if (conditionMessage(w) != "one argument not used by format '%i '") warning(w)
# invokeRestart("muffleWarning")
# })
#
# if (is_null(A[["saem.method"]])) A[["saem.method"]] <- "map"
#
# ps[[t.lev]] <- p.score <- drop(predict(fit, newdata = covs, method = A[["saem.method"]]))
# ps[[c.lev]] <- 1 - ps[[t.lev]]
# }
# else {
if (use.br) {
check.package("brglm2")
ctrl_fun <- brglm2::brglmControl
glm_method <- brglm2::brglmFit
family <- binomial(link = A[["link"]])
}
else {
ctrl_fun <- stats::glm.control
glm_method <- if_null_then(A[["glm.method"]], stats::glm.fit)
family <- quasibinomial(link = A[["link"]])
}
control <- do.call(ctrl_fun, c(A[["control"]],
A[setdiff(names(formals(ctrl_fun))[pmatch(names(A), names(formals(ctrl_fun)), 0)],
names(A[["control"]]))]))
start <- mustart <- NULL
if (family$link %in% c("log", "identity")) {
#Need starting values because links are unbounded
start <- c(family$linkfun(w.m(treat[subset], s.weights[subset])), rep(0, ncol(covs)))
}
else {
#Default starting values from glm.fit() without weights; these
#work better with s.weights than usual default.
mustart <- .25 + .5*treat[subset]
}
withCallingHandlers({
fit <- do.call(stats::glm, list(formula, data = data,
weights = s.weights,
mustart = mustart,
start = start,
family = family,
method = glm_method,
subset = subset,
control = control), quote = TRUE)
},
warning = function(w) {
if (conditionMessage(w) != "non-integer #successes in a binomial glm!") warning(w)
invokeRestart("muffleWarning")
})
ps[[t.lev]] <- p.score <- fit$fitted.values
ps[[names(ps) != t.lev]] <- 1 - ps[[t.lev]]
# }
fit[["call"]] <- NULL
fit.obj <- fit
}
else if (ord.treat) {
check.package("MASS")
if (A[["link"]] == "logit") A[["link"]] <- "logistic"
# message(paste("Using ordinal", A$link, "regression."))
tryCatch({fit <- do.call(MASS::polr,
list(formula,
data = data,
weights = s.weights,
Hess = FALSE,
model = FALSE,
method = A[["link"]],
subset = subset,
contrasts = NULL), quote = TRUE)},
error = function(e) {stop(paste0("There was a problem fitting the ordinal ", A$link, " regressions with polr().\n Try again with an un-ordered treatment.",
"\n Error message: ", conditionMessage(e)), call. = FALSE)})
ps <- fit$fitted.values
fit.obj <- fit
p.score <- NULL
}
else {
if (use.br) {
check.package("brglm2")
ctrl_fun <- brglm2::brglmControl
control <- do.call(ctrl_fun, c(A[["control"]],
A[setdiff(names(formals(ctrl_fun))[pmatch(names(A), names(formals(ctrl_fun)), 0)],
names(A[["control"]]))]))
tryCatch({fit <- do.call(brglm2::brmultinom,
list(formula, data,
weights = s.weights,
subset = subset,
control = control), quote = TRUE)},
error = function(e) stop("There was a problem with the bias-reduced multinomial logit regression. Try a different link.", call. = FALSE))
ps <- fit$fitted.values
fit.obj <- fit
}
else if (A$link %in% c("logit", "probit")) {
if (!isFALSE(A$use.mclogit)) {
check.package("mclogit")
data[[".s.weights"]] <- s.weights
if (is_not_null(A[["random"]])) {
ctrl_fun <- mclogit::mmclogit.control
}
else {
ctrl_fun <- mclogit::mclogit.control
}
control <- do.call(ctrl_fun, c(A[["control"]],
A[setdiff(names(formals(ctrl_fun))[pmatch(names(A), names(formals(ctrl_fun)), 0)],
names(A[["control"]]))]))
tryCatch({
fit <- do.call(mclogit::mblogit, list(formula,
data = data,
weights = quote(.s.weights),
random = A[["random"]],
subset = subset,
method = A[["mclogit.method"]],
estimator = if_null_then(A[["estimator"]], eval(formals(mclogit::mclogit)[["estimator"]])),
dispersion = if_null_then(A[["dispersion"]], eval(formals(mclogit::mclogit)[["dispersion"]])),
groups = A[["groups"]],
control = control))
},
error = function(e) {stop(paste0("There was a problem fitting the multinomial ", A$link, " regression with mblogit().\n Try again with use.mclogit = FALSE.\nError message (from mclogit):\n ", conditionMessage(e)), call. = FALSE)}
)
ps <- fitted(fit)
colnames(ps) <- levels(treat)
fit.obj <- fit
}
else {
ps <- make_df(levels(treat), nrow = length(subset))
ctrl_fun <- stats::glm.control
control <- do.call(ctrl_fun, c(A[["control"]],
A[setdiff(names(formals(ctrl_fun))[pmatch(names(A), names(formals(ctrl_fun)), 0)],
names(A[["control"]]))]))
fit.list <- make_list(levels(treat))
for (i in levels(treat)) {
if (isTRUE(A[["test1"]])) {
if (i == last(levels(treat))) {
ps[[i]] <- 1 - rowSums(ps[names(ps) != i])
next
}
}
form_i <- update(formula, sprintf("I(%s == '%s') ~ ."),
treat.name, i)
fit.list[[i]] <- do.call(stats::glm, list(form_i, data = data,
family = quasibinomial(link = A$link),
subset = subset,
weights = s.weights,
control = control), quote = TRUE)
ps[[i]] <- fit.list[[i]]$fitted.values
}
if (isTRUE(A[["test2"]])) ps <- ps/rowSums(ps)
fit.obj <- fit.list
}
}
p.score <- NULL
}
#ps should be matrix of probs for each treat
#Computing weights
w <- get_w_from_ps(ps = ps, treat = treat, estimand, focal = focal,
stabilize = stabilize, subclass = subclass)
obj <- list(w = w, ps = p.score, fit.obj = fit.obj)
return(obj)
}
weightit2ps.cont <- function(covs, treat, s.weights, subset, stabilize, missing, ps, ...) {
A <- list(...)
fit.obj <- NULL
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (ncol(covs) > 1) {
if (missing == "saem") {
covs0 <- covs
for (i in colnames(covs)[anyNA_col(covs)]) covs0[is.na(covs0[,i]),i] <- covs0[!is.na(covs0[,i]),i][1]
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs0))]
}
else colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
}
data <- data.frame(treat, covs)
formula <- formula(data)
#Process density params
if (isTRUE(A[["use.kernel"]])) {
if (is_null(A[["bw"]])) A[["bw"]] <- "nrd0"
if (is_null(A[["adjust"]])) A[["adjust"]] <- 1
if (is_null(A[["kernel"]])) A[["kernel"]] <- "gaussian"
if (is_null(A[["n"]])) A[["n"]] <- 10*length(treat)
use.kernel <- TRUE
densfun <- NULL
}
else {
if (is_null(A[["density"]])) densfun <- dnorm
else if (is.function(A[["density"]])) densfun <- A[["density"]]
else if (is.character(A[["density"]]) && length(A[["density"]] == 1)) {
splitdens <- strsplit(A[["density"]], "_", fixed = TRUE)[[1]]
if (is_not_null(splitdens1 <- get0(splitdens[1], mode = "function", envir = parent.frame()))) {
if (length(splitdens) > 1 && !can_str2num(splitdens[-1])) {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
word_list(splitdens[-1], and.or = "or", quotes = TRUE), "cannot be coerced to numeric."), call. = FALSE)
}
densfun <- function(x) {
tryCatch(do.call(splitdens1, c(list(x), as.list(str2num(splitdens[-1])))),
error = function(e) stop(paste0("Error in applying density:\n ", conditionMessage(e)), call. = FALSE))
}
}
else {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
splitdens[1], "is not an available function."), call. = FALSE)
}
}
else stop("The argument to 'density' cannot be evaluated as a density function.", call. = FALSE)
use.kernel <- FALSE
}
#Stabilization - get dens.num
p.num <- treat - mean(treat)
if (use.kernel) {
d.n <- density(p.num, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]], kernel = A[["kernel"]])
dens.num <- with(d.n, approxfun(x = x, y = y))(p.num)
}
else {
dens.num <- densfun(p.num/sd(treat))
if (is_null(dens.num) || !is.atomic(dens.num) || anyNA(dens.num)) {
stop("There was a problem with the output of density. Try another density function or leave it blank to use the normal density.", call. = FALSE)
}
else if (any(dens.num <= 0)) {
stop("The input to density may not accept the full range of treatment values.", call. = FALSE)
}
}
#Estimate GPS
if (is_null(ps)) {
if (missing == "saem") {
check.package("misaem")
withCallingHandlers({
fit <- misaem::miss.lm(formula, data, control = as.list(A[["control"]]))
},
warning = function(w) {
if (conditionMessage(w) != "one argument not used by format '%i '") warning(w)
invokeRestart("muffleWarning")
})
if (is_null(A[["saem.method"]])) A[["saem.method"]] <- "map"
gp.score <- drop(predict(fit, newdata = covs, method = A[["saem.method"]]))
}
else {
if (is_null(A[["link"]])) A[["link"]] <- "identity"
else {
if (missing == "saem") acceptable.links <- "identity"
else acceptable.links <- c("identity", "log", "inverse")
which.link <- acceptable.links[pmatch(A[["link"]], acceptable.links, nomatch = 0)][1]
if (is.na(which.link)) {
A[["link"]] <- acceptable.links[1]
stop(paste0("Only ", word_list(acceptable.links, quotes = TRUE, is.are = TRUE),
" allowed as the link for continuous treatments",
if (missing == "saem") " with missing = \"saem\"", "."),
call. = FALSE)
}
else A[["link"]] <- which.link
}
fit <- do.call("glm", c(list(formula, data = data,
weights = s.weights,
family = gaussian(link = A[["link"]]),
control = as.list(A$control))),
quote = TRUE)
gp.score <- fit$fitted.values
}
fit.obj <- fit
}
#Get weights
w <- get_cont_weights(gp.score, treat = treat, s.weights = s.weights,
dens.num = dens.num, densfun = densfun,
use.kernel = use.kernel, densControl = A)
if (use.kernel && isTRUE(A[["plot"]])) {
d.d <- density(treat - gp.score, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]],
kernel = A[["kernel"]])
plot_density(d.n, d.d)
}
obj <- list(w = w, fit.obj = fit.obj)
return(obj)
}
#MABW with optweight
weightit2optweight <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
A <- list(...)
check.package("optweight")
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
new.data <- data.frame(treat, covs)
new.formula <- formula(new.data)
for (f in names(formals(optweight::optweight))) {
if (is_null(A[[f]])) A[[f]] <- formals(optweight::optweight)[[f]]
}
A[names(A) %in% names(formals(weightit2optweight))] <- NULL
if ("tols" %in% names(A)) A[["tols"]] <- optweight::check.tols(new.formula, new.data, A[["tols"]], stop = TRUE)
if ("targets" %in% names(A)) {
warning("targets cannot be used through WeightIt and will be ignored.", call. = FALSE, immediate. = TRUE)
A[["targets"]] <- NULL
}
A[["formula"]] <- new.formula
A[["data"]] <- new.data
A[["estimand"]] <- estimand
A[["s.weights"]] <- s.weights
A[["focal"]] <- focal
A[["verbose"]] <- TRUE
out <- do.call(optweight::optweight, A, quote = TRUE)
obj <- list(w = out[["weights"]], info = list(duals = out$duals), fit.obj = out)
return(obj)
}
weightit2optweight.cont <- function(covs, treat, s.weights, subset, missing, moments, int, ...) {
A <- list(...)
check.package("optweight")
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
new.data <- data.frame(treat, covs)
new.formula <- formula(new.data)
for (f in names(formals(optweight::optweight))) {
if (is_null(A[[f]])) A[[f]] <- formals(optweight::optweight)[[f]]
}
A[names(A) %in% names(formals(weightit2optweight.cont))] <- NULL
if ("tols" %in% names(A)) A[["tols"]] <- optweight::check.tols(new.formula, new.data, A[["tols"]], stop = TRUE)
if ("targets" %in% names(A)) {
warning("targets cannot be used through WeightIt and will be ignored.", call. = FALSE, immediate. = TRUE)
A[["targets"]] <- NULL
}
A[["formula"]] <- new.formula
A[["data"]] <- new.data
A[["s.weights"]] <- s.weights
A[["verbose"]] <- TRUE
out <- do.call(optweight::optweight, A, quote = TRUE)
obj <- list(w = out[["weights"]], info = list(duals = out$duals), fit.obj = out)
return(obj)
}
weightit2optweight.msm <- function(covs.list, treat.list, s.weights, subset, missing, moments, int, ...) {
A <- list(...)
check.package("optweight")
if (is_not_null(covs.list)) {
covs.list <- lapply(covs.list, function(c) {
covs <- c[subset, , drop = FALSE]
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
return(covs)
})
}
if (is_not_null(treat.list)) {
treat.list <- lapply(treat.list, function(t) {
treat <- t[subset]
if (get_treat_type(t) != "continuous") treat <- factor(treat)
return(treat)
})
}
if (is_not_null(s.weights)) {
s.weights <- s.weights[subset]
}
baseline.data <- data.frame(treat.list[[1]], covs.list[[1]])
baseline.formula <- formula(baseline.data)
if ("tols" %in% names(A)) A[["tols"]] <- optweight::check.tols(baseline.formula, baseline.data, A[["tols"]], stop = TRUE)
if ("targets" %in% names(A)) {
warning("targets cannot be used through WeightIt and will be ignored.", call. = FALSE, immediate. = TRUE)
A[["targets"]] <- NULL
}
out <- do.call(optweight::optweight.fit, c(list(treat = treat.list,
covs = covs.list,
s.weights = s.weights,
verbose = TRUE),
A), quote = TRUE)
obj <- list(w = out$w, fit.obj = out)
return(obj)
}
#Generalized boosted modeling with gbm and cobalt
weightit2gbm <- function(covs, treat, s.weights, estimand, focal, subset, stabilize, subclass, missing, ...) {
check.package("gbm")
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
colnames(missing.ind) <- paste0(colnames(missing.ind), ":<NA>")
covs <- cbind(covs, missing.ind)
}
}
if (is_null(A[["stop.method"]])) {
warning("No stop.method was provided. Using \"es.mean\".",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- "es.mean"
}
else if (length(A[["stop.method"]]) > 1) {
warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- A[["stop.method"]][1]
}
cv <- 0
available.stop.methods <- bal_criterion(treat.type, list = TRUE)
s.m.matches <- charmatch(A[["stop.method"]], available.stop.methods)
if (is.na(s.m.matches) || s.m.matches == 0L) {
if (startsWith(A[["stop.method"]], "cv") && can_str2num(numcv <- substr(A[["stop.method"]], 3, nchar(A[["stop.method"]])))) {
cv <- round(str2num(numcv))
if (cv < 2) stop("At least 2 CV-folds must be specified in stop.method.", call. = FALSE)
}
else stop(paste0("'stop.method' must be one of ", word_list(c(available.stop.methods, "cv{#}"), "or", quotes = TRUE), "."), call. = FALSE)
}
else stop.method <- available.stop.methods[s.m.matches]
tunable <- c("interaction.depth", "shrinkage", "distribution")
trim.at <- if_null_then(A[["trim.at"]], 0)
for (f in names(formals(gbm::gbm.fit))) {
if (is_null(A[[f]])) {
if (f %in% c("x", "y", "misc", "w", "verbose", "var.names",
"response.name", "group", "distribution")) A[f] <- list(NULL)
else A[f] <- list(switch(f, n.trees = 1e4,
interaction.depth = 3,
shrinkage = .01,
bag.fraction = 1,
keep.data = FALSE,
formals(gbm::gbm.fit)[[f]]))
}
}
if (treat.type == "binary") {
available.distributions <- c("bernoulli", "adaboost")
t.lev <- get_treated_level(treat)
treat <- binarize(treat, one = focal)
}
else {
available.distributions <- "multinomial"
treat <- factor(treat)
}
if (cv == 0) {
start.tree <- if_null_then(A[["start.tree"]], 1)
if (is_null(A[["n.grid"]])) n.grid <- round(1+sqrt(2*(A[["n.trees"]]-start.tree+1)))
else if (!is_(A[["n.grid"]], "numeric") || length(A[["n.grid"]]) > 1 ||
!between(A[["n.grid"]], c(2, A[["n.trees"]]))) {
stop("'n.grid' must be a numeric value between 2 and n.trees.", call. = FALSE)
}
else n.grid <- round(A[["n.grid"]])
crit <- bal_criterion(treat.type, stop.method)
init <- crit$init(covs, treat, estimand = estimand, s.weights = s.weights, focal = focal, ...)
}
A[["x"]] <- covs
A[["y"]] <- treat
A[["distribution"]] <- if (is_null(distribution <- A[["distribution"]])) {
available.distributions[1]} else {
match_arg(distribution, available.distributions, several.ok = TRUE)}
A[["w"]] <- s.weights
A[["verbose"]] <- FALSE
tune <- do.call("expand.grid", c(A[names(A) %in% tunable],
list(stringsAsFactors = FALSE, KEEP.OUT.ATTRS = FALSE)))
current.best.loss <- Inf
for (i in seq_row(tune)) {
A[["distribution"]] <- list(name = tune[["distribution"]][i])
tune_args <- as.list(tune[i, setdiff(tunable, "distribution")])
fit <- do.call(gbm::gbm.fit, c(A[names(A) %in% setdiff(names(formals(gbm::gbm.fit)), names(tune_args))], tune_args), quote = TRUE)
if (cv == 0) {
n.trees <- fit[["n.trees"]]
iters <- 1:n.trees
iters.grid <- round(seq(start.tree, n.trees, length.out = n.grid))
if (is_null(iters.grid) || anyNA(iters.grid) || any(iters.grid > n.trees)) stop("A problem has occurred")
ps <- gbm::predict.gbm(fit, n.trees = iters.grid, type = "response", newdata = covs)
w <- get.w.from.ps(ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass)
if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))
iter.grid.balance <- apply(w, 2, function(w_) {
crit$fun(init = init, weights = w_)
})
if (n.grid == n.trees) {
best.tree.index <- which.min(iter.grid.balance)
best.loss <- iter.grid.balance[best.tree.index]
best.tree <- iters.grid[best.tree.index]
tree.val <- setNames(data.frame(iters.grid,
iter.grid.balance),
c("tree", stop.method))
}
else {
it <- which.min(iter.grid.balance) + c(-1, 1)
it[1] <- iters.grid[max(1, it[1])]
it[2] <- iters.grid[min(length(iters.grid), it[2])]
iters.to.check <- iters[between(iters, iters[it])]
if (is_null(iters.to.check) || anyNA(iters.to.check) || any(iters.to.check > n.trees)) stop("A problem has occurred")
ps <- gbm::predict.gbm(fit, n.trees = iters.to.check, type = "response", newdata = covs)
w <- get.w.from.ps(ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass)
if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))
iter.grid.balance.fine <- apply(w, 2, function(w_) {
crit$fun(init = init, weights = w_)
})
best.tree.index <- which.min(iter.grid.balance.fine)
best.loss <- iter.grid.balance.fine[best.tree.index]
best.tree <- iters.to.check[best.tree.index]
tree.val <- setNames(data.frame(c(iters.grid, iters.to.check),
c(iter.grid.balance, iter.grid.balance.fine)),
c("tree", stop.method))
}
tree.val <- unique(tree.val[order(tree.val$tree),])
w <- w[,best.tree.index]
ps <- if (treat.type == "binary") ps[,best.tree.index] else NULL
tune[[paste.("best", stop.method)]][i] <- best.loss
tune[["best.tree"]][i] <- best.tree
if (best.loss < current.best.loss) {
best.fit <- fit
best.w <- w
best.ps <- ps
current.best.loss <- best.loss
best.tune.index <- i
info <- list(best.tree = best.tree,
tree.val = tree.val)
}
}
else {
A["data"] <- list(data.frame(treat, covs))
A[["cv.folds"]] <- cv
A["n.cores"] <- list(A[["n.cores"]])
A["var.names"] <- list(A[["var.names"]])
A["offset"] <- list(NULL)
A[["nTrain"]] <- floor(nrow(covs))
A[["class.stratify.cv"]] <- FALSE
A[["y"]] <- treat
A[["x"]] <- covs
A[["distribution"]] <- list(name = tune[["distribution"]][i])
A[["w"]] <- s.weights
tune_args <- as.list(tune[i, setdiff(tunable, "distribution")])
cv.results <- do.call(gbm::gbmCrossVal,
c(A[names(A) %in% setdiff(names(formals(gbm::gbmCrossVal)), names(tune_args))],
tune_args), quote = TRUE)
best.tree.index <- which.min(cv.results$error)
best.loss <- cv.results$error[best.tree.index]
best.tree <- best.tree.index
tune[[paste.("best", names(fit$name))]][i] <- best.loss
tune[["best.tree"]][i] <- best.tree
if (best.loss < current.best.loss) {
best.fit <- fit
best.ps <- gbm::predict.gbm(fit, n.trees = best.tree, type = "response", newdata = covs)
best.w <- drop(get.w.from.ps(best.ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass))
# if (trim.at != 0) best.w <- suppressMessages(trim(best.w, at = trim.at, treat = treat))
current.best.loss <- best.loss
best.tune.index <- i
tree.val <- data.frame(tree = seq_along(cv.results$error),
error = cv.results$error)
info <- list(best.tree = best.tree,
tree.val = tree.val)
if (treat.type == "multinomial") best.ps <- NULL
}
}
if (treat.type == "multinomial") ps <- NULL
}
tune[tunable[vapply(tune[tunable], all_the_same, logical(1L))]] <- NULL
if (ncol(tune) > 2) {
info[["tune"]] <- tune
info[["best.tune"]] <- tune[best.tune.index,]
}
if (is_not_null(best.ps)) {
if (is_not_null(focal) && focal != t.lev) best.ps <- 1 - best.ps
}
obj <- list(w = best.w, ps = best.ps, info = info, fit.obj = best.fit)
return(obj)
}
weightit2gbm.cont <- function(covs, treat, s.weights, estimand, focal, subset, stabilize, subclass, missing, ...) {
check.package("gbm")
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
colnames(missing.ind) <- paste0(colnames(missing.ind), ":<NA>")
covs <- cbind(covs, missing.ind)
}
}
if (is_null(A[["stop.method"]])) {
warning("No stop.method was provided. Using \"p.mean\".",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- "p.mean"
}
else if (length(A[["stop.method"]]) > 1) {
warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- A[["stop.method"]][1]
}
cv <- 0
available.stop.methods <- bal_criterion("continuous", list = TRUE)
s.m.matches <- charmatch(A[["stop.method"]], available.stop.methods)
if (is.na(s.m.matches) || s.m.matches == 0L) {
if (startsWith(A[["stop.method"]], "cv") && can_str2num(numcv <- substr(A[["stop.method"]], 3, nchar(A[["stop.method"]])))) {
cv <- round(str2num(numcv))
if (cv < 2) stop("At least 2 CV-folds must be specified in stop.method.", call. = FALSE)
}
else stop(paste0("'stop.method' must be one of ", word_list(c(available.stop.methods, "cv{#}"), "or", quotes = TRUE), "."), call. = FALSE)
}
else stop.method <- available.stop.methods[s.m.matches]
tunable <- c("interaction.depth", "shrinkage", "distribution")
trim.at <- if_null_then(A[["trim.at"]], 0)
for (f in names(formals(gbm::gbm.fit))) {
if (is_null(A[[f]])) {
if (f %in% c("x", "y", "misc", "w", "verbose", "var.names",
"response.name", "group", "distribution")) A[f] <- list(NULL)
else A[f] <- list(switch(f, n.trees = 2e4,
interaction.depth = 4,
shrinkage = 0.0005,
bag.fraction = 1,
formals(gbm::gbm.fit)[[f]]))
}
}
available.distributions <- c("gaussian", "laplace", "tdist", "poisson")
if (cv == 0) {
start.tree <- if_null_then(A[["start.tree"]], 1)
if (is_null(A[["n.grid"]])) n.grid <- round(1+sqrt(2*(A[["n.trees"]]-start.tree+1)))
else if (!is_(A[["n.grid"]], "numeric") || length(A[["n.grid"]]) > 1 ||
!between(A[["n.grid"]], c(2, A[["n.trees"]]))) {
stop("'n.grid' must be a numeric value between 2 and n.trees.", call. = FALSE)
}
else n.grid <- round(A[["n.grid"]])
crit <- bal_criterion("continuous", stop.method)
init <- crit$init(covs, treat, s.weights = s.weights, ...)
}
A[["x"]] <- covs
A[["y"]] <- treat
A[["distribution"]] <- if (is_null(distribution <- A[["distribution"]])) {
available.distributions[1]} else {
match_arg(distribution, available.distributions, several.ok = TRUE)}
A[["w"]] <- s.weights
A[["verbose"]] <- FALSE
if (!is.numeric(A[["n.trees"]]) || length(A[["n.trees"]]) > 1 || A[["n.trees"]] <= 1) {
stop("'n.trees' must be a number greater than 1.", call. = FALSE)
}
tune <- do.call("expand.grid", c(A[names(A) %in% tunable],
list(stringsAsFactors = FALSE, KEEP.OUT.ATTRS = FALSE)))
#Process density params
if (isTRUE(A[["use.kernel"]])) {
if (is_null(A[["bw"]])) A[["bw"]] <- "nrd0"
if (is_null(A[["adjust"]])) A[["adjust"]] <- 1
if (is_null(A[["kernel"]])) A[["kernel"]] <- "gaussian"
if (is_null(A[["n"]])) A[["n"]] <- 10*length(treat)
use.kernel <- TRUE
densfun <- NULL
}
else {
if (is_null(A[["density"]])) densfun <- dnorm
else if (is.function(A[["density"]])) densfun <- A[["density"]]
else if (is.character(A[["density"]]) && length(A[["density"]] == 1)) {
splitdens <- strsplit(A[["density"]], "_", fixed = TRUE)[[1]]
if (exists(splitdens[1], mode = "function", envir = parent.frame())) {
if (length(splitdens) > 1 && !can_str2num(splitdens[-1])) {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
word_list(splitdens[-1], and.or = "or", quotes = TRUE), "cannot be coerced to numeric."), call. = FALSE)
}
densfun <- function(x) {
tryCatch(do.call(get(splitdens[1]), c(list(x), as.list(str2num(splitdens[-1])))),
error = function(e) stop(paste0("Error in applying density:\n ", conditionMessage(e)), call. = FALSE))
}
}
else {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
splitdens[1], "is not an available function."), call. = FALSE)
}
}
else stop("The argument to 'density' cannot be evaluated as a density function.", call. = FALSE)
use.kernel <- FALSE
}
#Stabilization - get dens.num
p.num <- treat - mean(treat)
if (use.kernel) {
d.n <- density(p.num, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]], kernel = A[["kernel"]])
dens.num <- with(d.n, approxfun(x = x, y = y))(p.num)
}
else {
dens.num <- densfun(p.num/sd(treat))
if (is_null(dens.num) || !is.atomic(dens.num) || anyNA(dens.num)) {
stop("There was a problem with the output of 'density'. Try another density function or leave it blank to use the normal density.", call. = FALSE)
}
else if (any(dens.num <= 0)) {
stop("The input to 'density' may not accept the full range of treatment values.", call. = FALSE)
}
}
current.best.loss <- Inf
for (i in seq_row(tune)) {
fit <- do.call(gbm::gbm.fit, c(A[names(A) %in% setdiff(names(formals(gbm::gbm.fit)), tunable)],
tune[i, tunable[tunable %in% names(formals(gbm::gbm.fit))]]), quote = TRUE)
if (cv == 0) {
n.trees <- fit[["n.trees"]]
iters <- 1:n.trees
iters.grid <- round(seq(start.tree, n.trees, length.out = n.grid))
if (is_null(iters.grid) || anyNA(iters.grid) || any(iters.grid > n.trees)) stop("A problem has occurred")
gps <- gbm::predict.gbm(fit, n.trees = iters.grid, newdata = covs)
w <- get_cont_weights(gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
densfun = densfun, use.kernel = use.kernel, densControl = A)
if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))
iter.grid.balance <- apply(w, 2, function(w_) {
crit$fun(init = init, weights = w_)
})
if (n.grid == n.trees) {
best.tree.index <- which.min(iter.grid.balance)
best.loss <- iter.grid.balance[best.tree.index]
best.tree <- iters.grid[best.tree.index]
tree.val <- setNames(data.frame(iters.grid,
iter.grid.balance),
c("tree", stop.method))
}
else {
it <- which.min(iter.grid.balance) + c(-1, 1)
it[1] <- iters.grid[max(1, it[1])]
it[2] <- iters.grid[min(length(iters.grid), it[2])]
iters.to.check <- iters[between(iters, iters[it])]
if (is_null(iters.to.check) || anyNA(iters.to.check) || any(iters.to.check > n.trees)) stop("A problem has occurred")
gps <- gbm::predict.gbm(fit, n.trees = iters.to.check, newdata = covs)
w <- get_cont_weights(gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
densfun = densfun, use.kernel = use.kernel, densControl = A)
if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))
iter.grid.balance.fine <- apply(w, 2, function(w_) {
crit$fun(init = init, weights = w_)
})
best.tree.index <- which.min(iter.grid.balance.fine)
best.loss <- iter.grid.balance.fine[best.tree.index]
best.tree <- iters.to.check[best.tree.index]
tree.val <- setNames(data.frame(c(iters.grid, iters.to.check),
c(iter.grid.balance, iter.grid.balance.fine)),
c("tree", stop.method))
}
tree.val <- unique(tree.val[order(tree.val$tree),])
w <- w[,best.tree.index]
gps <- gps[,as.character(best.tree)]
tune[[paste.("best", stop.method)]][i] <- best.loss
tune[["best.tree"]][i] <- best.tree
if (best.loss < current.best.loss) {
best.fit <- fit
best.w <- w
best.gps <- gps
current.best.loss <- best.loss
best.tune.index <- i
info <- list(best.tree = best.tree,
tree.val = tree.val)
}
}
else {
A["data"] <- list(data.frame(treat, covs))
A[["cv.folds"]] <- cv
A["n.cores"] <- list(A[["n.cores"]])
A["var.names"] <- list(A[["var.names"]])
A["offset"] <- list(NULL)
A[["nTrain"]] <- floor(nrow(covs))
A[["class.stratify.cv"]] <- FALSE
A[["y"]] <- treat
A[["x"]] <- covs
A[["distribution"]] <- list(name = A[["distribution"]])
A[["w"]] <- s.weights
cv.results <- do.call(gbm::gbmCrossVal,
c(A[names(A) %in% setdiff(names(formals(gbm::gbmCrossVal)), tunable)],
tune[i, tunable[tunable %in% names(formals(gbm::gbmCrossVal))]]), quote = TRUE)
best.tree.index <- which.min(cv.results$error)
best.loss <- cv.results$error[best.tree.index]
best.tree <- best.tree.index
tune[[paste.("best", "error")]][i] <- best.loss
tune[["best.tree"]][i] <- best.tree
if (best.loss < current.best.loss) {
best.fit <- fit
best.gps <- gbm::predict.gbm(fit, n.trees = best.tree, newdata = covs)
best.w <- get_cont_weights(best.gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
densfun = densfun, use.kernel = use.kernel, densControl = A)
# if (trim.at != 0) best.w <- suppressMessages(trim(best.w, at = trim.at, treat = treat))
current.best.loss <- best.loss
best.tune.index <- i
tree.val <- data.frame(tree = seq_along(cv.results$error),
error = cv.results$error)
info <- list(best.tree = best.tree,
tree.val = tree.val)
}
}
}
if (use.kernel && isTRUE(A[["plot"]])) {
d.d <- density(treat - best.gps, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]],
kernel = A[["kernel"]])
plot_density(d.n, d.d)
}
tune[tunable[vapply(tunable, function(x) length(A[[x]]) == 1, logical(1L))]] <- NULL
if (ncol(tune) > 2) {
info[["tune"]] <- tune
info[["best.tune"]] <- tune[best.tune.index,]
}
obj <- list(w = best.w, info = info, fit.obj = best.fit)
return(obj)
}
#CBPS
weightit2cbps <- function(covs, treat, s.weights, estimand, focal, subset, stabilize, subclass, missing, ...) {
check.package("CBPS")
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
if (estimand == "ATT") {
ps <- make_df(levels(treat), length(treat))
control.levels <- levels(treat)[levels(treat) != focal]
fit.list <- make_list(control.levels)
for (i in control.levels) {
treat.in.i.focal <- which(treat %in% c(focal, i))
treat_ <- as.integer(treat[treat.in.i.focal] != i)
covs_ <- covs[treat.in.i.focal, , drop = FALSE]
new.data <- data.frame(treat_, covs_)
tryCatch({fit.list[[i]] <- CBPS::CBPS(formula(new.data),
data = new.data,
method = if (is_not_null(A$over) && A$over == FALSE) "exact" else "over",
standardize = FALSE,
sample.weights = s.weights[treat.in.i.focal],
ATT = 1,
...)},
error = function(e) {
e. <- conditionMessage(e)
e. <- gsub("method = \"exact\"", "over = FALSE", e., fixed = TRUE)
stop(e., call. = FALSE)
}
)
ps[[focal]][treat.in.i.focal] <- fit.list[[i]][["fitted.values"]]
ps[[i]][treat.in.i.focal] <- 1 - ps[[focal]][treat.in.i.focal]
}
}
else {
new.data <- data.frame(treat, covs)
if (treat.type == "binary" || !nunique.gt(treat, 4)) {
tryCatch({fit.list <- CBPS::CBPS(formula(new.data),
data = new.data,
method = if (isFALSE(A$over)) "exact" else "over",
standardize = FALSE,
sample.weights = s.weights,
ATT = 0,
...)},
error = function(e) {
e. <- conditionMessage(e)
e. <- gsub("method = \"exact\"", "over = FALSE", e., fixed = TRUE)
stop(e., call. = FALSE)
}
)
ps <- fit.list[["fitted.values"]]
}
else {
ps <- rep(NA_real_, length(treat))
fit.list <- make_list(levels(treat))
for (i in levels(treat)) {
new.data[[1]] <- as.integer(treat == i)
fit.list[[i]] <- CBPS::CBPS(formula(new.data), data = new.data,
method = if (isFALSE(A$over)) "exact" else "over",
standardize = FALSE,
sample.weights = s.weights,
ATT = 0, ...)
ps[treat==i] <- fit.list[[i]][["fitted.values"]][treat==i]
}
}
}
w <- get_w_from_ps(ps, treat, estimand = estimand, subclass = subclass,
focal = focal, stabilize = stabilize)
if (treat.type != "binary") {
p.score <- NULL
}
else if (is_not_null(dim(ps)) && length(dim(ps)) == 2) {
p.score <- ps[[get_treated_level(treat)]]
}
else p.score <- ps
obj <- list(w = w, ps = p.score, fit.obj = fit.list)
return(obj)
}
weightit2cbps.cont <- function(covs, treat, s.weights, subset, missing, ...) {
check.package("CBPS")
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
new.data <- data.frame(treat = treat, covs)
tryCatch({fit <- CBPS::CBPS(formula(new.data),
data = new.data,
method = if (isFALSE(A$over)) "exact" else "over",
standardize = FALSE,
sample.weights = s.weights,
...)},
error = function(e) {
e. <- conditionMessage(e)
e. <- gsub("method = \"exact\"", "over = FALSE", e., fixed = TRUE)
stop(e., call. = FALSE)
}
)
w <- fit$weights / s.weights
obj <- list(w = w, fit.obj = fit)
return(obj)
}
weightit2cbps.msm <- function(covs.list, treat.list, s.weights, subset, missing, ...) {
stop("CBMSM doesn't work yet.")
}
weightit2npcbps <- function(covs, treat, s.weights, subset, moments, int, missing, ...) {
check.package("CBPS")
A <- list(...)
if (!all_the_same(s.weights)) stop(paste0("Sampling weights cannot be used with method = \"npcbps\"."),
call. = FALSE)
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
new.data <- data.frame(treat = treat, covs)
fit <- do.call(CBPS::npCBPS, c(list(formula(new.data), data = new.data, print.level = 1), A),
quote = TRUE)
w <- fit$weights
for (i in levels(treat)) w[treat == i] <- w[treat == i]/mean(w[treat == i])
obj <- list(w = w, fit.obj = fit)
return(obj)
}
weightit2npcbps.cont <- function(covs, treat, s.weights, subset, moments, int, missing, ...) {
check.package("CBPS")
A <- list(...)
if (!all_the_same(s.weights)) stop(paste0("Sampling weights cannot be used with method = \"npcbps\"."),
call. = FALSE)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
new.data <- data.frame(treat = treat, covs)
fit <- do.call(CBPS::npCBPS, c(list(formula(new.data), data = new.data, print.level = 1), A),
quote = TRUE)
w <- fit$weights
w <- w/mean(w)
obj <- list(w = w, fit.obj = fit)
return(obj)
}
#Entropy balancing
weightit2ebal <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, missing, moments, int, ...) {
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int, center = TRUE))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (is_not_null(A[["base.weights"]])) A[["base.weight"]] <- A[["base.weights"]]
if (is_null(A[["base.weight"]])) {
bw <- rep(1, length(treat))
}
else {
if (!is.numeric(A[["base.weight"]]) || length(A[["base.weight"]]) != length(treat)) {
stop("The argument to base.weight must be a numeric vector with length equal to the number of units.", call. = FALSE)
}
else bw <- A[["base.weight"]]
}
eb <- function(C, M, s.weights_t, Q) {
#X_t : covariates in control group;
#Returns weights for control group
n <- nrow(C)
W <- function(Z) {
drop(Q * exp(-C %*% Z))
}
objective.EB <- function(Z) {
log(sum(W(Z))) + sum(M * Z)
}
gradient.EB <- function(Z) {
w <- W(Z)
drop(M - w %*% C/sum(w))
}
opt.out <- optim(par = rep(0, ncol(C)),
fn = objective.EB,
gr = gradient.EB,
method = "BFGS",
control = list(trace = 0,
reltol = if_null_then(A[["reltol"]], sqrt(.Machine$double.eps)),
maxit = if_null_then(A[["maxit"]], 200)))
w <- W(opt.out$par)
list(Z = setNames(opt.out$par, colnames(C)),
w = w/(mean(w) * s.weights_t),
opt.out = opt.out)
}
w <- rep(1, length(treat))
if (estimand == "ATT") {
groups_to_weight <- levels(treat)[levels(treat) != focal]
targets <- cobalt::col_w_mean(covs, s.weights = s.weights, subset = treat == focal)
}
else if (estimand == "ATE") {
groups_to_weight <- levels(treat)
targets <- cobalt::col_w_mean(covs, s.weights = s.weights)
}
fit.list <- make_list(groups_to_weight)
for (i in groups_to_weight) {
fit.list[[i]] <- eb(covs[treat == i,,drop = FALSE], targets,
s.weights[treat == i], bw[treat == i])
w[treat == i] <- fit.list[[i]]$w
}
obj <- list(w = w, fit.obj = lapply(fit.list, function(x) x[["opt.out"]]))
return(obj)
}
weightit2ebal.cont <- function(covs, treat, s.weights, subset, moments, int, missing, ...) {
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
d.moments <- max(if_null_then(A[["d.moments"]], 1), moments)
k <- ncol(covs)
poly.covs <- int.poly.f(covs, poly = d.moments)
int.covs <- int.poly.f(covs, int = int)
covs <- cbind(covs, poly.covs, int.covs)
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
# colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
# covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
if (is_not_null(A[["base.weights"]])) A[["base.weight"]] <- A[["base.weights"]]
if (is_null(A[["base.weight"]])) {
q <- rep(1, length(treat))
}
else {
if (!is.numeric(A[["base.weight"]]) || length(A[["base.weight"]]) != length(treat)) {
stop("The argument to base.weight must be a numeric vector with length equal to the number of units.", call. = FALSE)
}
else q <- A[["base.weight"]]
}
t.mat <- poly(treat, degree = d.moments)
treat_sc <- mat_div(center(t.mat, at = cobalt::col_w_mean(t.mat, s.weights)),
cobalt::col_w_sd(t.mat, s.weights))
covs_sc <- mat_div(center(covs, at = cobalt::col_w_mean(covs, s.weights)),
cobalt::col_w_sd(covs, s.weights))
kp <- ncol(poly.covs)/(d.moments-1)
cov_include <- c(seq_len(k),
if (moments > 1) k + unlist(lapply(seq_len(moments - 1), function(i) i + (d.moments - 1)*(seq_len(kp)-1))),
if (int) seq_col(covs)[-seq_len(k + ncol(poly.covs))])
gTX <- do.call("cbind", c(list(treat_sc, covs_sc, treat_sc[,1] * covs_sc[,cov_include])))
#----Code written by Stefan Tubbicke---#
#define objective function (Lagrange dual)
objective.EBCT <- function(theta) {
f <- log(mean(q*s.weights*exp(gTX %*% theta)))*nrow(gTX)
return(f)
}
#define gradient function (LHS of equations 8 in Tubbicke (2020))
gradient.EBCT<- function(theta) {
g <- t(gTX) %*% (q*s.weights*exp(gTX %*% theta)/(mean(q*s.weights*exp(gTX %*% theta))))
return(g)
}
opt.out <- optim(par = rep(0, ncol(gTX)),
fn = objective.EBCT,
gr = gradient.EBCT,
method = "BFGS",
control = list(trace = TRUE,
reltol = if_null_then(A[["reltol"]], sqrt(.Machine$double.eps)),
maxit = if_null_then(A[["maxit"]], 200)))
w <- q*exp(gTX %*% opt.out$par)/(mean(q*exp(gTX %*% opt.out$par)))
#--------------------------------------#
obj <- list(w = w, fit.obj = opt.out)
return(obj)
}
#Empirical Balancing Calibration weights with ATE
weightit2ebcw <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
check.package("ATE")
A <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
for (f in names(formals(ATE::ATE))) {
if (is_null(A[[f]])) A[[f]] <- formals(ATE::ATE)[[f]]
}
if (estimand == "ATT") {
w <- rep(1, length(treat))
control.levels <- levels(treat)[levels(treat) != focal]
fit.list <- make_list(control.levels)
for (i in control.levels) {
treat.in.i.focal <- treat %in% c(focal, i)
treat_ <- as.integer(treat[treat.in.i.focal] != i)
covs_ <- covs[treat.in.i.focal, , drop = FALSE]
colinear.covs.to.remove <- colnames(covs_)[colnames(covs_) %nin% colnames(make_full_rank(covs_[treat_ == 0, , drop = FALSE]))]
covs_ <- covs_[, colnames(covs_) %nin% colinear.covs.to.remove, drop = FALSE]
covs_[treat_ == 1,] <- covs_[treat_ == 1,] * s.weights[treat == focal] * sum(treat == focal)/ sum(s.weights[treat == focal])
Y <- rep(0, length(treat_))
ate.out <- ATE::ATE(Y = Y, Ti = treat_, X = covs_,
ATT = TRUE,
theta = A[["theta"]],
verbose = TRUE,
max.iter = A[["max.iter"]],
tol = A[["tol"]],
initial.values = A[["initial.values"]],
backtrack = A[["backtrack"]],
backtrack.alpha = A[["backtrack.alpha"]],
backtrack.beta = A[["backtrack.beta"]])
w[treat == i] <- ate.out$weights.q[treat_ == 0] / s.weights[treat == i]
fit.list[[i]] <- ate.out
}
}
else if (estimand == "ATE") {
w <- rep(1, length(treat))
fit.list <- make_list(levels(treat))
for (i in levels(treat)) {
covs_i <- rbind(covs, covs[treat==i, , drop = FALSE])
treat_i <- c(rep(1, nrow(covs)), rep(0, sum(treat==i)))
colinear.covs.to.remove <- colnames(covs_i)[colnames(covs_i) %nin% colnames(make_full_rank(covs_i[treat_i == 0, , drop = FALSE]))]
covs_i <- covs_i[, colnames(covs_i) %nin% colinear.covs.to.remove, drop = FALSE]
covs_i[treat_i == 1,] <- covs_i[treat_i == 1,] * s.weights * sum(treat_i == 1) / sum(s.weights)
Y <- rep(0, length(treat_i))
ate.out <- ATE::ATE(Y = Y, Ti = treat_i, X = covs_i,
ATT = TRUE,
theta = A[["theta"]],
verbose = TRUE,
max.iter = A[["max.iter"]],
tol = A[["tol"]],
initial.values = A[["initial.values"]],
backtrack = A[["backtrack"]],
backtrack.alpha = A[["backtrack.alpha"]],
backtrack.beta = A[["backtrack.beta"]])
w[treat == i] <- ate.out$weights.q[treat_i == 0] / s.weights[treat == i]
fit.list[[i]] <- ate.out
}
}
if (length(fit.list) == 1) fit.list <- fit.list[[1]]
obj <- list(w = w, fit.obj = fit.list)
return(obj)
}
#PS weights using SuperLearner
weightit2super <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, ...) {
A <- list(...)
check.package("SuperLearner")
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
covs <- as.data.frame(covs)
if (ncol(covs) > 1) {
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
}
for (f in names(formals(SuperLearner::SuperLearner))) {
if (f == "method") {if (is_null(A[["SL.method"]])) A[["SL.method"]] <- formals(SuperLearner::SuperLearner)[["method"]]}
else if (f == "env") {if (is_null(A[["env"]])) A[["env"]] <- environment(SuperLearner::SuperLearner)}
else if (is_null(A[[f]])) A[[f]] <- formals(SuperLearner::SuperLearner)[[f]]
}
discrete <- if_null_then(A[["discrete"]], FALSE)
if (length(discrete) != 1 || !is_(discrete, "logical")) stop("'discrete' must be TRUE or FALSE.", call. = FALSE)
if (identical(A[["SL.method"]], "method.balance")) {
if (treat.type != "binary") stop("\"method.balance\" cannot be used with multi-category treatments.", call. = FALSE)
if (is_null(A[["stop.method"]])) {
warning("No stop.method was provided. Using \"es.mean\".",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- "es.mean"
}
else if (length(A[["stop.method"]]) > 1) {
warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
call. = FALSE, immediate. = TRUE)
A[["stop.method"]] <- A[["stop.method"]][1]
}
available.stop.methods <- bal_criterion(treat.type, list = TRUE)
s.m.matches <- charmatch(A[["stop.method"]], available.stop.methods)
if (is.na(s.m.matches) || s.m.matches == 0L) {
stop(paste0("'stop.method' must be one of ", word_list(available.stop.methods, "or", quotes = TRUE), "."), call. = FALSE)
}
else stop.method <- available.stop.methods[s.m.matches]
crit <- bal_criterion("binary", stop.method)
init <- crit$init(covs, treat, estimand = estimand, s.weights = s.weights, focal = focal, ...)
bal_fun <- crit$fun
sneaky <- 0
attr(sneaky, "vals") <- list(init = init, bal_fun = bal_fun, estimand = estimand)
A[["control"]] <- list(trimLogit = sneaky)
A[["SL.method"]] <- method.balance(stop.method)
}
fit.list <- info <- make_list(levels(treat))
ps <- make_df(levels(treat), nrow = length(treat))
for (i in levels(treat)) {
if (treat.type == "binary" && i == last(levels(treat))) {
ps[[i]] <- 1 - ps[[1]]
fit.list <- fit.list[[1]]
info <- info[[1]]
next
}
treat_i <- as.numeric(treat == i)
fit.list[[i]] <- do.call(SuperLearner::SuperLearner, list(Y = treat_i,
X = as.data.frame(covs),
family = binomial(),
SL.library = A[["SL.library"]],
verbose = FALSE,
method = A[["SL.method"]],
id = NULL,
obsWeights = s.weights,
control = A[["control"]],
cvControl = A[["cvControl"]],
env = A[["env"]]))
if (discrete) ps[[i]] <- fit.list[[i]]$library.predict[,which.min(fit.list[[i]]$cvRisk)]
else ps[[i]] <- fit.list[[i]]$SL.predict
info[[i]] <- list(coef = fit.list[[i]]$coef,
cvRisk = fit.list[[i]]$cvRisk)
}
#ps should be matrix of probs for each treat
#Computing weights
w <- get_w_from_ps(ps = ps, treat = treat, estimand, focal, stabilize = stabilize, subclass = subclass)
p.score <- if (treat.type == "binary") ps[[get_treated_level(treat)]] else NULL
obj <- list(w = w, ps = p.score, info = info, fit.obj = fit.list)
return(obj)
}
weightit2super.cont <- function(covs, treat, s.weights, subset, stabilize, missing, ps, ...) {
A <- B <- list(...)
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
if (ncol(covs) > 1) {
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
}
#Process density params
if (isTRUE(A[["use.kernel"]])) {
if (is_null(A[["bw"]])) A[["bw"]] <- "nrd0"
if (is_null(A[["adjust"]])) A[["adjust"]] <- 1
if (is_null(A[["kernel"]])) A[["kernel"]] <- "gaussian"
if (is_null(A[["n"]])) A[["n"]] <- 10*length(treat)
use.kernel <- TRUE
densfun <- NULL
}
else {
if (is_null(A[["density"]])) densfun <- dnorm
else if (is.function(A[["density"]])) densfun <- A[["density"]]
else if (is.character(A[["density"]]) && length(A[["density"]] == 1)) {
splitdens <- strsplit(A[["density"]], "_", fixed = TRUE)[[1]]
if (exists(splitdens[1], mode = "function", envir = parent.frame())) {
if (length(splitdens) > 1 && !can_str2num(splitdens[-1])) {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
word_list(splitdens[-1], and.or = "or", quotes = TRUE), "cannot be coerced to numeric."), call. = FALSE)
}
densfun <- function(x) {
tryCatch(do.call(get(splitdens[1]), c(list(x), as.list(str2num(splitdens[-1])))),
error = function(e) stop(paste0("Error in applying density:\n ", conditionMessage(e)), call. = FALSE))
}
}
else {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
splitdens[1], "is not an available function."), call. = FALSE)
}
}
else stop("The argument to 'density' cannot be evaluated as a density function.", call. = FALSE)
use.kernel <- FALSE
}
#Stabilization - get dens.num
p.num <- treat - mean(treat)
if (use.kernel) {
d.n <- density(p.num, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]], kernel = A[["kernel"]])
dens.num <- with(d.n, approxfun(x = x, y = y))(p.num)
}
else {
dens.num <- densfun(p.num/sd(treat))
if (is_null(dens.num) || !is.atomic(dens.num) || anyNA(dens.num)) {
stop("There was a problem with the output of density. Try another density function or leave it blank to use the normal density.", call. = FALSE)
}
else if (any(dens.num <= 0)) {
stop("The input to density may not accept the full range of treatment values.", call. = FALSE)
}
}
#Estimate GPS
for (f in names(formals(SuperLearner::SuperLearner))) {
if (f == "method") {if (is_null(B[["SL.method"]])) B[["SL.method"]] <- formals(SuperLearner::SuperLearner)[["method"]]}
else if (f == "env") {if (is_null(B[["env"]])) B[["env"]] <- environment(SuperLearner::SuperLearner)}
else if (is_null(B[[f]])) B[[f]] <- formals(SuperLearner::SuperLearner)[[f]]
}
discrete <- if_null_then(A[["discrete"]], FALSE)
if (length(discrete) != 1 || !is_(discrete, "logical")) stop("discrete must be TRUE or FALSE.", call. = FALSE)
if (identical(B[["SL.method"]], "method.balance")) {
if (is_null(B[["stop.method"]])) {
warning("No stop.method was provided. Using \"p.mean\".",
call. = FALSE, immediate. = TRUE)
B[["stop.method"]] <- "p.mean"
}
else if (length(B[["stop.method"]]) > 1) {
warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
call. = FALSE, immediate. = TRUE)
B[["stop.method"]] <- B[["stop.method"]][1]
}
available.stop.methods <- bal_criterion("continuous", list = TRUE)
s.m.matches <- charmatch(B[["stop.method"]], available.stop.methods)
if (is.na(s.m.matches) || s.m.matches == 0L) {
stop(paste0("'stop.method' must be one of ", word_list(available.stop.methods, "or", quotes = TRUE), "."), call. = FALSE)
}
else stop.method <- available.stop.methods[s.m.matches]
crit <- bal_criterion("continuous", stop.method)
init <- crit$init(covs, treat, s.weights = s.weights, ...)
bal_fun <- crit$fun
sneaky <- 0
attr(sneaky, "vals") <- list(init = init,
bal_fun = bal_fun,
dens.num = dens.num,
densfun = densfun,
use.kernel = use.kernel,
densControl = A)
B[["control"]] <- list(trimLogit = sneaky)
B[["SL.method"]] <- method.balance.cont(stop.method)
}
fit <- do.call(SuperLearner::SuperLearner, list(Y = treat,
X = as.data.frame(covs),
family = gaussian(),
SL.library = B[["SL.library"]],
verbose = FALSE,
method = B[["SL.method"]],
id = NULL,
obsWeights = s.weights,
control = B[["control"]],
cvControl = B[["cvControl"]],
env = B[["env"]]))
if (discrete) gp.score <- fit$library.predict[,which.min(fit$cvRisk)]
else gp.score <- fit$SL.predict
#Get weights
w <- get_cont_weights(gp.score, treat = treat, s.weights = s.weights,
dens.num = dens.num, densfun = densfun,
use.kernel = use.kernel, densControl = A)
if (use.kernel && isTRUE(A[["plot"]])) {
d.d <- density(treat - gp.score, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]],
kernel = A[["kernel"]])
plot_density(d.n, d.d)
}
info <- list(coef = fit$coef,
cvRisk = fit$cvRisk)
obj <- list(w = w, info = info, fit.obj = fit)
return(obj)
}
#PS weights using BART
weightit2bart <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, ...) {
A <- list(...)
check.package("dbarts")
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
if (!all_the_same(s.weights)) stop("Sampling weights cannot be used with method = \"bart\".",
call. = FALSE)
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
if (ncol(covs) > 1) {
colinear.covs.to.remove <- colnames(covs)[colnames(covs) %nin% colnames(make_full_rank(covs))]
covs <- covs[, colnames(covs) %nin% colinear.covs.to.remove, drop = FALSE]
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
ps <- make_df(levels(treat), nrow = length(treat))
A[["formula"]] <- covs
A[["keepCall"]] <- FALSE
A[["combineChains"]] <- TRUE
A[["verbose"]] <- FALSE #necessary to prevent crash
fit.list <- make_list(levels(treat))
for (i in levels(treat)) {
if (treat.type == "binary" && i == last(levels(treat))) {
ps[[i]] <- 1 - ps[[1]]
fit.list <- fit.list[[1]]
next
}
A[["data"]] <- as.integer(treat == i)
fit.list[[i]] <- do.call(dbarts::bart2, A[names(A) %in% setdiff(c(names(formals(dbarts::bart2)),
names(formals(dbarts::dbartsControl))),
c("offset.test", "weights", "subset", "test"))],
quote = TRUE)
ps[[i]] <- fitted(fit.list[[i]])
}
info <- list()
#ps should be matrix of probs for each treat
#Computing weights
w <- get_w_from_ps(ps = ps, treat = treat, estimand, focal, stabilize = stabilize, subclass = subclass)
p.score <- if (treat.type == "binary") ps[[get_treated_level(treat)]] else NULL
obj <- list(w = w, ps = p.score, info = info, fit.obj = fit.list)
return(obj)
}
weightit2bart.cont <- function(covs, treat, s.weights, subset, stabilize, missing, ps, ...) {
A <- list(...)
check.package("dbarts")
covs <- covs[subset, , drop = FALSE]
treat <- treat[subset]
s.weights <- s.weights[subset]
if (!all_the_same(s.weights)) stop("Sampling weights cannot be used with method = \"bart\".",
call. = FALSE)
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])
#Process density params
if (isTRUE(A[["use.kernel"]])) {
if (is_null(A[["bw"]])) A[["bw"]] <- "nrd0"
if (is_null(A[["adjust"]])) A[["adjust"]] <- 1
if (is_null(A[["kernel"]])) A[["kernel"]] <- "gaussian"
if (is_null(A[["n"]])) A[["n"]] <- 10*length(treat)
use.kernel <- TRUE
densfun <- NULL
}
else {
if (is_null(A[["density"]])) densfun <- dnorm
else if (is.function(A[["density"]])) densfun <- A[["density"]]
else if (is.character(A[["density"]]) && length(A[["density"]] == 1)) {
splitdens <- strsplit(A[["density"]], "_", fixed = TRUE)[[1]]
if (exists(splitdens[1], mode = "function", envir = parent.frame())) {
if (length(splitdens) > 1 && !can_str2num(splitdens[-1])) {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
word_list(splitdens[-1], and.or = "or", quotes = TRUE), "cannot be coerced to numeric."), call. = FALSE)
}
densfun <- function(x) {
tryCatch(do.call(get(splitdens[1]), c(list(x), as.list(str2num(splitdens[-1])))),
error = function(e) stop(paste0("Error in applying density:\n ", conditionMessage(e)), call. = FALSE))
}
}
else {
stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
splitdens[1], "is not an available function."), call. = FALSE)
}
}
else stop("The argument to 'density' cannot be evaluated as a density function.", call. = FALSE)
use.kernel <- FALSE
}
#Stabilization - get dens.num
p.num <- treat - mean(treat)
if (use.kernel) {
d.n <- density(p.num, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]], kernel = A[["kernel"]])
dens.num <- with(d.n, approxfun(x = x, y = y))(p.num)
}
else {
dens.num <- densfun(p.num/sd(treat))
if (is_null(dens.num) || !is.atomic(dens.num) || anyNA(dens.num)) {
stop("There was a problem with the output of density. Try another density function or leave it blank to use the normal density.", call. = FALSE)
}
else if (any(dens.num <= 0)) {
stop("The input to density may not accept the full range of treatment values.", call. = FALSE)
}
}
A[["formula"]] <- covs
A[["data"]] <- treat
A[["keepCall"]] <- FALSE
A[["combineChains"]] <- TRUE
A[["verbose"]] <- FALSE #necessary to prevent crash
#Estimate GPS
fit <- do.call(dbarts::bart2, A[names(A) %in% setdiff(c(names(formals(dbarts::bart2)),
names(formals(dbarts::dbartsControl))),
c("offset.test", "weights", "subset", "test"))],
quote = TRUE)
gp.score <- fitted(fit)
#Get weights
w <- get_cont_weights(gp.score, treat = treat, s.weights = s.weights,
dens.num = dens.num, densfun = densfun,
use.kernel = use.kernel, densControl = A)
if (use.kernel && isTRUE(A[["plot"]])) {
d.d <- density(treat - gp.score, n = A[["n"]],
weights = s.weights/sum(s.weights), give.Rkern = FALSE,
bw = A[["bw"]], adjust = A[["adjust"]],
kernel = A[["kernel"]])
plot_density(d.n, d.d)
}
info <- list()
obj <- list(w = w, info = info, fit.obj = fit)
return(obj)
}
#Energy balancing
weightit2energy <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
check.package("osqp")
A <- list(...)
if (missing == "ind") {
missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
if (is_not_null(missing.ind)) {
covs[is.na(covs)] <- 0
covs <- cbind(covs, missing.ind)
}
}
dist.mat <- if_null_then(A[["dist.mat"]], "scaled_euclidean")
A[["dist.mat"]] <- NULL
if (is.character(dist.mat) && length(dist.mat) == 1L) {
dist.covs <- transform_covariates(data = covs, method = dist.mat,
s.weights = s.weights, discarded = !subset)
d <- unname(eucdist_internal(dist.covs))
# dist.mat <- match_arg(dist.mat, c("mahalanobis", "scaled_euclidean", "euclidean"))
#
# dist.mat <- {
# if (dist.mat == "mahalanobis") {
# mahSigma_inv <- generalized_inverse(cov.wt(covs, s.weights)$cov)
# as.matrix(dist(tcrossprod(covs, chol2(mahSigma_inv))))
# }
# else if (dist.mat == "scaled_euclidean") {
# as.matrix(dist(mat_div(covs, sqrt(col.w.v(covs, s.weights)))))
# }
# else if (dist.mat == "euclidean") {
# as.matrix(dist(covs))
# }
# }
}
else {
if (inherits(dist.mat, "dist")) dist.mat <- as.matrix(dist.mat)
if (!is.matrix(dist.mat) || !all(dim(dist.mat) == length(treat)) ||
!all(check_if_zero(diag(dist.mat))) || any(dist.mat < 0) ||
!isSymmetric(unname(dist.mat))) {
stop("'dist.mat' must be one of \"mahalanobis\", \"scaled_euclidean\", or \"euclidean\" or a square, symmetric distance matrix with a value for all pairs of units.", call. = FALSE)
}
d <- unname(dist.mat[subset, subset])
}
covs <- covs[subset, , drop = FALSE]
treat <- factor(treat[subset])
s.weights <- s.weights[subset]
n <- length(treat)
levels_treat <- levels(treat)
diagn <- diag(n)
min.w <- if_null_then(A[["min.w"]], 1e-8)
if (!is.numeric(min.w) || length(min.w) != 1) {
warning("'min.w' must be a single number. Setting min.w = 1e-8.", call. = FALSE, immediate. = TRUE)
min.w <- 1e-8
}
for (t in levels_treat) s.weights[treat == t] <- s.weights[treat == t]/mean(s.weights[treat == t])
tmat <- vapply(levels_treat, function(t) treat == t, logical(n))
nt <- colSums(tmat)
J <- setNames(lapply(levels_treat, function(t) s.weights*tmat[,t]/nt[t]), levels_treat)
if (estimand == "ATE") {
J0 <- as.matrix(s.weights/n)
M2_array <- vapply(levels_treat, function(t) -2 * tcrossprod(J[[t]]) * d, diagn)
M1_array <- vapply(levels_treat, function(t) 2 * J[[t]] * d %*% J0, J0)
M2 <- rowSums(M2_array, dims = 2)
M1 <- rowSums(M1_array)
if (!isFALSE(A[["improved"]])) {
all_pairs <- combn(levels_treat, 2, simplify = FALSE)
M2_pairs_array <- vapply(all_pairs, function(p) -2 * tcrossprod(J[[p[1]]]-J[[p[2]]]) * d, diagn)
M2 <- M2 + rowSums(M2_pairs_array, dims = 2)
}
#Constraints for positivity and sum of weights
Amat <- rbind(diagn, t(s.weights * tmat))
lvec <- c(rep(min.w, n), nt)
uvec <- c(ifelse(check_if_zero(s.weights), min.w, Inf), nt)
}
else {
J0_focal <- as.matrix(J[[focal]])
clevs <- levels_treat[levels_treat != focal]
M2_array <- vapply(clevs, function(t) -2 * tcrossprod(J[[t]]) * d, diagn)
M1_array <- vapply(clevs, function(t) 2 * J[[t]] * d %*% J0_focal, J0_focal)
M2 <- rowSums(M2_array, dims = 2)
M1 <- rowSums(M1_array)
#Constraints for positivity and sum of weights
Amat <- rbind(diagn, t(s.weights*tmat))
lvec <- c(ifelse_(check_if_zero(s.weights), min.w, treat == focal, 1, min.w), nt)
uvec <- c(ifelse_(check_if_zero(s.weights), min.w, treat == focal, 1, Inf), nt)
}
#Add weight penalty
if (is_not_null(A[["lambda"]])) diag(M2) <- diag(M2) + A[["lambda"]] / n^2
if (moments != 0 || int) {
#Exactly balance moments and/or interactions
covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
if (estimand == "ATE") targets <- col.w.m(covs, s.weights)
else targets <- col.w.m(covs[treat == focal, , drop = FALSE], s.weights[treat == focal])
Amat <- do.call("rbind", c(list(Amat),
lapply(levels_treat, function(t) {
if (is_null(focal) || t != focal) t(covs * J[[t]])
})))
lvec <- do.call("c", c(list(lvec),
lapply(levels_treat, function(t) {
if (is_null(focal) || t != focal) targets
})))
uvec <- do.call("c", c(list(uvec),
lapply(levels_treat, function(t) {
if (is_null(focal) || t != focal) targets
})))
}
if (is_not_null(A[["eps"]])) {
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- A[["eps"]]
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- A[["eps"]]
}
A[names(A) %nin% names(formals(osqp::osqpSettings))] <- NULL
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 2E3L
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1E-8
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1E-8
A[["verbose"]] <- TRUE
options.list <- do.call(osqp::osqpSettings, A)
opt.out <- do.call(osqp::solve_osqp, list(P = M2, q = M1, A = Amat, l = lvec, u = uvec,
pars = options.list),
quote = TRUE)
if (identical(opt.out$info$status, "maximum iterations reached")) {
warning("The optimization failed to converge. See Notes section at ?method_energy for information.", call. = FALSE)
}
w <- opt.out$x
if (estimand == "ATT") w[treat == focal] <- 1
w[w <= min.w] <- min.w
obj <- list(w = w, fit.obj = opt.out)
return(obj)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.