R/get.w.R

Defines functions get.w.sbwcau get.w.wimids get.w.mimids get.w.designmatch get.w.weightit get.w.cem.match get.w.optmatch get.w.ebalance get.w.CBMSM get.w.CBPS get.w.Match get.w.iptw get.w.ps.cont get.w.mnps get.w.ps get.w.matchit get.w

Documented in get.w get.w.CBMSM get.w.CBPS get.w.cem.match get.w.designmatch get.w.ebalance get.w.iptw get.w.Match get.w.matchit get.w.mimids get.w.mnps get.w.optmatch get.w.ps get.w.ps.cont get.w.sbwcau get.w.weightit get.w.wimids

#' @title Extract Weights from Preprocessing Objects
#' 
#' @description Extracts weights from the outputs of preprocessing functions.
#' 
#' @param x output from the corresponding preprocessing packages.
#' @param stop.method the name of the stop method used in the original call to `ps()` or `mnps()` in \pkg{twang}, e.g., `"es.mean"`. If empty, will return weights from all stop method available into a data.frame. Abbreviations allowed.
#' @param estimand if weights are computed using the propensity score (i.e., for the `ps` and `CBPS` methods), which estimand to use to compute the weights. If `"ATE"`, weights will be computed as `1/ps` for the treated group and `1/(1-ps)` for the control group. If `"ATT"`, weights will be computed as `1` for the treated group and `ps/(1-ps)` for the control group. If not specified, `get.w()` will try to figure out which estimand is desired based on the object.
#' 
#' If weights are computed using subclasses/matching strata (i.e., for the `cem` and `designmatch` methods), which estimand to use to compute the weights. First, a subclass propensity score is computed as the proportion of treated units in each subclass, and the one of the formulas above will be used based on the estimand requested. If not specified, `"ATT"` is assumed.
#' @param treat a vector of treatment status for each unit. This is required for methods that include `treat` as an argument. The treatment variable that was used in the original preprocessing function call should be used.
#' @param s.weights whether the sampling weights included in the original call to the fitting function should be included in the weights. If `TRUE`, the returned weights will be the product of the balancing weights estimated by the fitting function and the sampling weights. If `FALSE`, only the balancing weights will be returned.
#' @param ... arguments passed to other methods.
#' 
#' @returns A vector or data frame of weights for each unit. These may be matching weights or balancing weights.
#' 
#' @details The output of `get.w()` can be used in calls to the formula and data frame methods of [bal.tab()] (see example below). In this way, the output of multiple preprocessing packages can be viewed simultaneously and compared. The weights can also be used in `weights` statements in regression methods to compute weighted effects.
#' 
#' \pkg{twang} has a function called `get.weights()` that performs the same function on `ps` objects but offers slightly finer control. Note that the weights generated by `get.w()` for `ps` objects do not include sampling weights by default.
#' 
#' When sampling weights are used with `CBPS()` in \pkg{CBPS}, the returned weights will already have the sampling weights incorporated. To retrieve the balancing weights on their own, divide the returned weights by the original sampling weights. For other packages, the balancing weights are returned separately unless `s.weights = TRUE`, which means they must be multiplied by the sampling weights for effect estimation.
#' 
#' When `Match()` in \pkg{Matching} is used with `CommonSupport = TRUE`, the returned weights will be incorrect. This option is not recommended by the package authors.
#' 
#' @examplesIf all(sapply(c("WeightIt", "MatchIt"), requireNamespace, quietly = TRUE))
#' data("lalonde", package = "cobalt")
#' 
#' m.out <- MatchIt::matchit(treat ~ age + educ + race,
#'                           data = lalonde,
#'                           estimand = "ATT") 
#' 
#' w.out <- WeightIt::weightit(treat ~ age + educ + race,
#'                             data = lalonde,
#'                             estimand = "ATT")
#' 
#' bal.tab(treat ~ age + educ + race, data = lalonde,
#'         weights = data.frame(matched = get.w(m.out),
#'                              weighted = get.w(w.out)),
#'         method = c("matching", "weighting"), 
#'         estimand = "ATT")

#' @rdname get.w
#' @export 
get.w <- function(x, ...) {
    if (!inherits(x, "cobalt.processed.obj")) {
        x <- process_obj(x)
        get.w(x, ...)
    }
    else {
        UseMethod("get.w")
    }
}

#' @rdname get.w
#' @exportS3Method get.w matchit
get.w.matchit <- function(x,...) {
    x$weights
}

