R/weightitMSM.R

Defines functions print.weightitMSM weightitMSM

Documented in weightitMSM

#' Generate Balancing Weights for Longitudinal Treatments
#'
#' @description
#' `weightitMSM()` allows for the easy generation of balancing weights for
#' marginal structural models for time-varying treatments using a variety of
#' available methods for binary, continuous, and multi-category treatments. Many
#' of these methods exist in other packages, which [weightit()] calls; these
#' packages must be installed to use the desired method.
#'
#' @inheritParams weightit
#' @param formula.list a list of formulas corresponding to each time point with
#' the time-specific treatment variable on the left hand side and pre-treatment
#' covariates to be balanced on the right hand side. The formulas must be in
#' temporal order, and must contain all covariates to be balanced at that time
#' point (i.e., treatments and covariates featured in early formulas should
#' appear in later ones). Interactions and functions of covariates are allowed.
#' @param data an optional data set in the form of a data frame that contains
#' the variables in the formulas in `formula.list`. This must be a wide
#' data set with exactly one row per unit.
#' @param method a string of length 1 containing the name of the method that
#' will be used to estimate weights. See [weightit()] for allowable options.
#' The default is `"glm"`, which estimates the weights using generalized
#' linear models.
#' @param stabilize `logical`; whether or not to stabilize the weights.
#' Stabilizing the weights involves fitting a model predicting treatment at
#' each time point from treatment status at prior time points. If `TRUE`,
#' a fully saturated model will be fit (i.e., all interactions between all
#' treatments up to each time point), essentially using the observed treatment
#' probabilities in the numerator (for binary and multi-category treatments). This
#' may yield an error if some combinations are not observed. Default is
#' `FALSE`. To manually specify stabilization model formulas, e.g., to
#' specify non-saturated models, use `num.formula`. With many time points,
#' saturated models may be time-consuming or impossible to fit.
#' @param num.formula optional; a one-sided formula with the stabilization
#' factors (other than the previous treatments) on the right hand side, which
#' adds, for each time point, the stabilization factors to a model saturated
#' with previous treatments. See Cole & Hernán (2008) for a discussion of how
#' to specify this model; including stabilization factors can change the
#' estimand without proper adjustment, and should be done with caution. Can
#' also be a list of one-sided formulas, one for each time point. Unless you
#' know what you are doing, we recommend setting `stabilize = TRUE` and
#' ignoring `num.formula`.
#' @param include.obj whether to include in the output a list of the fit
#' objects created in the process of estimating the weights at each time point.
#' For example, with `method = "glm"`, a list of the `glm` objects
#' containing the propensity score models at each time point will be included.
#' See the help pages for each method for information on what object will be
#' included if `TRUE`.
#' @param is.MSM.method whether the method estimates weights for multiple time
#' points all at once (`TRUE`) or by estimating weights at each time point
#' and then multiplying them together (`FALSE`). This is only relevant for user-specified functions.
#' @param weightit.force several methods are not valid for estimating weights
#' with longitudinal treatments, and will produce an error message if
#' attempted. Set to `TRUE` to bypass this error message.
#' @param ...  other arguments for functions called by `weightit()` that
#' control aspects of fitting that are not covered by the above arguments. See
#' Details at [weightit()].
#'
#' @returns
#' A `weightitMSM` object with the following elements:
#' \item{weights}{The estimated weights, one for each unit.}
#' \item{treat.list}{A list of the values of the time-varying treatment variables.}
#' \item{covs.list}{A list of the covariates used in the fitting at each time point. Only includes the raw covariates, which may have been altered in the fitting process.}
#' \item{data}{The data.frame originally entered to `weightitMSM()`.}
#' \item{estimand}{"ATE", currently the only estimand for MSMs with binary or multi-category treatments.}
#' \item{method}{The weight estimation method specified.}
#' \item{ps.list}{A list of the estimated propensity scores (if any) at each time point.}
#' \item{s.weights}{The provided sampling weights.}
#' \item{by}{A data.frame containing the `by` variable when specified.}
#' \item{stabilization}{The stabilization factors, if any.}
#'
#' When `keep.mparts` is `TRUE` (the default) and the chosen method is compatible with M-estimation, the components related to M-estimation for use in [glm_weightit()] are stored in the `"Mparts.list"` attribute. When `by` is specified, `keep.mparts` is set to `FALSE`.
#'
#' @details
#' Currently only "wide" data sets, where each row corresponds to a unit's
#' entire variable history, are supported. You can use [reshape()] or other
#' functions to transform your data into this format; see example below.
#'
#' In general, `weightitMSM()` works by separating the estimation of
#' weights into separate procedures for each time period based on the formulas
#' provided. For each formula, `weightitMSM()` simply calls
#' `weightit()` to that formula, collects the weights for each time
#' period, and multiplies them together to arrive at longitudinal balancing
#' weights.
#'
#' Each formula should contain all the covariates to be balanced on. For
#' example, the formula corresponding to the second time period should contain
#' all the baseline covariates, the treatment variable at the first time
#' period, and the time-varying covariates that took on values after the first
#' treatment and before the second. Currently, only wide data sets are
#' supported, where each unit is represented by exactly one row that contains
#' the covariate and treatment history encoded in separate variables.
#'
#' The `"cbps"` method, which calls `CBPS()` in \pkg{CBPS}, will
#' yield different results from `CBMSM()` in \pkg{CBPS} because
#' `CBMSM()` takes a different approach to generating weights than simply
#' estimating several time-specific models.
#'
#' @seealso
#' [weightit()] for information on the allowable methods
#'
#' [summary.weightitMSM()] for summarizing the weights
#'
#' @references
#' Cole, S. R., & Hernán, M. A. (2008). Constructing Inverse
#' Probability Weights for Marginal Structural Models. American Journal of
#' Epidemiology, 168(6), 656–664. \doi{10.1093/aje/kwn164}
#'
#' @examples
#'
#' library("cobalt")
#'
#' data("msmdata")
#' (W1 <- weightitMSM(list(A_1 ~ X1_0 + X2_0,
#'                         A_2 ~ X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0,
#'                         A_3 ~ X1_2 + X2_2 +
#'                           A_2 + X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0),
#'                    data = msmdata,
#'                    method = "glm"))
#' summary(W1)
#' bal.tab(W1)
#'
#' #Using stabilization factors
#' W2 <- weightitMSM(list(A_1 ~ X1_0 + X2_0,
#'                         A_2 ~ X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0,
#'                         A_3 ~ X1_2 + X2_2 +
#'                           A_2 + X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0),
#'                    data = msmdata,
#'                    method = "glm",
#'                    stabilize = TRUE,
#'                    num.formula = list(~ 1,
#'                                       ~ A_1,
#'                                       ~ A_1 + A_2))
#'
#' #Same as above but with fully saturated stabilization factors
#' #(i.e., making the last entry in 'num.formula' A_1*A_2)
#' W3 <- weightitMSM(list(A_1 ~ X1_0 + X2_0,
#'                         A_2 ~ X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0,
#'                         A_3 ~ X1_2 + X2_2 +
#'                           A_2 + X1_1 + X2_1 +
#'                           A_1 + X1_0 + X2_0),
#'                    data = msmdata,
#'                    method = "glm",
#'                    stabilize = TRUE)

