#Turn a vector into a 0/1 vector. 'zero' and 'one' can be supplied to make it clear which is
#which; otherwise, a guess is used.
binarize <- function(variable, zero = NULL, one = NULL) {
var.name <- deparse1(substitute(variable))
if (length(unique(variable)) > 2) {
stop(sprintf("Cannot binarize %s: more than two levels.", var.name), call. = FALSE)
}
if (is.character(variable) || is.factor(variable)) {
variable <- factor(variable, nmax = 2)
unique.vals <- levels(variable)
}
else {
unique.vals <- unique(variable, nmax = 2)
}
if (is.null(zero)) {
if (is.null(one)) {
if (can_str2num(unique.vals)) {
variable.numeric <- str2num(variable)
}
else {
variable.numeric <- as.numeric(variable)
}
if (0 %in% variable.numeric) zero <- 0
else zero <- min(variable.numeric, na.rm = TRUE)
out <- setNames(as.integer(variable.numeric != zero), names(variable))
}
else {
if (one %in% unique.vals) out <- setNames(as.integer(variable == one), names(variable))
else stop("The argument to 'one' is not the name of a level of variable.", call. = FALSE)
}
}
else {
if (zero %in% unique.vals) out <- setNames(as.integer(variable != zero), names(variable))
else stop("The argument to 'zero' is not the name of a level of variable.", call. = FALSE)
}
return(out)
}
#Get covariates (RHS) vars from formula
get.covs.matrix <- function(formula = NULL, data = NULL) {
if (is.null(formula)) {
fnames <- colnames(data)
fnames[!startsWith(fnames, "`")] <- paste0("`", fnames[!startsWith(fnames, "`")], "`")
formula <- reformulate(fnames)
}
else formula <- update(terms(formula, data = data), NULL ~ . + 1)
mf <- model.frame(terms(formula, data = data), data,
na.action = na.pass)
chars.in.mf <- vapply(mf, is.character, logical(1L))
mf[chars.in.mf] <- lapply(mf[chars.in.mf], factor)
X <- model.matrix(formula, data = mf,
contrasts.arg = lapply(Filter(is.factor, mf),
contrasts, contrasts = FALSE))
assign <- attr(X, "assign")[-1]
X <- X[,-1, drop = FALSE]
attr(X, "assign") <- assign
return(X)
}
#Add missing indicators (is.na(x)) to formula
add_miss_ind_to_formula <- function(formula, data = NULL) {
mf <- model.frame(delete.response(terms(formula, data = data)),
data = data, na.action = "na.pass")
missing_vars <- names(mf)[vapply(mf, anyNA, logical(1L))]
s <- do.call("paste",
c(list(". ~ ."),
lapply(missing_vars,
sprintf,
fmt = "+ is.na(%s)")))
update(formula, s)
}
#treat.type processing
assign_treat_type <- function(treat, use.multi = FALSE) {
#Returns treat with treat.type attribute
nunique.treat <- nunique(treat)
if (nunique.treat < 2) {
stop("The treatment must have at least two unique values.", call. = FALSE)
}
else if (!use.multi && nunique.treat == 2) {
treat.type <- "binary"
}
else if (use.multi || is_(treat, c("factor", "character"))) {
treat.type <- "multinomial"
if (!is_(treat, "processed.treat")) treat <- factor(treat)
}
else {
treat.type <- "continuous"
}
attr(treat, "treat.type") <- treat.type
return(treat)
}
get_treat_type <- function(treat) {
return(attr(treat, "treat.type"))
}
has_treat_type <- function(treat) {
is_not_null(get_treat_type(treat))
}
get_treated_level <- function(treat) {
if (!is_binary(treat)) stop("'treat' must be a binary variable.")
if (is.character(treat) || is.factor(treat)) {
treat <- factor(treat, nmax = 2)
unique.vals <- levels(treat)
}
else {
unique.vals <- unique(treat, nmax = 2)
}
if (can_str2num(unique.vals)) {
unique.vals.numeric <- str2num(unique.vals)
}
else {
unique.vals.numeric <- seq_along(unique.vals)
}
if (0 %in% unique.vals.numeric) treated <- unique.vals[unique.vals.numeric != 0]
else treated <- unique.vals[which.max(unique.vals.numeric)]
return(treated)
}
#Converting string to numeric
can_str2num <- function(x) {
if (is.numeric(x) || is.logical(x)) return(TRUE)
nas <- is.na(x)
suppressWarnings(x_num <- as.numeric(as.character(x[!nas])))
return(!anyNA(x_num))
}
str2num <- function(x) {
nas <- is.na(x)
if (!is_(x, c("numeric", "logical"))) x <- as.character(x)
suppressWarnings(x_num <- as.numeric(x))
is.na(x_num[nas]) <- TRUE
return(x_num)
}
#Weighted mean faster than weighted.mean()
w.m <- function(x, w = NULL, na.rm = TRUE) {
if (is.null(w)) {
if (anyNA(x)) {
if (!na.rm) return(NA_real_)
nas <- which(is.na(x))
x <- x[-nas]
}
return(sum(x)/length(x))
}
else {
if (anyNA(x) || anyNA(w)) {
if (!na.rm) return(NA_real_)
nas <- which(is.na(x) | is.na(w))
x <- x[-nas]
w <- w[-nas]
}
return(sum(x*w)/sum(w))
}
}
#Input processing
method.to.proper.method <- function(method) {
method <- tolower(method)
if (method %in% c("ps")) return("ps")
else if (method %in% c("gbm", "gbr")) return("gbm")
else if (method %in% c("cbps", "cbgps")) return("cbps")
else if (method %in% c("npcbps", "npcbgps")) return("npcbps")
else if (method %in% c("entropy", "ebal", "ebalance")) return("ebal")
else if (method %in% c("ebcw", "ate")) return("ebcw")
else if (method %in% c("optweight", "opt", "sbw")) return("optweight")
else if (method %in% c("super", "superlearner")) return("super")
else if (method %in% c("bart")) return("bart")
else if (method %in% c("energy")) return("energy")
# else if (method %in% c("kbal")) return("kbal")
else return(method)
}
check.acceptable.method <- function(method, msm = FALSE, force = FALSE) {
bad.method <- FALSE
acceptable.methods <- c("ps"
, "gbm", "gbr"
, "cbps", "cbgps"
, "npcbps", "npcbgps"
, "ebal", "entropy", "ebalance"
, "sbw"
, "ebcw", "ate"
, "optweight", "opt", "sbw"
, "super", "superlearner"
, "energy"
, "bart"
# "kbal",
)
if (missing(method)) method <- "ps"
else if (is_null(method) || length(method) > 1) bad.method <- TRUE
else if (is.character(method)) {
if (tolower(method) %nin% acceptable.methods) bad.method <- TRUE
}
else if (!is.function(method)) bad.method <- TRUE
if (bad.method) {
if (identical(method, "twang")) stop('"twang" is no longer an acceptable argument to \'method\'. Please use "gmb" for generalized boosted modeling.', call. = FALSE)
stop("'method' must be a string of length 1 containing the name of an acceptable weighting method or a function that produces weights.", call. = FALSE)
}
if (msm && !force && is.character(method)) {
m <- method.to.proper.method(method)
if (m %in% c("nbcbps", "ebal", "ebcw", "optweight", "energy", "kbal")) {
stop(paste0("The use of ", method.to.phrase(m), " with longitudinal treatments has not been validated. Set weightit.force = TRUE to bypass this error."), call. = FALSE)
}
}
}
check.user.method <- function(method) {
#Check to make sure it accepts treat and covs
if (all(c("covs", "treat") %in% names(formals(method)))) {
}
# else if (all(c("covs.list", "treat.list") %in% names(formals(method)))) {
# }
else {
stop("The user-provided function to 'method' must contain \"covs\" and \"treat\" as named parameters.", call. = FALSE)
}
}
method.to.phrase <- function(method) {
if (is.function(method)) return("a user-defined method")
else {
method <- method.to.proper.method(method)
if (method %in% c("ps")) return("propensity score weighting")
else if (method %in% c("gbm")) return("propensity score weighting with GBM")
else if (method %in% c("cbps")) return("covariate balancing propensity score weighting")
else if (method %in% c("npcbps")) return("non-parametric covariate balancing propensity score weighting")
else if (method %in% c("ebal")) return("entropy balancing")
else if (method %in% c("ebcw")) return("empirical balancing calibration weighting")
else if (method %in% c("optweight")) return("targeted stable balancing weights")
else if (method %in% c("super")) return("propensity score weighting with SuperLearner")
else if (method %in% c("bart")) return("propensity score weighting with BART")
else if (method %in% c("energy")) return("energy balancing")
# else if (method %in% c("kbal")) return("kernel balancing")
else return("the chosen method of weighting")
}
}
process.estimand <- function(estimand, method, treat.type) {
#Allowable estimands
AE <- list(
binary = list(ps = c("ATT", "ATC", "ATE", "ATO", "ATM", "ATOS")
, gbm = c("ATT", "ATC", "ATE", "ATO", "ATM")
, cbps = c("ATT", "ATC", "ATE")
, npcbps = c("ATE")
, ebal = c("ATT", "ATC", "ATE")
, ebcw = c("ATT", "ATC", "ATE")
, optweight = c("ATT", "ATC", "ATE")
, super = c("ATT", "ATC", "ATE", "ATO", "ATM")
, energy = c("ATT", "ATC", "ATE")
, bart = c("ATT", "ATC", "ATE", "ATO", "ATM")
# , kbal = c("ATT", "ATC", "ATE")
),
multinomial = list(ps = c("ATT", "ATC", "ATE", "ATO", "ATM")
, gbm = c("ATT", "ATC", "ATE", "ATO", "ATM")
, cbps = c("ATT", "ATC", "ATE")
, npcbps = c("ATE")
, ebal = c("ATT", "ATC", "ATE")
, ebcw = c("ATT", "ATC", "ATE")
, optweight = c("ATT", "ATC", "ATE")
, super = c("ATT", "ATC", "ATE", "ATO", "ATM")
, energy = c("ATT", "ATC", "ATE")
, bart = c("ATT", "ATC", "ATE", "ATO", "ATM")
# , kbal = c("ATT", "ATE")
))
if (treat.type != "continuous" && !is.function(method)) {
if (is_null(estimand)) stop(paste0("estimand must be one of ", word_list(AE[[treat.type]][[method]], quotes = TRUE, and.or = "or"), "."), call. = FALSE)
else if (toupper(estimand) %nin% AE[[treat.type]][[method]]) {
stop(paste0("\"", estimand, "\" is not an allowable estimand for ", method.to.phrase(method),
" with ", treat.type, " treatments. Only ", word_list(AE[[treat.type]][[method]], quotes = TRUE, and.or = "and", is.are = TRUE),
" allowed."), call. = FALSE)
}
}
return(toupper(estimand))
}
check.subclass <- function(method, treat.type) {
#Allowable estimands
AE <- list(
binary = list(ps = TRUE
, gbm = TRUE
, cbps = TRUE
, npcbps = FALSE
, ebal = FALSE
, ebcw = FALSE
, optweight = FALSE
, super = TRUE
, energy = FALSE
, bart = TRUE
# , kbal = FALSE
),
multinomial = list(ps = TRUE
, gbm = TRUE
, cbps = FALSE
, npcbps = FALSE
, ebal = FALSE
, ebcw = FALSE
, optweight = FALSE
, super = TRUE
, energy = FALSE
, bart = TRUE
))
if (treat.type != "continuous" && !is.function(method) &&
!AE[[treat.type]][[method]]) {
stop(paste0("subclasses are not compatible with ", method.to.phrase(method),
" with ", treat.type, " treatments."), call. = FALSE)
}
}
process.ps <- function(ps, data = NULL, treat) {
if (is_not_null(ps)) {
if (is.character(ps) && length(ps)==1) {
if (is_null(data)) {
stop("'ps' was specified as a string but there was no argument to 'data'.", call. = FALSE)
}
else if (utils::hasName(data, ps)) {
ps <- data[[ps]]
}
else stop("The name supplied to 'ps' is not the name of a variable in 'ps'.", call. = FALSE)
}
else if (is.numeric(ps)) {
if (length(ps) != length(treat)) {
stop("'ps' must have the same number of units as the treatment.", call. = FALSE)
}
}
else {
stop("The argument to 'ps' must be a vector of propensity scores or the (quoted) names of the variable in 'data' that contains sampling weights.", call. = FALSE)
}
}
else ps <- NULL
return(ps)
}
process.focal.and.estimand <- function(focal, estimand, treat, treated = NULL) {
reported.estimand <- estimand
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
unique.treat <- unique(treat, nmax = switch(treat.type, "binary" = 2, "multinomial" = length(treat)/4))
#Check focal
if (is_not_null(focal) && (length(focal) > 1L || focal %nin% unique.treat)) {
stop("The argument supplied to 'focal' must be the name of a level of treatment.", call. = FALSE)
}
if (treat.type == "multinomial") {
if (estimand %nin% c("ATT", "ATC") && is_not_null(focal)) {
warning(paste(estimand, "is not compatible with 'focal'. Setting 'estimand' to \"ATT\"."), call. = FALSE, immediate. = TRUE)
reported.estimand <- estimand <- "ATT"
}
if (estimand == "ATT") {
if (is_null(focal)) {
if (is_null(treated) || treated %nin% unique.treat) {
stop("When estimand = \"ATT\" for multinomial treatments, an argument must be supplied to 'focal'.", call. = FALSE)
}
focal <- treated
}
}
else if (estimand == "ATC") {
if (is_null(focal)) {
stop("When estimand = \"ATC\" for multinomial treatments, an argument must be supplied to 'focal'.", call. = FALSE)
}
}
}
else if (treat.type == "binary") {
unique.treat.bin <- unique(binarize(treat), nmax = 2)
if (estimand %nin% c("ATT", "ATC") && is_not_null(focal)) {
warning(paste(estimand, "is not compatible with 'focal'. Setting 'estimand' to \"ATT\"."), call. = FALSE, immediate. = TRUE)
reported.estimand <- estimand <- "ATT"
}
if (is_null(treated) || treated %nin% unique.treat) {
if (is_null(focal)) {
if (all(as.character(unique.treat.bin) == as.character(unique.treat))) {
treated <- unique.treat[unique.treat.bin == 1]
}
else {
if (is.factor(treat)) treated <- levels(treat)[2]
else treated <- unique.treat[unique.treat.bin == 1]
if (estimand == "ATT") {
message(paste0("Assuming ", word_list(treated, quotes = !is.numeric(treat), is.are = TRUE),
" the treated level. If not, supply an argument to 'focal'."))
}
else if (estimand == "ATC") {
message(paste0("Assuming ", word_list(unique.treat[unique.treat %nin% treated], quotes = !is.numeric(treat), is.are = TRUE),
" the control level. If not, supply an argument to 'focal'."))
}
}
if (estimand == "ATT")
focal <- treated
else if (estimand == "ATC")
focal <- unique.treat[unique.treat %nin% treated]
}
else {
if (estimand == "ATT")
treated <- focal
else if (estimand == "ATC")
treated <- unique.treat[unique.treat %nin% focal]
}
if (estimand == "ATC") estimand <- "ATT"
}
else {
if (is_null(focal)) {
if (estimand == "ATT")
focal <- treated
else if (estimand == "ATC")
focal <- unique.treat[unique.treat %nin% treated]
}
if (estimand == "ATC") estimand <- "ATT"
}
}
return(list(focal = as.character(focal),
estimand = estimand,
reported.estimand = reported.estimand,
treated = if (is.factor(treated)) as.character(treated) else treated))
}
process.by <- function(by, data, treat, treat.name = NULL, by.arg = "by") {
##Process by
bad.by <- FALSE
n <- length(treat)
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
if (missing(by)) {
bad.by <- TRUE
}
else if (is_null(by)) {
by <- NULL
by.name <- NULL
}
else if (is.character(by) && length(by) == 1 && by %in% names(data)) {
by.name <- by
by <- data[[by]]
}
else if (length(dim(by)) == 2L && len(by) == n) {
by.name <- colnames(by)[1]
by <- drop(by[, 1])
}
else if (is.formula(by, 1)) {
t.c <- get_covs_and_treat_from_formula(by, data)
by <- t.c[["reported.covs"]]
if (NCOL(by) != 1) stop(paste0("Only one variable can be on the right-hand side of the formula for '", by.arg, "'."), call. = FALSE)
else by.name <- colnames(by)
}
else bad.by <- TRUE
if (!bad.by) {
by.components <- data.frame(by)
if (is_not_null(colnames(by))) names(by.components) <- colnames(by)
else names(by.components) <- by.name
if (is_null(by)) by.factor <- factor(rep(1L, n), levels = 1L)
else by.factor <- factor(by.components[[1]], levels = unique(by.components[[1]]),
labels = paste(names(by.components), "=", unique(by.components[[1]])))
# by.vars <- acceptable.bys[vapply(acceptable.bys, function(x) equivalent.factors(by, data[[x]]), logical(1L))]
}
else {
stop(paste0("'",by.arg, "' must be a string containing the name of the variable in data for which weighting is to occur within strata or a one-sided formula with the stratifying variable on the right-hand side."), call. = FALSE)
}
if (treat.type != "continuous" && any(vapply(levels(by.factor), function(x) nunique(treat) != nunique(treat[by.factor == x]), logical(1L)))) {
stop(paste0("Not all the groups formed by '", by.arg, "' contain all treatment levels", if (is_not_null(treat.name)) paste("in", treat.name) else "", ". Consider coarsening ", by.arg, "."), call. = FALSE)
}
attr(by.components, "by.factor") <- by.factor
return(by.components)
}
process.moments.int <- function(moments, int, method) {
if (is.function(method) || method %in% c("npcbps", "ebal", "ebcw", "optweight", "energy")) {
if (length(int) != 1 || !is.logical(int)) {
stop("int must be a logical (TRUE/FALSE) of length 1.", call. = FALSE)
}
if (is_not_null(moments)) {
if (length(moments) != 1 || !is.numeric(moments) ||
!check_if_zero(moments - round(moments))) {
if (method == "energy") {
if (moments < 0) stop("'moments' must be a nonnegative integer of length 1.", call. = FALSE)
}
else if (method %in% c("npcbps", "ebal", "ebcw", "optweight")) {
if (moments < 1) stop("'moments' must be a positive integer of length 1.", call. = FALSE)
}
moments <- as.integer(moments)
}
}
else {
if (!is.function(method) && method == "energy") moments <- 0L
else moments <- 1L
}
}
else if (is_not_null(moments) && any(mi0 <- c(as.integer(moments) != 1L, int))) {
warning(paste0(word_list(c("moments", "int")[mi0], and.or = "and", is.are = TRUE, quotes = 1),
" not compatible with ", method.to.phrase(method), ". Ignoring ", word_list(c("moments", "int")[mi0], and.or = "and", quotes = 1), "."), call. = FALSE, immediate. = TRUE)
moments <- NULL
int <- FALSE
}
moments <- as.integer(moments)
return(list(moments = moments, int = int))
}
process.MSM.method <- function(is.MSM.method, method) {
methods.with.MSM <- c("optweight")
if (is.function(method)) {
if (isTRUE(is.MSM.method)) stop("Currently, only user-defined methods that work with is.MSM.method = FALSE are allowed.", call. = FALSE)
is.MSM.method <- FALSE
}
else if (method %in% methods.with.MSM) {
if (is_null(is.MSM.method)) is.MSM.method <- TRUE
else if (!isTRUE(is.MSM.method)) {
message(paste0(method.to.phrase(method), " can be used with a single model when multiple time points are present.\nUsing a seperate model for each time point. To use a single model, set 'is.MSM.method' to TRUE."))
}
}
else {
if (isTRUE(is.MSM.method)) warning(paste0(method.to.phrase(method), " cannot be used with a single model when multiple time points are present.\nUsing a seperate model for each time point."),
call. = FALSE, immediate. = TRUE)
is.MSM.method <- FALSE
}
return(is.MSM.method)
}
process.missing <- function(missing, method, treat.type) {
#Allowable estimands
AE <- list(binary = list(ps = c("ind"
, "saem"
)
, gbm = c("ind", "surr")
, cbps = c("ind")
, npcbps = c("ind")
, ebal = c("ind")
, ebcw = c("ind")
, optweight = c("ind")
, super = c("ind")
, bart = c("ind")
, energy = c("ind")
# , kbal = c("ind")
),
multinomial = list(ps = c("ind")
, gbm = c("ind", "surr")
, cbps = c("ind")
, npcbps = c("ind")
, ebal = c("ind")
, ebcw = c("ind")
, optweight = c("ind")
, super = c("ind")
, bart = c("ind")
, energy = c("ind")
# , kbal = c("ind")
),
continuous = list(ps = c("ind"
, "saem"
)
, gbm = c("ind", "surr")
, cbps = c("ind")
, npcbps = c("ind")
, ebal = c("ind")
, ebcw = c("ind")
, optweight = c("ind")
, super = c("ind")
, bart = c("ind")
, energy = c("ind")
# , kbal = c("ind")
))
allowable.missings <- AE[[treat.type]][[method]]
if (is_null(missing)) {
missing <- allowable.missings[1]
warning(paste0("Missing values are present in the covariates. See ?WeightIt::method_",
method, " for information on how these are handled."), call. = FALSE, immediate. = TRUE)
}
else {
if (!is.character(missing) || length(missing) != 1) stop("'missing' must be a string of length 1.", call. = FALSE)
if (missing %pin% allowable.missings) {
missing <- allowable.missings[pmatch(missing, allowable.missings)]
}
else {
missing <- allowable.missings[1]
warning(paste0("Only ", word_list(allowable.missings, quotes = 2, is.are = TRUE), " allowed for 'missing' with ",
treat.type,
" treatments. Using link = ", word_list(allowable.missings[1], quotes = 2), "."),
call. = FALSE, immediate. = TRUE)
}
}
return(missing)
}
process.bin.vars <- function(bin.vars, mat) {
if (missing(bin.vars)) bin.vars <- is_binary_col(mat)
else if (is_null(bin.vars)) bin.vars <- rep(FALSE, ncol(mat))
else {
if (is.logical(bin.vars)) {
bin.vars[is.na(bin.vars)] <- FALSE
if (length(bin.vars) != ncol(mat)) stop("If 'bin.vars' is logical, it must have length equal to the number of columns of 'mat'.")
}
else if (is.numeric(bin.vars)) {
bin.vars <- bin.vars[!is.na(bin.vars) & bin.vars != 0]
if (any(bin.vars < 0) && any(bin.vars > 0)) stop("Positive and negative indices cannot be mixed with 'bin.vars'.")
if (any(abs(bin.vars) > ncol(mat))) stop("If 'bin.vars' is numeric, none of its values can exceed the number of columns of 'mat'.")
logical.bin.vars <- rep(any(bin.vars < 0), ncol(mat))
logical.bin.vars[abs(bin.vars)] <- !logical.bin.vars[abs(bin.vars)]
bin.vars <- logical.bin.vars
}
else if (is.character(bin.vars)) {
bin.vars <- bin.vars[!is.na(bin.vars) & bin.vars != ""]
if (is_null(colnames(mat))) stop("If 'bin.vars' is character, 'mat' must have column names.")
if (any(bin.vars %nin% colnames(mat))) stop("If 'bin.vars' is character, all its values must be column names of 'mat'.")
bin.vars <- colnames(mat) %in% bin.vars
}
else stop("'bin.vars' must be a logical, numeric, or character vector.")
}
return(bin.vars)
}
int.poly.f <- function(d, ex = NULL, int = FALSE, poly = 1, center = TRUE, orthogonal_poly = TRUE) {
#Adds to data frame interactions and polynomial terms
#d=matrix input
#ex=names of variables to exclude in interactions and polynomials; a subset of df
#int=whether to include interactions or not; currently only 2-way are supported
#poly=degree of polynomials to include; will also include all below poly. If 1, no polynomial will be included
if (is_null(ex)) ex <- rep(FALSE, ncol(d))
binary.vars <- is_binary_col(d)
if (center) {
d[,!binary.vars] <- center(d[, !binary.vars, drop = FALSE])
}
nd <- NCOL(d)
if (poly > 1) {
make.poly <- which(!binary.vars & !ex)
npol <- length(make.poly)
poly_terms <- poly_co.names <- make_list(npol)
if (npol > 0) {
for (i in seq_along(make.poly)) {
poly_terms[[i]] <- poly(d[, make.poly[i]], degree = poly, raw = !orthogonal_poly, simple = TRUE)[,-1, drop = FALSE]
poly_co.names[[i]] <- paste0(if (orthogonal_poly) "orth_", colnames(d)[make.poly[i]], num_to_superscript(2:poly))
}
}
}
else poly_terms <- poly_co.names <- list()
if (int && nd > 1) {
int_terms <- int_co.names <- make_list(1)
ints_to_make <- combn(colnames(d)[!ex], 2, simplify = FALSE)
if (is_not_null(ints_to_make)) {
int_terms[[1]] <- do.call("cbind", lapply(ints_to_make, function(i) d[,i[1]]*d[,i[2]]))
int_co.names[[1]] <- vapply(ints_to_make, paste, character(1L), collapse = " * ")
}
}
else int_terms <- int_co.names <- list()
if (is_not_null(poly_terms) || is_not_null(int_terms)) {
out <- do.call("cbind", c(poly_terms, int_terms))
out_co.names <- c(do.call("c", poly_co.names), do.call("c", int_co.names))
colnames(out) <- unlist(out_co.names)
#Remove single values
if (is_not_null(out)) {
single_value <- apply(out, 2, all_the_same)
out <- out[, !single_value, drop = FALSE]
}
}
else out <- NULL
return(out)
}
get.s.d.denom.weightit <- function(s.d.denom = NULL, estimand = NULL, weights = NULL, treat = NULL, focal = NULL) {
check.estimand <- check.weights <- check.focal <- FALSE
s.d.denom.specified <- is_not_null(s.d.denom)
estimand.specified <- is_not_null(estimand)
if (!is.factor(treat)) treat <- factor(treat)
if (s.d.denom.specified) {
allowable.s.d.denoms <- c("treated", "control", "pooled", "all", "weighted", "hedges")
try.s.d.denom <- tryCatch(match_arg(s.d.denom, allowable.s.d.denoms),
error = function(cond) NA_character_)
if (anyNA(try.s.d.denom)) {
check.estimand <- TRUE
}
else {
s.d.denom <- try.s.d.denom
}
}
else {
check.estimand <- TRUE
}
if (check.estimand) {
if (estimand.specified) {
allowable.estimands <- c("ATT", "ATC", "ATE", "ATO", "ATM")
try.estimand <- tryCatch(match_arg(toupper(estimand), allowable.estimands),
error = function(cond) NA_character_)
if (anyNA(try.estimand) || try.estimand %in% c("ATC", "ATT")) {
check.focal <- TRUE
}
else {
s.d.denom <- vapply(try.estimand, switch, FUN.VALUE = character(1L),
ATO = "weighted", ATM = "weighted", "pooled")
}
}
else {
check.focal <- TRUE
}
}
if (check.focal) {
if (is_not_null(focal)) {
s.d.denom <- focal
}
else check.weights <- TRUE
}
if (check.weights) {
if (is_null(weights)) {
s.d.denom <- "pooled"
}
else {
for (tv in levels(treat)) {
if (all_the_same(weights[treat == tv]) &&
!all_the_same(weights[treat != tv])) {
s.d.denom <- tv
}
else if (tv == last(levels(treat))) {
s.d.denom <- "pooled"
}
}
}
}
return(s.d.denom)
}
get.s.d.denom.cont.weightit <- function(s.d.denom = NULL) {
s.d.denom.specified <- is_not_null(s.d.denom)
if (s.d.denom.specified) {
allowable.s.d.denoms <- c("all", "weighted")
try.s.d.denom <- tryCatch(match_arg(s.d.denom, allowable.s.d.denoms),
error = function(cond) NA_character_)
if (anyNA(try.s.d.denom)) {
s.d.denom <- "all"
}
else {
s.d.denom <- try.s.d.denom
}
}
else {
s.d.denom <- "all"
}
return(s.d.denom)
}
compute_s.d.denom <- function(mat, treat, s.d.denom = "pooled", s.weights = NULL, bin.vars = NULL, subset = NULL, weighted.weights = NULL, to.sd = rep(TRUE, ncol(mat)), na.rm = TRUE) {
denoms <- setNames(rep(1, ncol(mat)), colnames(mat))
if (is.character(s.d.denom) && length(s.d.denom) == 1L) {
if (is_null(bin.vars)) {
bin.vars <- rep(FALSE, ncol(mat))
bin.vars[to.sd] <- is_binary_col(mat[subset, to.sd,drop = FALSE])
}
else if (!is.atomic(bin.vars) || length(bin.vars) != ncol(mat) ||
anyNA(as.logical(bin.vars))) {
stop("'bin.vars' must be a logical vector with length equal to the number of columns of 'mat'.")
}
possibly.supplied <- c("mat", "treat", "weighted.weights", "s.weights", "subset")
lengths <- setNames(vapply(mget(possibly.supplied), len, integer(1L)),
possibly.supplied)
supplied <- lengths > 0
if (!all_the_same(lengths[supplied])) {
stop(paste(word_list(possibly.supplied[supplied], quotes = 1), "must have the same number of units."))
}
if (lengths["weighted.weights"] == 0) weighted.weights <- rep(1, NROW(mat))
if (lengths["s.weights"] == 0) s.weights <- rep(1, NROW(mat))
if (lengths["subset"] == 0) subset <- rep(TRUE, NROW(mat))
else if (anyNA(as.logical(subset))) stop("'subset' must be a logical vector.")
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
cont.treat <- get_treat_type(treat) == "continuous"
if (!cont.treat) {
treat <- as.character(treat)
unique.treats <- unique(treat)
}
else unique.treats <- NULL
if (s.d.denom %in% unique.treats)
denom.fun <- function(mat, treat, s.weights, weighted.weights, bin.vars,
unique.treats, na.rm) {
sqrt(col.w.v(mat[treat == s.d.denom, , drop = FALSE],
w = s.weights[treat == s.d.denom],
bin.vars = bin.vars, na.rm = na.rm))
}
else if (s.d.denom == "pooled")
denom.fun <- function(mat, treat, s.weights, weighted.weights, bin.vars,
unique.treats, na.rm) {
sqrt(Reduce("+", lapply(unique.treats,
function(t) col.w.v(mat[treat == t, , drop = FALSE],
w = s.weights[treat == t],
bin.vars = bin.vars, na.rm = na.rm))) / length(unique.treats))
}
else if (s.d.denom == "all")
denom.fun <- function(mat, treat, s.weights, weighted.weights, bin.vars,
unique.treats, na.rm) {
sqrt(col.w.v(mat, w = s.weights, bin.vars = bin.vars, na.rm = na.rm))
}
else if (s.d.denom == "weighted")
denom.fun <- function(mat, treat, s.weights, weighted.weights, bin.vars,
unique.treats, na.rm) {
sqrt(col.w.v(mat, w = weighted.weights * s.weights, bin.vars = bin.vars, na.rm = na.rm))
}
else if (s.d.denom == "hedges")
denom.fun <- function(mat, treat, s.weights, weighted.weights, bin.vars,
unique.treats, na.rm) {
(1 - 3/(4*length(treat) - 9))^-1 * sqrt(Reduce("+", lapply(unique.treats,
function(t) (sum(treat == t) - 1) * col.w.v(mat[treat == t, , drop = FALSE],
w = s.weights[treat == t],
bin.vars = bin.vars, na.rm = na.rm))) / (length(treat) - 2))
}
else stop("s.d.denom is not an allowed value.")
denoms[to.sd] <- denom.fun(mat = mat[, to.sd, drop = FALSE], treat = treat, s.weights = s.weights,
weighted.weights = weighted.weights, bin.vars = bin.vars[to.sd],
unique.treats = unique.treats, na.rm = na.rm)
if (any(zero_sds <- check_if_zero(denoms[to.sd]))) {
denoms[to.sd][zero_sds] <- sqrt(col.w.v(mat[, to.sd, drop = FALSE][, zero_sds, drop = FALSE],
w = s.weights,
bin.vars = bin.vars[to.sd][zero_sds], na.rm = na.rm))
}
if (cont.treat) {
treat.sd <- denom.fun(mat = treat, s.weights = s.weights,
weighted.weights = weighted.weights, bin.vars = FALSE,
na.rm = na.rm)
denoms[to.sd] <- denoms[to.sd]*treat.sd
}
}
else {
if (is.numeric(s.d.denom)) {
if (is_not_null(names(s.d.denom)) && any(colnames(mat) %in% names(s.d.denom))) {
denoms[colnames(mat)[colnames(mat) %in% names(s.d.denom)]] <- s.d.denom[names(s.d.denom)[names(s.d.denom) %in% colnames(mat)]]
}
else if (length(s.d.denom) == sum(to.sd)) {
denoms[to.sd] <- s.d.denom
}
else if (length(s.d.denom) == ncol(mat)) {
denoms[] <- s.d.denom
}
else {
stop("'s.d.denom' must be an allowable value or a numeric vector of with length equal to the number of columns of 'mat'. See ?cobalt::col_w_smd for allowable values.")
}
}
else {
stop("'s.d.denom' must be an allowable value or a numeric vector of with length equal to the number of columns of 'mat'. See ?cobalt::col_w_smd for allowable values.")
}
}
return(denoms)
}
check_estimated_weights <- function(w, treat, treat.type, s.weights) {
tw <- w*s.weights
extreme.warn <- FALSE
if (treat.type == "continuous") {
if (all_the_same(w)) {
warning(paste0("All weights are ", w[1], ", possibly indicating an estimation failure."), call. = FALSE)
}
else if (sd(tw, na.rm = TRUE)/mean(tw, na.rm = TRUE) > 4) extreme.warn <- TRUE
}
else {
if (all_the_same(w)) {
warning(paste0("All weights are ", w[1], ", possibly indicating an estimation failure."), call. = FALSE)
}
else {
t.levels <- unique(treat)
bad.treat.groups <- setNames(rep(FALSE, length(t.levels)), t.levels)
for (i in t.levels) {
ti <- which(treat == i)
if (all(is.na(w[ti])) || all(w[ti] == 0)) bad.treat.groups[as.character(i)] <- TRUE
else if (!extreme.warn && sum(!is.na(tw[ti])) > 1 && sd(tw[ti], na.rm = TRUE)/mean(tw[ti], na.rm = TRUE) > 4) extreme.warn <- TRUE
}
if (any(bad.treat.groups)) {
n <- sum(bad.treat.groups)
warning(paste0("All weights are NA or 0 in treatment ", ngettext(n, "group ", "groups "),
word_list(t.levels[bad.treat.groups], quotes = TRUE), "."), call. = FALSE)
}
}
}
if (extreme.warn) warning("Some extreme weights were generated. Examine them with summary() and maybe trim them with trim().", call. = FALSE)
}
ps_to_ps_mat <- function(ps, treat, assumed.treated = NULL, treat.type = NULL, treated = NULL, estimand = NULL) {
if (is_(ps, c("matrix", "data.frame"))) {
ps.names <- rownames(ps)
}
else if (is_(ps, "numeric")) {
ps.names <- names(ps)
ps <- matrix(ps, ncol = 1)
}
if (is.factor(treat)) t.levels <- levels(treat)
else t.levels <- unique(treat, nmax = length(treat)/4)
if (treat.type == "binary") {
if ((is.matrix(ps) && is.numeric(ps)) ||
(is.data.frame(ps) && all(vapply(ps, is.numeric, logical(1L))))) {
if (ncol(ps) == 1) {
if (can_str2num(treat) &&
all(check_if_zero(binarize(treat) - str2num(treat)))) treated.level <- 1
else if (is_not_null(treated)) {
if (treated %in% treat) treated.level <- treated
else stop("The argument to 'treated' must be a value in treat.", call. = FALSE)
}
else if (is_not_null(assumed.treated)) {
treated.level <- assumed.treated
}
else if (is_not_null(colnames(ps)) && colnames(ps) %in% as.character(t.levels)) {
treated.level <- colnames(ps)
}
else {
stop("If the treatment has two non-0/1 levels and 'ps' is a vector or has only one column, an argument to 'treated' must be supplied.", call. = FALSE)
}
t.levels <- c(treated.level, t.levels[t.levels != treated.level])
ps <- matrix(c(ps[,1], 1-ps[,1]), ncol = 2, dimnames = list(ps.names, as.character(t.levels)))
}
else if (ncol(ps) == 2) {
if (!all(as.character(t.levels) %in% colnames(ps))) {
stop("If 'ps' has two columns, they must be named with the treatment levels.", call. = FALSE)
}
# else ps <- ps
}
else {
stop("'ps' cannot have more than two columns if the treatment is binary.", call. = FALSE)
}
}
else {
stop("'ps' must be a matrix, data frame, or vector of propensity scores.", call. = FALSE)
}
}
else if (treat.type == "multinomial") {
if ((is.matrix(ps) && is.numeric(ps)) ||
(is.data.frame(ps) && all(vapply(ps, is.numeric, logical(1L))))) {
if (ncol(ps) == 1) {
if (toupper(estimand) == "ATE") {
ps <- matrix(rep(ps, nunique(treat)), nrow = length(treat), dimnames = list(ps.names, t.levels))
}
else {
stop("With multinomial treatments, 'ps' can be a vector or have only one column only if the estimand is the ATE.", call. = FALSE)
}
}
else if (ncol(ps) == nunique(treat)) {
if (!all(t.levels %in% colnames(ps))) {
stop("The columns of 'ps' must be named with the treatment levels.", call. = FALSE)
}
else ps <- ps
}
else {
stop("'ps' must have as many columns as there are treatment levels.", call. = FALSE)
}
}
else {
stop("'ps' must be a matrix or data frame of propensity scores.", call. = FALSE)
}
}
return(ps)
}
subclass_ps <- function(ps_mat, treat, estimand = "ATE", focal = NULL, subclass) {
if (length(subclass) != 1 || !is.numeric(subclass)) {
stop("'subclass' must be a single number.", call. = FALSE)
}
else if (round(subclass) <= 1) {
stop("'subclass' must be greater than 1.", call. = FALSE)
}
subclass <- round(subclass)
if (is_not_null(focal)) ps_mat <- ps_mat[,c(focal, setdiff(colnames(ps_mat), focal))]
ps_sub <- sub_mat <- ps_mat * 0
for (i in colnames(ps_mat)) {
if (toupper(estimand) == "ATE") {
sub <- as.integer(findInterval(ps_mat[, as.character(i)],
quantile(ps_mat[, as.character(i)],
seq(0, 1, length.out = subclass + 1)),
all.inside = TRUE))
}
else if (toupper(estimand) == "ATT") {
if (i != focal) ps_mat[, as.character(i)] <- 1 - ps_mat[, as.character(i)]
sub <- as.integer(findInterval(ps_mat[, as.character(i)],
quantile(ps_mat[treat == focal, as.character(i)],
seq(0, 1, length.out = subclass + 1)),
all.inside = TRUE))
}
else {
stop("Only the ATE, ATT, and ATC are compatible with stratification weights.")
}
sub.tab <- table(treat, sub)
if (any(sub.tab == 0)) {
# stop("Too many subclasses were requested.", call. = FALSE)
sub <- subclass_scoot(sub, treat, ps_mat[,i])
sub.tab <- table(treat, sub)
}
sub <- as.character(sub)
sub.totals <- colSums(sub.tab)
sub.ps <- setNames(sub.tab[as.character(i), ] / sub.totals,
colnames(sub.tab))
ps_sub[,i] <- sub.ps[sub]
sub_mat[,i] <- sub
if (ncol(ps_sub) == 2) {
ps_sub[,colnames(ps_sub) != i] <- 1 - ps_sub[,i]
sub_mat[,colnames(sub_mat) != i] <- sub
break
}
}
attr(ps_sub, "sub_mat") <- sub_mat
return(ps_sub)
}
subclass_scoot <- function(sub, treat, x, min.n = 1) {
#Reassigns subclasses so there are no empty subclasses
#for each treatment group. min.n is the smallest a
#subclass is allowed to be.
treat <- as.character(treat)
unique.treat <- unique(treat, nmax = 2)
names(x) <- seq_along(x)
names(sub) <- seq_along(sub)
original.order <- names(x)
nsub <- nunique(sub)
#Turn subs into a contiguous sequence
sub <- setNames(setNames(seq_len(nsub), sort(unique(sub)))[as.character(sub)],
original.order)
if (any(table(treat) < nsub * min.n)) {
stop("Too many subclasses were requested.", call. = FALSE)
}
for (t in unique.treat) {
if (length(x[treat == t]) == nsub) {
sub[treat == t] <- seq_len(nsub)
}
}
if (any({sub_tab <- table(treat, sub)} == 0)) {
soft_thresh <- function(x, minus = 1) {
x <- x - minus
x[x < 0] <- 0
x
}
for (t in unique.treat) {
for (n in seq_len(min.n)) {
while (any(sub_tab[t,] == 0)) {
first_0 <- which(sub_tab[t,] == 0)[1]
if (first_0 == nsub ||
(first_0 != 1 &&
sum(soft_thresh(sub_tab[t, seq(1, first_0 - 1)]) / abs(first_0 - seq(1, first_0 - 1))) >=
sum(soft_thresh(sub_tab[t, seq(first_0 + 1, nsub)]) / abs(first_0 - seq(first_0 + 1, nsub))))) {
#If there are more and closer nonzero subs to the left...
first_non0_to_left <- max(seq(1, first_0 - 1)[sub_tab[t, seq(1, first_0 - 1)] > 0])
name_to_move <- names(sub)[which(x == max(x[treat == t & sub == first_non0_to_left]) & treat == t & sub == first_non0_to_left)[1]]
sub[name_to_move] <- first_0
sub_tab[t, first_0] <- 1L
sub_tab[t, first_non0_to_left] <- sub_tab[t, first_non0_to_left] - 1L
}
else {
#If there are more and closer nonzero subs to the right...
first_non0_to_right <- min(seq(first_0 + 1, nsub)[sub_tab[t, seq(first_0 + 1, nsub)] > 0])
name_to_move <- names(sub)[which(x == min(x[treat == t & sub == first_non0_to_right]) & treat == t & sub == first_non0_to_right)[1]]
sub[name_to_move] <- first_0
sub_tab[t, first_0] <- 1L
sub_tab[t, first_non0_to_right] <- sub_tab[t, first_non0_to_right] - 1L
}
}
sub_tab[t,] <- sub_tab[t,] - 1
}
}
#Unsort
sub <- sub[names(sub)]
}
return(sub)
}
stabilize_w <- function(weights, treat) {
if (is.factor(treat)) t.levels <- levels(treat)
else t.levels <- unique(treat)
w.names <- names(weights)
tab <- setNames(vapply(t.levels, function(x) mean_fast(treat == x), numeric(1L)), t.levels)
return(setNames(weights * tab[as.character(treat)], w.names))
}
`%+%` <- function(...) {
if (is_(..1, "atomic") && is_(..2, "atomic")) crayon::`%+%`(as.character(..1), as.character(..2))
else ggplot2::`%+%`(...)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.