#' @rdname get.w
#' @exportS3Method get.w ps
get.w.ps <- function(x, stop.method = NULL, estimand, s.weights = FALSE, ...) {
    
    if (!missing(estimand)) estimand <- tolower(estimand)
    else estimand <- NULL
    
    if (is_not_null(stop.method)) {
        if (any(is.character(stop.method))) {
            rule1 <- names(x$w)[vapply(tolower(names(x$w)), function(x) any(startsWith(x, tolower(stop.method))), logical(1L))]
            if (is_null(rule1)) {
                .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead",
                             word_list(names(x$w), and.or = "or", quotes = 2)))
                rule1 <- names(x$w)
            }
        }
        else if (is.numeric(stop.method) && any(stop.method %in% seq_along(names(x$w)))) {
            if (any(stop.method %nin% seq_along(names(x$w)))) {
                .wrn(sprintf("there are %s stop methods available, but you requested %s", 
                             length(names(x$w)),
                             word_list(stop.method[stop.method %nin% seq_along(names(x$w))], and.or = "and")))
            }
            rule1 <- names(x$w)[stop.method %in% seq_along(names(x$w))]
        }
        else {
            .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead",
                         word_list(names(x$w), and.or = "or", quotes = 2)))
            rule1 <- names(x$w)
        }
    }
    else {
        rule1 <- names(x$w)
    }
    
    s <- names(x$w)[match(tolower(rule1), tolower(names(x$w)))]
    criterion <- substr(tolower(s), 1, nchar(s)-4)
    allowable.estimands <- c("ATT", "ATE", "ATC")
    
    if (is_null(estimand)) estimand <- setNames(substr(toupper(s), nchar(s)-2, nchar(s)), s)
    else if (!all(toupper(estimand) %in% allowable.estimands)) {
        .err(sprintf("`estimand` must be %s", word_list(allowable.estimands, "or", quotes = 1)))
    }
    else {
        if (length(estimand) == 1) estimand <- setNames(toupper(rep(estimand, length(s))), s)
        else if (length(estimand) >= length(s)) estimand <- setNames(toupper(estimand[seq_along(s)]), s)
        else .err("`estimand` must be the same length as the number of sets of weights requested")
    }
    
    w <- setNames(as.data.frame(matrix(1, nrow = nrow(x$ps), ncol = length(s))), s)
    for (p in s) {
        if (estimand[p] == "ATT") w[[p]] <- x$treat + (1-x$treat)*x$ps[,p]/(1-x$ps[,p])
        else if (estimand[p] == "ATE") w[[p]] <- x$treat/x$ps[,p] + (1-x$treat)/(1-x$ps[,p])
        else if (estimand[p] == "ATC") w[[p]] <- (1-x$treat) + x$treat*x$ps[,p]/(1-x$ps[,p])
        else w[[p]] <- x$w[,p]
        if (s.weights) w[[p]] <- w[[p]] * x$sampw
    }
    
    names(w) <- ifelse(toupper(substr(s, nchar(s)-2, nchar(s))) == estimand,
                       criterion,
                       sprintf("%s (%s)", criterion, estimand))
    if (ncol(w) == 1) w <- w[[1]]
    
    w
}

#' @rdname get.w
#' @exportS3Method get.w mnps
get.w.mnps <- function(x, stop.method = NULL, s.weights = FALSE, ...) {
    
    if (is_not_null(stop.method)) {
        if (is.character(stop.method)) {
            rule1 <- names(x$w)[vapply(tolower(names(x$w)), function(x) any(startsWith(x, tolower(stop.method))), logical(1L))]
            if (is_null(rule1)) {
                .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead",
                             word_list(x$stopMethods, and.or = "or", quotes = 2)))
                rule1 <- x$stopMethods
            }
        }
        else if (is.numeric(stop.method) && any(stop.method %in% seq_along(x$stopMethods))) {
            if (any(stop.method %nin% seq_along(x$stopMethods))) {
                .wrn(sprintf("there are %s stop methods available, but you requested %s",
                             length(x$stopMethods), 
                             word_list(stop.method[stop.method %nin% seq_along(x$stopMethods)], and.or = "and")))
            }
            rule1 <- x$stopMethods[stop.method %in% seq_along(x$stopMethods)]
        }
        else {
            .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead",
                         word_list(x$stopMethods, and.or = "or", quotes = 2)))
            rule1 <- x$stopMethods
        }
    }
    else {
        rule1 <- x$stopMethods
    }
    
    s <- paste.(x$stopMethods[match(tolower(rule1), tolower(x$stopMethods))],
                x$estimand)
    
    estimand <- x$estimand
    criterion <- x$stopMethods[match(tolower(rule1), tolower(x$stopMethods))]
    
    w <- setNames(as.data.frame(matrix(1, nrow = length(x$treatVar), ncol = length(s))),
                  criterion)
    
    if (estimand == "ATT") {
        for (i in x$levExceptTreatATT) {
            if (length(s) > 1) {
                w[x$treatVar == i, criterion] <- get.w.ps(x$psList[[i]])[x$psList[[i]]$treat == FALSE, criterion]
            }
            else {
                w[x$treatVar == i, criterion] <- get.w.ps(x$psList[[i]])[x$psList[[i]]$treat == FALSE]
            }
        }
    }
    else if (estimand == "ATE") {
        for (i in x$treatLev) {
            if (length(s) > 1) {
                w[x$treatVar == i, criterion] <- get.w.ps(x$psList[[i]])[x$psList[[i]]$treat == TRUE, criterion]
            }
            else {
                w[x$treatVar == i, criterion] <- get.w.ps(x$psList[[i]])[x$psList[[i]]$treat == TRUE]
            }
        }
    }
    
    if (s.weights) {
        w <- w * x$sampw
    }
    
    names(w) <- ifelse(toupper(substr(s, nchar(s)-2, nchar(s))) == estimand, criterion, paste0(criterion, " (", estimand, ")"))
    
    if (ncol(w) == 1) w <- w[[1]]
    
    w
}

