#' Generate Balancing Weights with Minimal Input Processing
#'
#' @description `weightit.fit()` dispatches one of the weight estimation methods
#' determined by `method`. It is an internal function called by [weightit()] and
#' should probably not be used except in special cases. Unlike `weightit()`,
#' `weightit.fit()` does not accept a formula and data frame interface and
#' instead requires the covariates and treatment to be supplied as a numeric
#' matrix and atomic vector, respectively. In this way, `weightit.fit()` is to
#' `weightit()` what [lm.fit()] is to [lm()] - a thinner, slightly faster
#' interface that performs minimal argument checking.
#'
#' @inheritParams weightit
#' @param covs a numeric matrix of covariates.
#' @param treat a vector of treatment statuses.
#' @param method a string containing the name of the method that will be used to
#' estimate weights. See [weightit()] for allowable options. The default is
#' `"glm"` for propensity score weighting using a generalized linear model to
#' estimate the propensity score.
#' @param s.weights a numeric vector of sampling weights. See the individual
#' pages for each method for information on whether sampling weights can be
#' supplied.
#' @param by.factor a factor variable for which weighting is to be done within
#' levels. Corresponds to the `by` argument in [weightit()].
#' @param stabilize `logical`; whether or not to stabilize the weights. For the
#' methods that involve estimating propensity scores, this involves
#' multiplying each unit's weight by the proportion of units in their
#' treatment group. Default is `FALSE`. Note this differs from its use with
#' [weightit()].
#' @param focal when `estimand` is set to `"ATT"` or `"ATC"`, which group to
#' consider the "treated" or "control" group. This group will not be weighted,
#' and the other groups will be weighted to resemble the focal group. If
#' specified, `estimand` will automatically be set to `"ATT"` (with a warning
#' if `estimand` is not `"ATT"` or `"ATC"`). See section *`estimand` and
#' `focal`* in Details at [weightit()].
#' @param ps a vector of propensity scores. If specified, `method` will be
#' ignored and set to `"glm"`.
#' @param moments,int,subclass arguments to customize the weight estimation. See
#' [weightit()] for details.
#' @param missing `character`; how missing data should be handled. The options
#' depend on the `method` used. If `NULL`, `covs` will be checked for `NA`
#' values, and if present, `missing` will be set to `"ind"`. If `""`, `covs`
#' will not be checked for `NA` values; this can be faster when it is known
#' there are none.
#' @param ... other arguments for functions called by `weightit.fit()` that
#' control aspects of fitting that are not covered by the above arguments.
#'
#' @returns A `weightit.fit` object with the following elements:
#' \item{weights}{The estimated weights, one for each unit.} \item{treat}{The
#' values of the treatment variable.} \item{estimand}{The estimand requested.}
#' \item{method}{The weight estimation method specified.}
#' \item{ps}{The estimated or provided propensity scores. Estimated propensity scores are
#' returned for binary treatments and only when `method` is `"glm"`, `"gbm"`, `"cbps"`, `"ipt"`, `"super"`, or `"bart"`. The propensity score corresponds to the predicted probability of being treated; see section *`estimand` and `focal`* in Details at [weightit()] for how the treated group is determined.}
#' \item{s.weights}{The provided sampling weights.} \item{focal}{The focal
#' treatment level if the ATT or ATC was requested.} \item{fit.obj}{When
#' `include.obj = TRUE`, the fit object.} \item{info}{Additional information
#' about the fitting. See the individual methods pages for what is included.}
#'
#' The `weightit.fit` object does not have specialized `print()`, `summary()`,
#' or `plot()` methods. It is simply a list containing the above components. Use
#' [as.weightit()] to convert it to a `weightit` object, which does have these
#' methods. See Examples.
#'
#' @details `weightit.fit()` is called by [weightit()] after the arguments to
#' `weightit()` have been checked and processed. `weightit.fit()` dispatches the
#' function used to actually estimate the weights, passing on the supplied
#' arguments directly. `weightit.fit()` is not meant to be used by anyone other
#' than experienced users who have a specific use case in mind. The returned
#' object contains limited information about the supplied arguments or details
#' of the estimation method; all that is processed by `weightit()`.
#'
#' Less argument checking or processing occurs in `weightit.fit()` than does in
#' `weightit()`, which means supplying incorrect arguments can result in errors,
#' crashes, and invalid weights, and error and warning messages may not be
#' helpful in diagnosing the problem. `weightit.fit()` does check to make sure
#' weights were actually estimated, though.
#'
#' `weightit.fit()` may be most useful in speeding up simulation simulation
#' studies that use `weightit()` because the covariates can be supplied as a
#' numeric matrix, which is often how they are generated in simulations, without
#' having to go through the potentially slow process of extracting the
#' covariates and treatment from a formula and data frame. If the user is
#' certain the arguments are valid (e.g., by ensuring the estimated weights are
#' consistent with those estimated from `weightit()` with the same arguments),
#' less time needs to be spent on processing the arguments. Also, the returned
#' object is much smaller than a `weightit` object because the covariates are
#' not returned alongside the weights.
#'
#' @seealso [weightit()], which you should use for estimating weights unless you
#' know better.
#'
#' [as.weightit()] for converting a `weightit.fit` object to a `weightit`
#' object.
#'
#' @examples
#'
#' library("cobalt")
#' data("lalonde", package = "cobalt")
#'
#' # Balancing covariates between treatment groups (binary)
#' covs <- lalonde[c("age", "educ", "race", "married",
#' "nodegree", "re74", "re75")]
#' ## Create covs matrix, splitting any factors using
#' ## cobalt::splitfactor()
#' covs_mat <- as.matrix(splitfactor(covs))
#'
#' WF1 <- weightit.fit(covs_mat, treat = lalonde$treat,
#' method = "glm", estimand = "ATT")
#' str(WF1)
#'
#' # Converting to a weightit object for use with
#' # summary() and bal.tab()
#' W1 <- as.weightit(WF1, covs = covs)
#' W1
#' summary(W1)
#' bal.tab(W1)
#'
#' @export
weightit.fit <- function(covs, treat, method = "glm", s.weights = NULL, by.factor = NULL,
estimand = "ATE", focal = NULL, stabilize = FALSE,
ps = NULL, moments = NULL, int = FALSE,
subclass = NULL, missing = NULL,
verbose = FALSE, include.obj = FALSE, ...) {
A <- list(...)
#Checks
if (!check_if_call_from_fun(weightit) && !check_if_call_from_fun(weightitMSM)) {
chk::chk_not_missing(covs, "`covs`")
chk::chk_matrix(covs)
chk::chk_numeric(covs)
chk::chk_not_missing(treat, "`treat`")
chk::chk_vector(treat)
chk::chk_numeric(treat)
chk::chk_not_any_na(treat)
chk::chk_flag(stabilize)
chk::chk_flag(verbose)
chk::chk_flag(include.obj)
if (length(treat) != nrow(covs)) {
.err("`treat` and `covs` must contain the same number of units")
}
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
.check_acceptable_method(method, msm = FALSE, force = FALSE)
.check_method_treat.type(method, treat.type)
if (is_not_null(ps)) {
chk::chk_vector(ps)
chk::chk_numeric(ps)
if (length(ps) != length(treat)) {
.err("`ps` and `treat` must be the same length")
}
if (is_not_null(ps) && is_not_null(method) &&
!is.function(method) &&
!identical(method, "glm")) {
.wrn("`ps` is supplied, so `method` will be ignored")
}
method <- NULL
}
if (is_null(s.weights)) {
s.weights <- rep.int(1.0, length(treat))
}
else {
.chk_basic_vector(s.weights)
chk::chk_numeric(s.weights)
.check_method_s.weights(method, s.weights)
if (length(s.weights) != length(treat)) {
.err("`s.weights` and `treat` must be the same length")
}
}
if (is_null(by.factor)) {
by.factor <- factor(rep.int(1L, length(treat)), levels = 1L)
}
else {
chk::chk_factor(by.factor)
if (length(by.factor) != length(treat)) {
.err("`by.factor` and `treat` must be the same length")
}
}
#Process estimand and focal
estimand <- .process_estimand(estimand, method, treat.type)
f.e.r <- .process_focal_and_estimand(focal, estimand, treat)
focal <- f.e.r[["focal"]]
estimand <- f.e.r[["estimand"]]
if (treat.type == "binary") {
attr(treat, "treated") <- f.e.r[["treated"]]
}
chk::chk_null_or(missing, vld = chk::vld_string)
if (is_null(missing)) {
if (is_not_null(method) && anyNA(covs)) missing <- "ind"
else missing <- ""
}
else if (nzchar(missing) && anyNA(covs)) {
missing <- .process_missing(missing, method)
}
else {
missing <- ""
}
#Check subclass
if (is_not_null(subclass)) {
.check_subclass(method, treat.type)
}
#Process moments and int
m.i.q <- .process_moments_int_quantile(moments = moments,
int = int,
quantile = A[["quantile"]],
method = method)
moments <- m.i.q[["moments"]]
int <- m.i.q[["int"]]
A["quantile"] <- m.i.q["quantile"]
}
else {
if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
treat.type <- get_treat_type(treat)
}
missing <- .process_missing2(missing, covs)
out <- make_list(c("weights", "treat", "estimand", "method", "ps", "s.weights",
"focal", "missing", "fit.obj", "info"))
out$weights <- out$ps <- rep.int(NA_real_, length(treat))
if (include.obj) {
fit.obj <- make_list(levels(by.factor))
}
info <- make_list(levels(by.factor))
obj <- NULL
.check_required_packages(method)
if (is.function(method)) {
fun <- "weightit2user"
}
else {
fun <- {
if (is_not_null(ps)) "weightit2ps"
else if (is_null(method)) "weightit2null"
else sprintf("weightit2%s", method)
}
fun <- switch(treat.type,
"multi-category" = paste.(fun, "multi"),
"continuous" = paste.(fun, "cont"),
fun)
}
A["covs"] <- list(covs)
A["treat"] <- list(treat)
A["s.weights"] <- list(s.weights)
A["estimand"] <- list(estimand)
A["focal"] <- list(focal)
A["stabilize"] <- list(stabilize)
A["subclass"] <- list(subclass)
A["ps"] <- list(ps)
A["moments"] <- list(moments)
A["int"] <- list(int)
A["missing"] <- list(missing)
A["verbose"] <- list(verbose)
for (i in levels(by.factor)) {
A["subset"] <- list(by.factor == i)
#Run method
if (is.function(method)) {
A["Fun"] <- list(method)
}
obj <- do.call(fun, A)
#Extract weights
if (is_null(obj)) {
.err("no object was created. This is probably a bug, and you should report it at https://github.com/ngreifer/WeightIt/issues")
}
if (is_null(obj$w) || all(is.na(obj$w))) {
.wrn("no weights were estimated. This is probably a bug, and you should report it at https://github.com/ngreifer/WeightIt/issues")
}
if (!all(is.finite(obj$w))) {
.wrn("some weights were estimated as `NA`, which means a value was impossible to compute (e.g., Inf). Check for extreme values of the treatment or covariates and try removing them. Non-finite weights will be set to 0")
obj$w[!is.finite(obj$w)] <- 0
}
# else if (any(!is.finite(obj$w))) probably_a_bug()
out$weights[by.factor == i] <- obj$w
if (is_not_null(obj$ps)) {
out$ps[by.factor == i] <- obj$ps
}
if (include.obj && is_not_null(obj$fit.obj)) {
fit.obj[[i]] <- obj$fit.obj
}
info[[i]] <- obj$info
}
if (include.obj) {
if (any(lengths(fit.obj) > 0L)) {
if (nlevels(by.factor) == 1L) {
fit.obj <- fit.obj[[1L]]
}
out$fit.obj <- fit.obj
}
else {
.wrn("`include.obj` was set to `TRUE`, but no fit object was returned by the fitting function")
}
}
if (nlevels(by.factor) == 1L) {
attr(out, "Mparts") <- obj$Mparts
}
if (is_not_null(info) && nlevels(by.factor) == 1L) {
info <- info[[1L]]
}
out$info <- info
if (all(is.na(out$ps))) {
out$ps <- NULL
}
out$treat <- treat
out$estimand <- estimand
out$method <- method
out$s.weights <- s.weights
out$focal <- focal
out$missing <- missing
class(out) <- "weightit.fit"
out
}
weightitMSM.fit <- function(covs.list, treat.list, method = "glm", s.weights = NULL, by.factor = NULL,
estimand = "ATE", focal = NULL, stabilize = FALSE,
moments = NULL, int = FALSE,
subclass = NULL, is.MSM.method = FALSE, missing = NULL,
verbose = FALSE, include.obj = FALSE, ...) {
A <- list(...)
#Checks
if (check_if_call_from_fun(weightitMSM)) {
for (i in seq_along(treat.list)) {
if (!has_treat_type(treat.list[[i]])) {
treat.list[[i]] <- assign_treat_type(treat.list[[i]])
}
}
}
else {
chk::chk_not_missing(covs.list, "`covs.list`")
chk::chk_list(covs.list)
for (i in seq_along(covs.list)) {
chk::chk_matrix(covs.list[[i]], sprintf("`covs.list[[%s]]", i))
chk::chk_numeric(covs.list[[i]], sprintf("`covs.list[[%s]]", i))
}
chk::chk_not_missing(treat.list, "`treat.list`")
chk::chk_list(treat.list)
for (i in seq_along(treat.list)) {
chk::chk_vector(treat.list[[i]], sprintf("`treat.list[[%s]]", i))
chk::chk_numeric(treat.list[[i]], sprintf("`treat.list[[%s]]", i))
chk::chk_not_any_na(treat.list[[i]], sprintf("`treat.list[[%s]]", i))
}
chk::chk_flag(stabilize)
chk::chk_flag(verbose)
chk::chk_flag(include.obj)
for (i in seq_along(treat.list)) {
n <- length(treat.list[[i]])
if (i > 1L && n != length(treat.list[[1L]])) {
.err("the number of units must be the same for each time point")
}
if (nrow(covs.list[[i]]) != n) {
.err("treatment and covariates must have the same number of units")
}
if (anyNA(treat.list[[i]])) {
.err(sprintf("no missing values are allowed in the treatment variable. Missing values found in treat.list[[%s]]", i))
}
treat.list[[i]] <- assign_treat_type(treat.list[[i]])
}
.check_acceptable_method(method, msm = TRUE, force = FALSE)
missing <- {
if (is_null(method) || !any_apply(covs.list, anyNA)) ""
else .process_missing(missing, method)
}
if (is_null(s.weights)) {
s.weights <- rep.int(1, n)
}
else {
chk::chk_vector(s.weights)
chk::chk_numeric(s.weights)
.check_method_s.weights(method, s.weights)
if (length(s.weights) != n) {
.err("`s.weights` must have length equal to the number of units")
}
}
if (is_null(by.factor)) {
by.factor <- factor(rep.int(1, n), levels = 1L)
}
else {
chk::chk_factor(by.factor)
if (length(by.factor) != n) {
.err("`by.factor` must have length equal to the number of units")
}
}
#Process moments and int
m.i.q <- .process_moments_int_quantile(moments = moments,
int = int,
quantile = A[["quantile"]],
method = method)
moments <- m.i.q[["moments"]]
int <- m.i.q[["int"]]
A["quantile"] <- m.i.q["quantile"]
}
out <- make_list(c("weights", "treat.list", "method", "s.weights", "missing",
"fit.obj", "info"))
out$weights <- rep.int(NA_real_, length(treat.list[[1L]]))
if (include.obj) {
fit.obj <- make_list(levels(by.factor))
}
info <- make_list(levels(by.factor))
obj <- NULL
.check_required_packages(method)
if (is.function(method)) {
fun <- "weightitMSM2user"
}
else {
fun <- {
if (is_null(method)) "weightitMSM2null"
else sprintf("weightitMSM2%s", method)
}
}
A["covs.list"] <- list(covs.list)
A["treat.list"] <- list(treat.list)
A["s.weights"] <- list(s.weights)
A["estimand"] <- list(estimand)
A["focal"] <- list(focal)
A["moments"] <- list(moments)
A["int"] <- list(int)
A["stabilize"] <- list(stabilize)
A["subclass"] <- list(subclass)
A["missing"] <- list(missing)
A["verbose"] <- list(verbose)
for (i in levels(by.factor)) {
A["subset"] <- list(by.factor == i)
#Run method
if (is.function(method)) {
A["Fun"] <- list(method)
}
obj <- do.call(fun, A)
#Extract weights
if (is_null(obj)) {
.err("no object was created. This is probably a bug, and you should report it at https://github.com/ngreifer/WeightIt/issues")
}
if (is_null(obj$w) || all(is.na(obj$w))) {
.wrn("no weights were estimated. This is probably a bug, and you should report it at https://github.com/ngreifer/WeightIt/issues")
}
if (!all(is.finite(obj$w))) {
.wrn("some weights were estimated as `NA`, which means a value was impossible to compute (e.g., Inf). Check for extreme values of the treatment or covariates and try removing them. Non-finite weights will be set to 0")
obj$w[!is.finite(obj$w)] <- 0
}
# else if (any(!is.finite(obj$w))) probably_a_bug()
out$weights[by.factor == i] <- obj$w
if (include.obj && is_not_null(obj$fit.obj)) {
fit.obj[[i]] <- obj$fit.obj
}
info[i] <- list(obj$info)
}
if (include.obj) {
if (any(lengths(fit.obj) > 0L)) {
if (nlevels(by.factor) == 1L) {
fit.obj <- fit.obj[[1L]]
}
out$fit.obj <- fit.obj
}
else {
.wrn("`include.obj` was set to `TRUE`, but no fit object was returned by the fitting function")
}
}
if (nlevels(by.factor) == 1L) {
attr(out, "Mparts") <- obj$Mparts
}
if (is_not_null(info) && nlevels(by.factor) == 1L) {
info <- info[[1L]]
}
out$info <- info
out$treat.list <- treat.list
out$method <- method
out$s.weights <- s.weights
out$missing <- missing
class(out) <- "weightitMSM.fit"
out
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.