#' @export
weightitMSM <- function(formula.list, data = NULL, method = "glm", stabilize = FALSE, by = NULL,
                        s.weights = NULL,
                        num.formula = NULL, moments = NULL, int = FALSE, missing = NULL,
                        verbose = FALSE, include.obj = FALSE, keep.mparts = TRUE,
                        is.MSM.method, weightit.force = FALSE, ...) {

  A <- list(...)

  call <- match.call()

  ## Checks and processing ----

  #Checks

  ##Process method
  .check_acceptable_method(method, msm = TRUE, force = weightit.force)

  if (is.character(method)) {
    method <- .method_to_proper_method(method)
    attr(method, "name") <- method
    if (missing(is.MSM.method)) is.MSM.method <- NULL
    is.MSM.method <- .process_MSM_method(is.MSM.method, method)
  }
  else if (is.function(method)) {
    method.name <- paste(deparse(substitute(method)))
    .check_user_method(method)
    if (missing(is.MSM.method)) is.MSM.method <- NULL
    is.MSM.method <- .process_MSM_method(is.MSM.method, method)
    attr(method, "name") <- method.name
  }

  #Process moments and int
  moments.int <- .process_moments_int(moments, int, method)
  moments <- moments.int[["moments"]]; int <- moments.int[["int"]]

  s.weights <- process.s.weights(s.weights, data)

  if (is_not_null(num.formula)) {
    if (!stabilize) {
      .msg("setting `stabilize` to `TRUE` based on `num.formula` input")
    }
    stabilize <- TRUE
  }
  if (stabilize) {
    if (is_not_null(num.formula)) {
      if (rlang::is_formula(num.formula)) {
        if (!rlang::is_formula(num.formula, lhs = FALSE)) {
          .err("the argument to `num.formula` must have right hand side variables but not a response variable (e.g., ~ V1 + V2)")
        }

        rhs.vars.mentioned.lang <- attr(terms(num.formula), "variables")[-1]
        rhs.vars.mentioned <- vapply(rhs.vars.mentioned.lang, deparse1, character(1L))
        rhs.vars.failed <- vapply(rhs.vars.mentioned.lang, function(v) {
          null_or_error(try(eval(v, c(data, .GlobalEnv)), silent = TRUE))
        }, logical(1L))

        if (any(rhs.vars.failed)) {
          .err(paste0(c("All variables in `num.formula` must be variables in `data` or objects in the global environment.\nMissing variables: ",
                        paste(rhs.vars.mentioned[rhs.vars.failed], collapse = ", "))), tidy = FALSE)
        }
      }
      else if (is.list(num.formula)) {
        if (length(num.formula) != length(formula.list)) {
          .err("when supplied as a list, `num.formula` must have as many entries as `formula.list`", call. = FALSE)
        }
        if (!all(vapply(num.formula, rlang::is_formula, logical(1L), lhs = FALSE))) {
          .err("'num.formula' must be a single formula with no response variable and with the stabilization factors on the right hand side or a list thereof")
        }
        rhs.vars.mentioned.lang.list <- lapply(num.formula, function(nf) attr(terms(nf), "variables")[-1])
        rhs.vars.mentioned <- unique(unlist(lapply(rhs.vars.mentioned.lang.list, function(r) vapply(r, deparse1, character(1L)))))
        rhs.vars.failed <- vapply(rhs.vars.mentioned, function(v) {
          null_or_error(try(eval(parse(text=v), c(data, .GlobalEnv)), silent = TRUE))
        }, logical(1L))

        if (any(rhs.vars.failed)) {
          .err(paste0(c("All variables in `num.formula` must be variables in `data` or objects in the global environment.\nMissing variables: ",
                        paste(rhs.vars.mentioned[rhs.vars.failed], collapse=", "))), tidy = FALSE)
        }
      }
      else {
        .err("`num.formula` must be a single formula with no response variable and with the stabilization factors on the right hand side or a list thereof")
      }
    }
  }

  ##Process by
  if (is_not_null(A[["exact"]])) {
    .msg("`by` has replaced `exact` in the `weightit()` syntax, but `exact` will always work")
    # by.name <- deparse(A[["exact"]])
    by <- A[["exact"]]
    by.arg <- "exact"
  }
  else {
    by.arg <- "by"
  }

  reported.covs.list <- covs.list <- treat.list <- w.list <- ps.list <-
    stabout <- sw.list <- Mparts.list <- stab.Mparts.list <- make_list(length(formula.list))

  if (is_null(formula.list) || !is.list(formula.list) ||
      !all(vapply(formula.list, rlang::is_formula, logical(1L), lhs = TRUE))) {
    .err("`formula.list` must be a list of formulas")
  }

  for (i in seq_along(formula.list)) {

    #Process treat and covs from formula and data
    t.c <- get_covs_and_treat_from_formula(formula.list[[i]], data)
    reported.covs.list[[i]] <- t.c[["reported.covs"]]
    covs.list[[i]] <- t.c[["model.covs"]]
    treat.list[[i]] <- t.c[["treat"]]
    treat.name <- t.c[["treat.name"]]
    names(treat.list)[i] <- treat.name
    names(reported.covs.list)[i] <- treat.name

    if (is_null(covs.list[[i]])) .err(sprintf("no covariates were specified in the %s formula", ordinal(i)))
    if (is_null(treat.list[[i]])) .err(sprintf("no treatment variable was specified in the %s formula", ordinal(i)))

    n <- length(treat.list[[i]])

    if (nrow(covs.list[[i]]) != n) {
      .err("treatment and covariates must have the same number of units")
    }
    if (anyNA(treat.list[[i]])) {
      .err(sprintf("no missing values are allowed in the treatment variable. Missing values found in %s", treat.name))
    }

    treat.list[[i]] <- assign_treat_type(treat.list[[i]])

    #By is processed each for each time, but only last time is used for by.factor.
    processed.by <- .process_by(by, data = data,
                               treat = treat.list[[i]],
                               treat.name = treat.name,
                               by.arg = by.arg)

    #Process missing
    if (anyNA(reported.covs.list[[i]])) {
      missing <- .process_missing(missing, method, get_treat_type(treat.list[[i]]))
    }
    else if (i == length(formula.list)) {
      missing <- ""
    }
  }

  if (is_null(s.weights)) s.weights <- rep(1, n)

  if (is.MSM.method) {
    #Returns weights (w)

    A[["covs.list"]] <- covs.list
    A[["treat.list"]] <- treat.list
    A[["s.weights"]] <- s.weights
    A[["by.factor"]] <- attr(processed.by, "by.factor")
    A[["stabilize"]] <- stabilize
    A[["method"]] <- method
    A[["moments"]] <- moments
    A[["int"]] <- int
    A[["subclass"]] <- numeric()
    A[["missing"]] <- missing
    A[["verbose"]] <- verbose
    A[["include.obj"]] <- include.obj

    obj <- do.call("weightitMSM.fit", A)

    w <- obj[["weights"]]
    stabout <- NULL
    obj.list <- obj[["fit.obj"]]
    Mparts.list <- attr(obj, "Mparts")
  }
  else {
    if (length(A[["link"]]) %nin% c(0, 1, length(formula.list))) {
      .err(sprintf("the argument to `link` must have length 1 or %s", length(formula.list)))
    }
    if (length(A[["link"]]) == 1) {
      A[["link"]] <- rep(A[["link"]], length(formula.list))
    }
    # if (length(A[["family"]]) %nin% c(0, 1, length(formula.list))) stop(paste0("The argument to link must have length 1 or ", length(formula.list), "."), call. = FALSE)
    # if (length(A[["family"]]) == 1) A[["family"]] <- rep(A[["family"]], length(formula.list))

    obj.list <- make_list(length(formula.list))

    A[["s.weights"]] <- s.weights
    A[["by.factor"]] <- attr(processed.by, "by.factor")
    A[["estimand"]] <- "ATE"
    A[["focal"]] <- character()
    A[["stabilize"]] <- FALSE
    A[["method"]] <- method
    A[["moments"]] <- moments
    A[["int"]] <- int
    A[["subclass"]] <- numeric()
    A[["ps"]] <- numeric()
    A[["missing"]] <- missing
    A[["verbose"]] <- verbose
    A[["is.MSM.method"]] <- FALSE
    A[["include.obj"]] <- include.obj

    for (i in seq_along(formula.list)) {
      A_i <- A
      if (length(A[["link"]]) == length(formula.list)) {
        A_i[["link"]] <- A[["link"]][[i]]
      }

      A_i[["covs"]] <- covs.list[[i]]
      A_i[["treat"]] <- treat.list[[i]]
      A_i[["treat.type"]] <- get_treat_type(treat.list[[i]])
      A_i[[".data"]] <- data
      A_i[[".covs"]] <- reported.covs.list[[i]]

      ## Running models ----

      #Returns weights (w) and propensty score (ps)
      obj <- do.call("weightit.fit", A_i)

      w.list[[i]] <- obj[["weights"]]
      ps.list[[i]] <- obj[["ps"]]
      obj.list[[i]] <- obj[["fit.obj"]]
      Mparts.list[[i]] <- attr(obj, "Mparts")

      if (stabilize) {
        #Process stabilization formulas and get stab weights
        if (rlang::is_formula(num.formula)) {
          if (i == 1) {
            stab.f <- update.formula(as.formula(paste(names(treat.list)[i], "~ 1")), as.formula(paste(paste(num.formula, collapse = ""), "+ .")))
          }
          else {
            stab.f <- update.formula(as.formula(paste(names(treat.list)[i], "~", paste(names(treat.list)[seq_along(names(treat.list)) < i], collapse = " * "))), as.formula(paste(num.formula, "+ .")))
          }
        }
        else if (is.list(num.formula)) {
          stab.f <- update.formula(as.formula(paste(names(treat.list)[i], "~ 1")), as.formula(paste(paste(num.formula[[i]], collapse = ""), "+ .")))
        }
        else {
          if (i == 1) {
            stab.f <- as.formula(paste(names(treat.list)[i], "~ 1"))
          }
          else {
            stab.f <- as.formula(paste(names(treat.list)[i], "~", paste(names(treat.list)[seq_along(names(treat.list)) < i], collapse = " * ")))
          }
        }

        stab.t.c_i <- get_covs_and_treat_from_formula(stab.f, data)

        A_i[["covs"]] <- stab.t.c_i[["model.covs"]]
        A_i[["method"]] <- "glm"
        A_i[["moments"]] <- numeric()
        A_i[["int"]] <- FALSE

        sw_obj <- do.call("weightit.fit", A_i)

        sw.list[[i]] <- 1/sw_obj[["weights"]]
        stabout[[i]] <- stab.f[-2]

        stab.Mparts.list[[i]] <- attr(sw_obj, "Mparts")

        if (is_not_null(stab.Mparts.list[[i]])) {
          stab.Mparts.list[[i]]$wfun <- Invert(stab.Mparts.list[[i]]$wfun)
        }
      }
    }

    w <- Reduce("*", w.list, init = 1)

    if (stabilize) {
      sw <- Reduce("*", sw.list, init = 1)
      w <- w * sw

      unique.stabout <- unique(stabout)
      if (length(unique.stabout) <= 1) stabout <- unique.stabout
    }
    else {
      stabout <- NULL
    }

    if (include.obj) names(obj.list) <- names(treat.list)
  }

  if (all_the_same(w)) .err(sprintf("all weights are %s", w[1]))


  ## Assemble output object----
  out <- list(weights = w,
              treat.list = treat.list,
              covs.list = reported.covs.list,
              estimand = "ATE",
              method = method,
              # ps.list = ps.list,
              s.weights = s.weights,
              #discarded = NULL,
              by = processed.by,
              call = call,
              formula.list = formula.list,
              stabilization = stabout,
              env = parent.frame(),
              obj = obj.list
  )

  out <- clear_null(out)

  if (keep.mparts && all(lengths(Mparts.list) > 0)) {
    attr(out, "Mparts.list") <- clear_null(c(Mparts.list, stab.Mparts.list))
  }

  class(out) <- c("weightitMSM", "weightit")

  out
}