#' @rdname get.w
#' @exportS3Method get.w ps.cont
get.w.ps.cont <- function(x, s.weights = FALSE, ...) {
    if (isTRUE(s.weights)) return(x$w * x$sampw)
    
    x$w
}

#' @rdname get.w
#' @exportS3Method get.w iptw
get.w.iptw <- function(x, stop.method = NULL, s.weights = FALSE, ...) {
    
    if (is_not_null(stop.method)) {
        if (any(is.character(stop.method))) {
            rule1 <- names(x$psList[[1]]$ps)[vapply(tolower(names(x$psList[[1]]$ps)), function(x) any(startsWith(x, tolower(stop.method))), logical(1L))]
            if (is_null(rule1)) {
                .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead.",
                             word_list(names(x$psList[[1]]$ps), and.or = "or", quotes = 2)))
                rule1 <- names(x$psList[[1]]$ps)
            }
        }
        else if (is.numeric(stop.method) && any(stop.method %in% seq_along(names(x$psList[[1]]$ps)))) {
            if (any(stop.method %nin% seq_along(names(x$psList[[1]]$ps)))) {
                .wrn(sprintf("there are %s stop methods available, but you requested %s"),
                     length(names(x$psList[[1]]$ps)), 
                     word_list(stop.method[stop.method %nin% seq_along(names(x$psList[[1]]$ps))], and.or = "and"))
            }
            rule1 <- names(x$psList[[1]]$ps)[stop.method %in% seq_along(names(x$psList[[1]]$ps))]
        }
        else {
            .wrn(sprintf("`stop.method` should be %s.\nUsing all available stop methods instead",
                         word_list(names(x$psList[[1]]$ps), and.or = "or", quotes = 2)))
            rule1 <- names(x$psList[[1]]$ps)
        }
    }
    else {
        rule1 <- names(x$psList[[1]]$ps)
    }
    
    w <- setNames(as.data.frame(matrix(NA, nrow = nrow(x$psList[[1]]$ps),
                                       ncol = length(rule1))),
                  rule1)
    for (i in rule1) {
        w[i] <- Reduce("*", lapply(x$psList, get.w.ps, stop.method = i))
    }
    
    if (s.weights) {
        w <- w * x$psList[[1]]$sampw
    }
    
    w
}

#' @rdname get.w
#' @exportS3Method get.w Match
get.w.Match <- function(x, ...) {
    # x$weights <- x$weights / ave(x$weights, x$index.treated, FUN = sum)
    vapply(seq_len(x$orig.nobs), function(i) {
        sum(x$weights[x$index.treated == i | x$index.control == i])
    }, numeric(1L))
}

#' @rdname get.w
#' @exportS3Method get.w CBPS
get.w.CBPS <- function(x, estimand, ...) {
    A <- list(...)
    use.weights <- if_null_then(A$use.weights, TRUE)
    
    if (!missing(estimand)) estimand <- tolower(estimand)
    else estimand <- NULL
    
    if (inherits(x, "CBPSContinuous") || inherits(x, "npCBPS") || is.factor(x$y)) { #continuous, multi, or npCBPS
        return(x$weights)
    }
    
    if (use.weights) {
        return(x$weights)
    }
    
    ps <- x$fitted.values
    t <- x$y 
    if (is_null(estimand)) {
        if (all_the_same(x$weights[t == 1])) {
            estimand <- "att"
        }
        else estimand <- "ate"
    }
    
    estimand <- match_arg(tolower(estimand), c("att", "atc", "ate"))
    switch(estimand, 
           "att" = t + (1-t)*ps/(1-ps),
           "atc" = t*(1-ps)/ps + (1-t),
           t/ps + (1-t)/(1-ps))
}

