Nothing
#' Subgroup Balancing Propensity Score
#'
#' @description
#' Implements the subgroup balancing propensity score (SBPS), which
#' is an algorithm that attempts to achieve balance in subgroups by sharing
#' information from the overall sample and subgroups (Dong, Zhang, Zeng, & Li,
#' 2020; DZZL). Each subgroup can use either weights estimated using the whole
#' sample, weights estimated using just that subgroup, or a combination of the
#' two. The optimal combination is chosen as that which minimizes an imbalance
#' criterion that includes subgroup as well as overall balance.
#'
#' @param obj a `weightit` object containing weights estimated in the overall
#' sample.
#' @param obj2 a `weightit` object containing weights estimated in the
#' subgroups. Typically this has been estimated by including `by` in the call
#' to [weightit()]. Either `obj2` or `moderator` must be specified.
#' @param moderator optional; a string containing the name of the variable in
#' `data` for which weighting is to be done within subgroups or a one-sided
#' formula with the subgrouping variable on the right-hand side. This argument
#' is analogous to the `by` argument in `weightit()`, and in fact it is passed
#' on to `by`. Either `obj2` or `moderator` must be specified.
#' @param formula an optional formula with the covariates for which balance is
#' to be optimized. If not specified, the formula in `obj$call` will be used.
#' @param data an optional data set in the form of a data frame that contains
#' the variables in `formula` or `moderator`.
#' @param smooth `logical`; whether the smooth version of the SBPS should be
#' used. This is only compatible with `weightit` methods that return a
#' propensity score.
#' @param full.search `logical`; when `smooth = FALSE`, whether every
#' combination of subgroup and overall weights should be evaluated. If
#' `FALSE`, a stochastic search as described in DZZL will be used instead. If
#' `TRUE`, all \eqn{2^R} combinations will be checked, where \eqn{R} is the
#' number of subgroups, which can take a long time with many subgroups. If
#' unspecified, will default to `TRUE` if \eqn{R \le 8} and `FALSE` otherwise.
#'
#' @returns
#' A `weightit.sbps` object, which inherits from `weightit`. This
#' contains all the information in `obj` with the weights, propensity scores,
#' call, and possibly covariates updated from `sbps()`. In addition, the
#' `prop.subgroup` component contains the values of the coefficients \eqn{C} for the
#' subgroups (which are either 0 or 1 for the standard SBPS), and the
#' `moderator` component contains a data.frame with the moderator.
#'
#' This object has its own summary method and is compatible with \pkg{cobalt}
#' functions. The `cluster` argument should be used with \pkg{cobalt} functions
#' to accurately reflect the performance of the weights in balancing the
#' subgroups.
#'
#' @details
#' The SBPS relies on two sets of weights: one estimated in the overall
#' sample and one estimated within each subgroup. The algorithm decides whether
#' each subgroup should use the weights estimated in the overall sample or those
#' estimated in the subgroup. There are \eqn{2^R} permutations of overall and subgroup
#' weights, where \eqn{R} is the number of subgroups. The optimal permutation is
#' chosen as that which minimizes a balance criterion as described in DZZL. The
#' balance criterion used here is, for binary and multi-category treatments, the
#' sum of the squared standardized mean differences within subgroups and
#' overall, which are computed using [cobalt::col_w_smd()], and for continuous
#' treatments, the sum of the squared correlations between each covariate and
#' treatment within subgroups and overall, which are computed using
#' [cobalt::col_w_corr()].
#'
#' The smooth version estimates weights that determine the relative contribution
#' of the overall and subgroup propensity scores to a weighted average
#' propensity score for each subgroup. If \eqn{P_O} are the propensity scores
#' estimated in the overall sample and \eqn{P_S} are the propensity scores estimated
#' in each subgroup, the smooth SBPS finds \eqn{R} coefficients \eqn{C} so that for each
#' subgroup, the ultimate propensity score is \eqn{C*P_S + (1-C)*P_O}, and
#' weights are computed from this propensity score. The coefficients are
#' estimated using [optim()] with `method = "L-BFGS-B"`. When \eqn{C} is estimated to
#' be 1 or 0 for each subgroup, the smooth SBPS coincides with the standard
#' SBPS.
#'
#' If `obj2` is not specified and `moderator` is, `sbps()` will attempt to refit
#' the model specified in `obj` with the `moderator` in the `by` argument. This
#' relies on the environment in which `obj` was created to be intact and can
#' take some time if `obj` was hard to fit. It's safer to estimate `obj` and
#' `obj2` (the latter simply by including the moderator in the `by` argument)
#' and supply these to `sbps()`.
#'
#' @seealso [weightit()], [summary.weightit()]
#'
#' @references
#' Dong, J., Zhang, J. L., Zeng, S., & Li, F. (2020). Subgroup balancing propensity score. *Statistical Methods in Medical Research*, 29(3), 659–676. \doi{10.1177/0962280219870836}
#'
#' @examples
#' library("cobalt")
#' data("lalonde", package = "cobalt")
#'
#' #Balancing covariates between treatment groups within races
#' (W1 <- weightit(treat ~ age + educ + married +
#' nodegree + race + re74, data = lalonde,
#' method = "glm", estimand = "ATT"))
#'
#' (W2 <- weightit(treat ~ age + educ + married +
#' nodegree + race + re74, data = lalonde,
#' method = "glm", estimand = "ATT",
#' by = "race"))
#'
#' S <- sbps(W1, W2)
#'
#' print(S)
#'
#' summary(S)
#'
#' bal.tab(S, cluster = "race")
#'
#' #Could also have run
#' #S <- sbps(W1, moderator = "race")
#'
#' S_ <- sbps(W1, W2, smooth = TRUE)
#'
#' print(S_)
#'
#' summary(S_)
#'
#' bal.tab(S_, cluster = "race")
#' @export
sbps <- function(obj, obj2 = NULL, moderator = NULL, formula = NULL,
data = NULL, smooth = FALSE, full.search) {
if (is_null(obj2) && is_null(moderator)) {
arg::err("either {.arg obj2} or {.arg moderator} must be specified")
}
treat <- obj[["treat"]]
treat.type <- get_treat_type(treat)
focal <- obj[["focal"]]
estimand <- obj[["estimand"]]
data.list <- list(data, obj2[["covs"]], obj[["covs"]])
combined.data <- do.call("data.frame", clear_null(data.list))
processed.moderator <- .process_by(moderator, data = clear_null(combined.data),
treat = obj[["treat"]], treat.name = NULL,
by.arg = "moderator")
moderator.factor <- .attr(processed.moderator, "by.factor")
if (is_not_null(obj2)) {
if (!inherits(obj2, "weightit")) {
arg::err("{.arg obj2} must be a {.cls weightit} object, ideally with a {.field by} component")
}
if (is_not_null(obj2[["by"]])) {
if (is_null(obj[["by"]])) {
processed.moderator <- obj2[["by"]]
moderator.factor <- .attr(processed.moderator, "by.factor")
}
else if (is_null(processed.moderator)) {
arg::err("cannot figure out moderator. Please supply a value to {.arg moderator}")
}
}
else if (is_null(processed.moderator)) {
arg::err("no moderator was specified")
}
}
else {
call <- obj[["call"]]
if (is_not_null(obj[["by"]])) {
call[["by"]] <- setNames(data.frame(factor(paste(processed.moderator[[1L]],
obj[["by"]][[1L]], sep = " | "))),
paste(names(processed.moderator), names(obj[["by"]]), sep = " | "))
}
else {
call[["by"]] <- processed.moderator
}
obj2 <- eval(call, obj[["env"]])
}
if (smooth && (is_null(obj[["ps"]]) || is_null(obj2[["ps"]]))) {
arg::err("smooth SBPS can only be used with methods that produce a propensity score")
}
if (is_null(formula)) {
formula <- obj[["formula"]]
}
t.c <- terms(formula) |>
delete.response() |>
get_covs_and_treat_from_formula2(combined.data)
if (is_null(t.c[["reported.covs"]])) {
arg::err("no covariates were found")
}
covs <- t.c[["model.covs"]]
s.weights <- obj[["s.weights"]]
mod.split <- cobalt::splitfactor(moderator.factor, drop.first = "if2")
same.as.moderator <- apply(covs, 2L, function(.c) {
any_apply(mod.split, equivalent.factors, .c)
})
covs <- covs[, !same.as.moderator, drop = FALSE]
bin.vars <- is_binary_col(covs)
s.d.denom <- get.s.d.denom.weightit(estimand = obj[["estimand"]],
weights = obj[["weights"]],
treat = treat)
R <- levels(moderator.factor)
if (smooth) {
if (!missing(full.search)) {
arg::wrn("{.arg full.search} is ignored when {.code smooth = TRUE}")
}
ps_o <- obj[["ps"]]
ps_s <- obj2[["ps"]]
get_w_smooth <- function(coefs, moderator.factor, treat, ps_o, ps_s, estimand) {
ind.coefs <- coefs[moderator.factor] #Gives each unit the coef for their subgroup
ps_ <- (1 - ind.coefs) * ps_o + ind.coefs * ps_s
get_w_from_ps(ps_, treat, estimand)
}
get_F_smooth <- function(ps_o, ps_s, treat.type, ...) {
coefs <- unlist(list(...))
w_ <- get_w_smooth(coefs, moderator.factor, treat, ps_o, ps_s, estimand = obj[["estimand"]])
if (treat.type == "binary") {
F0_o <- cobalt::col_w_smd(covs, treat, w_, std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights, bin.vars = bin.vars)
F0_s <- unlist(lapply(R, function(g) cobalt::col_w_smd(covs[moderator.factor == g, , drop = FALSE],
treat[moderator.factor == g], w_[moderator.factor == g],
std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
else if (treat.type %in% c("multinomial", "multi-category")) {
if (is_not_null(focal)) {
bin.treat <- as.numeric(treat == focal)
s.d.denom <- switch(estimand, ATT = "treated", ATC = "control", "all")
F0_o <- unlist(lapply(levels(treat)[levels(treat) != focal], function(t) {
cobalt::col_w_smd(covs, bin.treat, w_, std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights, bin.vars = bin.vars,
subset = treat %in% c(t, focal))
}))
F0_s <- unlist(lapply(levels(treat)[levels(treat) != focal], function(t) {
unlist(lapply(R, function(g) cobalt::col_w_smd(covs[moderator.factor == g, , drop = FALSE],
bin.treat[moderator.factor == g], w_[moderator.factor == g],
std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars,
subset = treat[moderator.factor == g] %in% c(t, focal))))
}))
}
else {
F0_o <- unlist(lapply(levels(treat), function(t) {
covs_i <- rbind(covs, covs[treat == t, , drop = FALSE])
treat_i <- c(rep.int(1, nrow(covs)), rep.int(0, sum(treat == t)))
w_i <- c(rep.int(1, nrow(covs)), w_[treat == t])
if (is_not_null(s.weights)) s.weights_i <- c(s.weights, s.weights[treat == t])
else s.weights_i <- NULL
cobalt::col_w_smd(covs_i, treat_i, w_i, std = TRUE, s.d.denom = "treated",
abs = TRUE, s.weights = s.weights_i, bin.vars = bin.vars)
}))
F0_s <- unlist(lapply(levels(treat), function(t) {
covs_i <- rbind(covs, covs[treat == t, , drop = FALSE])
treat_i <- c(rep.int(1, nrow(covs)), rep.int(0, sum(treat == t)))
w_i <- c(rep.int(1, nrow(covs)), w_[treat == t])
moderator.factor_i <- c(moderator.factor, moderator.factor[treat == t])
s.weights_i <- {
if (is_null(s.weights)) NULL
else c(s.weights, s.weights[treat == t])
}
unlist(lapply(R, function(g) cobalt::col_w_smd(covs_i[moderator.factor_i == g, , drop = FALSE],
treat_i[moderator.factor_i == g], w_i[moderator.factor_i == g],
std = TRUE, s.d.denom = "treated",
abs = TRUE, s.weights = s.weights_i[moderator.factor_i == g],
bin.vars = bin.vars)))
}))
}
}
else if (treat.type == "continuous") {
F0_o <- cobalt::col_w_corr(covs, treat, w_, abs = TRUE, s.weights = s.weights, bin.vars = bin.vars)
F0_s <- unlist(lapply(R, function(g) cobalt::col_w_corr(covs[moderator.factor == g, , drop = FALSE],
treat[moderator.factor == g], w_[moderator.factor == g],
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
# F0_g <- cobalt::col_w_smd(cobalt::splitfactor(moderator.factor, drop.first = FALSE),
# treat, w_, std = FALSE,
# abs = TRUE, s.weights = s.weights,
# bin.vars = rep.int(TRUE, length(R)))
sum(F0_o^2) + sum(F0_s^2) #+ sum(F0_g^2)
}
opt.out <- optim(rep_with(.5, R), fn = get_F_smooth,
ps_o = ps_o, ps_s = ps_s, treat.type = treat.type,
lower = 0, upper = 1,
method = "L-BFGS-B")
s_min <- setNames(opt.out$par, R) #coef is proportion subgroup vs. overall
weights <- get_w_smooth(s_min, moderator.factor, treat, ps_o, ps_s, estimand = obj[["estimand"]])
ps <- (1 - s_min[moderator.factor]) * ps_o + s_min[moderator.factor] * ps_s
}
else {
w_o <- obj[["weights"]]
w_s <- obj2[["weights"]]
if (missing(full.search)) {
full.search <- (length(R) <= 8)
}
else {
arg::arg_flag(full.search)
}
get_w <- function(s, moderator.factor, w_o, w_s) {
#Get weights for given permutation of "O" and "S"
w_ <- numeric(length(moderator.factor))
for (g in levels(moderator.factor)) {
if (s[g] == 0) w_[moderator.factor == g] <- w_o[moderator.factor == g]
else if (s[g] == 1) w_[moderator.factor == g] <- w_s[moderator.factor == g]
}
w_
}
get_F <- function(s, moderator.factor, w_o, w_s, treat.type) {
#Get value of loss function for given permutation of "O" and "S"
w_ <- get_w(s, moderator.factor, w_o, w_s)
if (treat.type == "binary") {
F0_o <- cobalt::col_w_smd(covs, treat, w_, std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights, bin.vars = bin.vars)
F0_s <- unlist(lapply(R, function(g) cobalt::col_w_smd(covs[moderator.factor == g, , drop = FALSE],
treat[moderator.factor == g], w_[moderator.factor == g],
std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
else if (treat.type %in% c("multinomial", "multi-category")) {
if (is_not_null(focal)) {
bin.treat <- as.numeric(treat == focal)
s.d.denom <- switch(estimand, ATT = "treated", ATC = "control", "all")
F0_o <- unlist(lapply(levels(treat)[levels(treat) != focal], function(t) {
cobalt::col_w_smd(covs, bin.treat, w_, std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights, bin.vars = bin.vars,
subset = treat %in% c(t, focal))
}))
F0_s <- unlist(lapply(levels(treat)[levels(treat) != focal], function(t) {
unlist(lapply(R, function(g) cobalt::col_w_smd(covs[moderator.factor == g, , drop = FALSE],
bin.treat[moderator.factor == g], w_[moderator.factor == g],
std = TRUE, s.d.denom = s.d.denom,
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars,
subset = treat[moderator.factor == g] %in% c(t, focal))))
}))
}
else {
F0_o <- unlist(lapply(levels(treat), function(t) {
covs_i <- rbind(covs, covs[treat == t, , drop = FALSE])
treat_i <- c(rep.int(1, nrow(covs)), rep.int(0, sum(treat == t)))
w_i <- c(rep.int(1, nrow(covs)), w_[treat == t])
if (is_not_null(s.weights)) s.weights_i <- c(s.weights, s.weights[treat == t])
else s.weights_i <- NULL
cobalt::col_w_smd(covs_i, treat_i, w_i, std = TRUE, s.d.denom = "treated",
abs = TRUE, s.weights = s.weights_i, bin.vars = bin.vars)
}))
F0_s <- unlist(lapply(levels(treat), function(t) {
covs_i <- rbind(covs, covs[treat == t, , drop = FALSE])
treat_i <- c(rep.int(1, nrow(covs)), rep.int(0, sum(treat == t)))
w_i <- c(rep.int(1, nrow(covs)), w_[treat == t])
moderator.factor_i <- c(moderator.factor, moderator.factor[treat == t])
s.weights_i <- if (is_not_null(s.weights)) c(s.weights, s.weights[treat == t])
unlist(lapply(R, function(g) cobalt::col_w_smd(covs_i[moderator.factor_i == g, , drop = FALSE],
treat_i[moderator.factor_i == g], w_i[moderator.factor_i == g],
std = TRUE, s.d.denom = "treated",
abs = TRUE, s.weights = s.weights_i[moderator.factor_i == g],
bin.vars = bin.vars)))
}))
}
}
else if (treat.type == "continuous") {
F0_o <- cobalt::col_w_corr(covs, treat, w_, abs = TRUE, s.weights = s.weights, bin.vars = bin.vars)
F0_s <- unlist(lapply(R, function(g) cobalt::col_w_corr(covs[moderator.factor == g, , drop = FALSE],
treat[moderator.factor == g], w_[moderator.factor == g],
abs = TRUE, s.weights = s.weights[moderator.factor == g],
bin.vars = bin.vars)))
}
# F0_g <- cobalt::col_w_smd(cobalt::splitfactor(moderator.factor, drop.first = FALSE),
# treat, w_, std = FALSE,
# abs = TRUE, s.weights = s.weights,
# bin.vars = rep.int(TRUE, length(R)))
sum(F0_o^2) + sum(F0_s^2) #+ sum(F0_g^2)
}
if (full.search) {
S <- as.matrix(setNames(do.call("expand.grid", replicate(length(R), 0:1, simplify = FALSE)),
R))
F_min <- Inf
for (i in seq_row(S)) {
s_try <- S[i, ]
F_try <- get_F(s_try, moderator.factor, w_o, w_s, treat.type)
if (F_try < F_min) {
F_min <- F_try
s_min <- s_try
}
}
}
else {
#Stochastic search described by DZZL
s_min <- setNames(rep_with(0, R), R)
F_min <- get_F(s_min, moderator.factor, w_o, w_s, treat.type)
L1 <- 25L
L2 <- 10L
k <- 0L
iters_since_change <- 0L
while (k < L1 || iters_since_change < L2) {
s_try <- setNames(sample(0:1, length(R), replace = TRUE), R)
F_try <- get_F(s_try, moderator.factor, w_o, w_s, treat.type)
Ar <- sample(R)
#Optimize s_try for given Ar
repeat {
s_try_prev <- s_try
for (i in Ar) {
s_alt <- s_try
s_alt[i] <- if (s_try[i] == 0) 1 else 0
F_alt <- get_F(s_alt, moderator.factor, w_o, w_s, treat.type)
if (F_alt < F_try) {
s_try <- s_alt
F_try <- F_alt
}
}
if (identical(s_try_prev, s_try)) {
break
}
}
if (F_try < F_min) {
F_min <- F_try
s_min <- s_try
iters_since_change <- 0L
}
else {
iters_since_change <- iters_since_change + 1L
}
k <- k + 1L
}
}
weights <- get_w(s_min, moderator.factor, w_o, w_s)
ps <- {
if (is_null(obj[["ps"]]) || is_null(obj2[["ps"]])) NULL
else get_w(s_min, moderator.factor, obj[["ps"]], obj2[["ps"]])
}
}
out <- obj
out[["covs"]] <- t.c[["simple.covs"]]
out[["weights"]] <- weights
out[["ps"]] <- ps
out[["moderator"]] <- processed.moderator
out[["prop.subgroup"]] <- s_min
out[["call"]] <- match.call()
attr(out, "Mparts") <- NULL
attr(out, "Mparts.list") <- NULL
out <- clear_null(out)
class(out) <- c("weightit.sbps", "weightit")
out
}
#' @exportS3Method summary weightit.sbps
summary.weightit.sbps <- function(object, top = 5L, ignore.s.weights = FALSE, weight.range = TRUE, ...) {
arg::arg_count(top)
arg::arg_flag(ignore.s.weights)
arg::arg_flag(weight.range)
sw <- {
if (ignore.s.weights || is_null(object$s.weights)) rep.int(1, nobs(object))
else object$s.weights
}
mod <- object$moderator
mod_factor <- .attr(mod, "by.factor")
out.list <- make_list(levels(mod_factor))
for (i in levels(mod_factor)) {
in_i<- which(mod_factor == i)
obj <- as.weightit(object$weights[in_i],
treat = object$treat[in_i],
s.weights = sw[in_i])
out.list[[i]] <- summary.weightit(obj, top = top, ignore.s.weights = ignore.s.weights,
weight.range = weight.range, ...)
}
attr(out.list, "prop.subgroup") <- matrix(c(1 - object$prop.subgroup,
object$prop.subgroup),
nrow = 2L, byrow = TRUE,
dimnames = list(c("Overall", "Subgroup"),
names(object$prop.subgroup)))
class(out.list) <- "summary.weightit.sbps"
out.list
}
#' @exportS3Method print summary.weightit.sbps
print.summary.weightit.sbps <- function(x, ...) {
cli::cat_line(space(18L), .ul("Summary of weights"), "\n")
cli::cat_line("- ",
.it("Overall vs. subgroup proportion contribution"),
":\n")
.attr(x, "prop.subgroup") |>
round_df_char(2L, pad = " ") |>
print.data.frame()
for (g in seq_along(x)) {
cat("\n")
cli::cat_line(.st(space(19L)),
.it(sprintf(" Subgroup: %s ", names(x)[g])),
.st(space(19L)))
cat("\n")
.print_summary_weightit_internal(x[[g]], ...)
}
invisible(x)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.