#' @exportS3Method print weightitMSM
print.weightitMSM <- function(x, ...) {
  treat.types <- vapply(x[["treat.list"]], get_treat_type, character(1L))
  trim <- attr(x[["weights"]], "trim")

  cat("A " %+% italic("weightitMSM") %+% " object\n")
  cat(paste0(" - method: \"", attr(x[["method"]], "name"), "\" (", .method_to_phrase(x[["method"]]), ")\n"))
  cat(paste0(" - number of obs.: ", length(x[["weights"]]), "\n"))
  cat(paste0(" - sampling weights: ", ifelse(all_the_same(x[["s.weights"]]), "none", "present"), "\n"))
  cat(paste0(" - number of time points: ", length(x[["treat.list"]]), " (", paste(names(x[["treat.list"]]), collapse = ", "), ")\n"))
  cat(paste0(" - treatment: \n",
             paste0(vapply(seq_along(x$covs.list), function(i) {
               paste0("    + time ", i, ": ",
                      if (treat.types[i] == "continuous") "continuous"
                      else paste0(nunique(x[["treat.list"]][[i]]), "-category",
                                  if (treat.types[i] == "multinomial") paste0(" (", paste(levels(x[["treat.list"]][[i]]), collapse = ", "), ")")
                                  else ""
                      ), "\n")
             }, character(1L)), collapse = ""), collapse = "\n"))
  cat(paste0(" - covariates: \n",
             paste0(vapply(seq_along(x$covs.list), function(i) {
               if (i == 1) {
                 paste0("    + baseline: ", if (is_null(x$covs.list[[i]])) "(none)" else paste(names(x$covs.list[[i]]), collapse = ", "), "\n")
               }
               else {
                 paste0("    + after time ", i-1, ": ", paste(names(x$covs.list[[i]]), collapse = ", "), "\n")
               }
             }, character(1L)), collapse = ""), collapse = "\n"))
  if (is_not_null(x[["by"]])) {
    cat(paste0(" - by: ", paste(names(x[["by"]]), collapse = ", "), "\n"))
  }
  if (is_not_null(x$stabilization)) {
    cat(" - stabilized")
    if (any(vapply(x$stabilization, function(s) is_not_null(all.vars(s)), logical(1L)))) {
      cat(paste0("; stabilization factors:\n", if (length(x$stabilization) == 1) paste0("      ", paste0(attr(terms(x[["stabilization"]][[1]]), "term.labels"), collapse = ", "))
                 else {
                   paste0(vapply(seq_along(x$stabilization), function(i) {
                     if (i == 1) {
                       paste0("    + baseline: ", if (is_null(attr(terms(x[["stabilization"]][[i]]), "term.labels"))) "(none)" else paste(attr(terms(x[["stabilization"]][[i]]), "term.labels"), collapse = ", "))
                     }
                     else {
                       paste0("    + after time ", i-1, ": ", paste(attr(terms(x[["stabilization"]][[i]]), "term.labels"), collapse = ", "))
                     }
                   }, character(1L)), collapse = "\n")
                 }))
    }
  }

  if (is_not_null(trim)) {
    if (trim < 1) {
      if (attr(x[["weights"]], "trim.lower")) trim <- c(1 - trim, trim)
      cat(paste(" - weights trimmed at", word_list(paste0(round(100*trim, 2), "%")), "\n"))
    }
    else {
      t.b <- if (attr(x[["weights"]], "trim.lower")) "top and bottom" else "top"
      cat(paste(" - weights trimmed at the", t.b, trim, "\n"))
    }
  }
  invisible(x)
}

Try the WeightIt package in your browser

Any scripts or data that you put into this service are public.

WeightIt documentation built on May 29, 2024, 9:48 a.m.