#' @rdname get.w
#' @exportS3Method get.w CBMSM
get.w.CBMSM <- function(x, ...) {
    x$weights[sort(unique(x$id))]
}

#' @rdname get.w
#' @exportS3Method get.w ebalance
get.w.ebalance <- function(x, treat, ...) {
    .chk_not_missing(treat, "`treat`")
    
    if (!inherits(treat, "processed.treat")) treat <- process_treat(treat)
    
    weights <- rep(1, length(treat))
    
    if (length(x$w) != sum(treat == treat_vals(treat)["Control"])) {
        .err("there are more control units in `treat` than weights in the `ebalance` object.")
    }
    weights[treat == treat_vals(treat)["Control"]] <- x$w
    
    weights
}

#' @rdname get.w
#' @exportS3Method get.w optmatch
get.w.optmatch <- function(x, estimand, ...) {
    if (missing(estimand) || is_null(estimand)) estimand <- "ATT"
    treat <- as.numeric(attr(x, "contrast.group"))
    strata2weights(x, treat = treat, estimand = estimand)
}

#' @rdname get.w
#' @exportS3Method get.w cem.match
get.w.cem.match <- function(x, estimand, ...) {
    A <- list(...)
    if (missing(estimand) || is_null(estimand)) estimand <- "ATT"
    if (isTRUE(A[["use.match.strata"]])) {
        if (inherits(x, "cem.match.list")) {
            return(unlist(lapply(x[vapply(x, is_, logical(1L), "cem.match")], function(cm) strata2weights(cm[["mstrata"]], treat = cm[["groups"]], estimand = estimand)), use.names = FALSE))
        }
        return(strata2weights(x[["mstrata"]], treat = x[["groups"]], estimand = estimand))
    }
    
    if (inherits(x, "cem.match.list")) {
        return(unlist(grab(x[vapply(x, is_, logical(1L), "cem.match")], "w"), use.names = FALSE))
    }
    
    x[["w"]]
}

#' @rdname get.w
#' @exportS3Method get.w weightit
get.w.weightit <- function(x, s.weights = FALSE, ...) {
    if (isTRUE(s.weights)) return(x$weights * x$s.weights)
    
    x$weights
}

#' @rdname get.w
#' @exportS3Method get.w designmatch
get.w.designmatch <- function(x, treat, estimand, ...) {
    .chk_not_missing(treat, "`treat`")
    if (missing(estimand) || is_null(estimand)) estimand <- "ATT"
    if (length(x[["group_id"]]) != length(x[["t_id"]]) + length(x[["c_id"]])) {
        .err("`designmatch` objects without 1:1 matching cannot be used")
    }
    q <- merge(data.frame(id = seq_along(treat)), 
               data.frame(id = c(x[["t_id"]], x[["c_id"]]),
                          group = factor(x[["group_id"]])),
               all.x = TRUE, by = "id")
    q <- q[order(q$id), , drop = FALSE]
    
    strata2weights(q$group, treat, estimand)
}

#' @rdname get.w
#' @exportS3Method get.w mimids
get.w.mimids <- function(x, ...) {
    old_version <- !all(c("object", "models", "approach") %in% names(x))
    
    weights <- {
        if (old_version) unlist(lapply(x[["models"]][-1], get.w.matchit))
        else unlist(lapply(x[["models"]], get.w.matchit))
    }
    weights[is.na(weights)] <- 0
    
    weights
}

#' @rdname get.w
#' @exportS3Method get.w wimids
get.w.wimids <- function(x, ...) {
    old_version <- !all(c("object", "models", "approach") %in% names(x))
    
    weights <- {
        if (old_version) unlist(lapply(x[["models"]][-1], get.w.weightit))
        else unlist(lapply(x[["models"]], get.w.weightit))
    }
    weights[is.na(weights)] <- 0
    
    weights
}

#' @rdname get.w
#' @exportS3Method get.w sbwcau
get.w.sbwcau <- function(x, ...) {
    x[["dat_weights"]][[ncol(x[["dat_weights"]])]]
}

Try the cobalt package in your browser

Any scripts or data that you put into this service are public.

cobalt documentation built on Nov. 21, 2023, 1:06 a.m.