R/functions_for_processing.R

Defines functions .get_glm_starting_values .make_link .get_glm_psi .p2o .gradientRich .gradientFD .gradient generalized_inverse verbosely add_missing_indicators replace_na_with neg_ent plot_density .dw_dp_multi .dw_dp_bin .get_w_from_ps_internal_array .get_w_from_ps_internal_multi .get_w_from_ps_internal_bin .get_dens_fun stabilize_w .subclass_scoot .subclass_ps_bin .subclass_ps_multi .check_estimated_weights get.s.d.denom.cont.weightit get.s.d.denom.weightit .quantile_f .int_poly_f .make_closer_to_1 .process_by get_treated_level .get_control_and_treated_levels .process_focal_and_estimand .process_ps .check_user_method .process_missing2 .missing_to_phrase .process_missing .process_MSM_method .process_moments_int_quantile .check_subclass .process_estimand .method_to_phrase .check_method_s.weights .process.s.weights .check_required_packages .check_method_treat.type .check_acceptable_method .method_to_proper_method

.method_to_proper_method <- function(method) {
  if (is_null(method)) {
    return(NULL)
  }

  if (!is.character(method)) {
    return(method)
  }

  method <- tolower(method)

  if (method %nin% unlist(grab(.weightit_methods, "alias"))) {
    return(method)
  }

  .allowable.methods <- unlist(lapply(names(.weightit_methods), function(m) {
    aliases <- .weightit_methods[[m]]$alias
    setNames(rep(m, length(aliases)), aliases)
  }))

  unname(.allowable.methods[method])
}

.check_acceptable_method <- function(method, msm = FALSE, force = FALSE) {

  if (missing(method)) {
    method <- "glm"
  }
  else if (is_null(method)) {
    return(invisible(NULL))
  }

  if (identical(method, "twang")) {
    .err('"twang" is no longer an acceptable argument to `method`. Please use "gbm" for generalized boosted modeling')
  }

  if ((!is.character(method) && !is.function(method)) ||
      (is.character(method) && (length(method) > 1L ||
                                !utils::hasName(.weightit_methods,
                                                .method_to_proper_method(method))))) {
    .err(sprintf("`method` must be a string of length 1 containing the name of an acceptable weighting method or a function that produces weights. Allowable methods:\n%s",
                 word_list(names(.weightit_methods), and.or = FALSE, quotes = 2L)),
         tidy = FALSE)
  }

  if (msm && !force && is.character(method)) {
    m <- .method_to_proper_method(method)
    if (!.weightit_methods[[m]]$msm_valid) {
      .err(sprintf("the use of %s with longitudinal treatments has not been validated. Set `weightit.force = TRUE` to bypass this error",
                   .method_to_phrase(m)))
    }
  }
}

.check_method_treat.type <- function(method, treat.type) {
  if (is_not_null(method) && is.character(method) &&
      utils::hasName(.weightit_methods, method) &&
      (treat.type %nin% .weightit_methods[[method]]$treat_type)) {
    .err(sprintf("%s can only be used with a %s treatment",
                 .method_to_phrase(method),
                 word_list(.weightit_methods[[method]]$treat_type, and.or = "or")))
  }
}

.check_required_packages <- function(method) {
  if (is_not_null(method) && is.character(method) &&
      utils::hasName(.weightit_methods, method)) {

    pkgs <- .weightit_methods[[method]]$packages_needed

    if (is_not_null(pkgs)) {
      rlang::check_installed(pkgs)
    }
  }

  invisible(NULL)
}

.process.s.weights <- function(s.weights, data = NULL) {
  #Process s.weights
  if (is_null(s.weights)) {
    return(NULL)
  }

  if (is.numeric(s.weights)) {
    return(s.weights)
  }

  if (!chk::vld_string(s.weights)) {
    .err("the argument to `s.weights` must be a vector or data frame of sampling weights or the (quoted) names of the variable in `data` that contains sampling weights")
  }

  if (is_null(data)) {
    .err("`s.weights` was specified as a string but there was no argument to `data`")
  }

  if (!utils::hasName(data, s.weights)) {
    .err("the name supplied to `s.weights` is not the name of a variable in `data`")
  }

  data[[s.weights]]
}

.check_method_s.weights <- function(method, s.weights) {
  if (is_not_null(method) &&
      !is.function(method) &&
      !.weightit_methods[[method]]$s.weights_ok &&
      !all_the_same(s.weights)) {
    .err(sprintf("sampling weights cannot be used with %s", .method_to_phrase(method)))
  }
}

.method_to_phrase <- function(method) {

  if (is_null(method)) {
    return("no weighting")
  }

  if (is.function(method)) {
    return("a user-defined method")
  }

  method <- .method_to_proper_method(method)

  if (!utils::hasName(.weightit_methods, method)) {
    return("the chosen method of weighting")
  }

  .weightit_methods[[method]]$description
}

.process_estimand <- function(estimand, method, treat.type) {

  if (is.function(method)) {
    chk::chk_null_or(estimand, vld = chk::vld_string)
    return(toupper(estimand))
  }

  if (treat.type == "continuous") {
    if (is_not_null(estimand) && !identical(toupper(estimand), "ATE")) {
      .wrn("`estimand` is ignored for continuous treatments")
    }

    return("ATE")
  }

  chk::chk_string(estimand)

  allowable_estimands <- {
    if (is_null(method)) unique(unlist(grab(.weightit_methods, "estimand")))
    else .weightit_methods[[method]]$estimand
  }

  if (treat.type == "multi-category") {
    allowable_estimands <- setdiff(allowable_estimands, "ATOS")
  }

  if (toupper(estimand) %nin% allowable_estimands) {
    .err(sprintf("%s is not an allowable estimand for %s with a %s treatment. Only %s allowed",
                 add_quotes(estimand), .method_to_phrase(method), treat.type,
                 word_list(allowable_estimands, quotes = TRUE, and.or = "and", is.are = TRUE)))
  }

  toupper(estimand)
}

.check_subclass <- function(method, treat.type) {
  if (is_not_null(method) && !is.function(method)) {

    subclass_ok <- .weightit_methods[[method]]$subclass_ok

    if (treat.type == "continuous" || !subclass_ok) {
      .err(sprintf("subclasses are not compatible with %s with a %s treatment",
                   .method_to_phrase(method), treat.type))
    }
  }
}

.process_moments_int_quantile <- function(moments, int, quantile = NULL, method = NULL) {
  if (is.function(method)) {
    return(list(moments = moments, int = int, quantile = quantile))
  }

  if (is_null(method) || !.weightit_methods[[method]]$moments_int_ok) {
    if (is_not_null(method)) {
      mi0 <- c(is_not_null(moments), is_not_null(int) && !isFALSE(int), is_not_null(quantile))

      if (any(mi0)) {
        .wrn(sprintf("%s not compatible with %s. Ignoring %s",
                     word_list(c("moments", "int", "quantile")[mi0], and.or = "and", is.are = TRUE, quotes = "`"),
                     .method_to_phrase(method),
                     word_list(c("moments", "int", "quantile")[mi0], and.or = "and", quotes = "`")))
      }
    }

    return(list(moments = integer(), int = FALSE, quantile = list()))
  }

  chk::chk_flag(int)

  if (is_not_null(quantile)) {
    .vld_qu <- function(x) {
      is.numeric(x) && all(x >= 0) && all(x <= 1)
    }

    bad.q <- FALSE
    if (is.numeric(quantile) && .vld_qu(quantile)) {
      if (length(quantile) == 1L || (is_not_null(names(quantile)) && all(nzchar(names(quantile))))) {
        quantile <- as.list(quantile)
      }
      else {
        bad.q <- TRUE
      }
    }
    else if (is.list(quantile)) {
      if ((length(quantile) > 1L && (is_null(names(quantile)) || !all(nzchar(names(quantile))))) ||
          !all_apply(quantile, .vld_qu)) {
        bad.q <- TRUE
      }
    }
    else {
      bad.q <- TRUE
    }

    if (bad.q) {
      .err("`quantile` must be a number between 0 and 1, a named list or vector of such values, or a named list of vectors of such values")
    }
  }

  if (is_not_null(moments)) {
    chk::chk_whole_number(moments)

    chk::chk_gte(moments,
                 if (is_null(quantile)) .weightit_methods[[method]]$moments_default
                 else 0)

    if (int && moments < 1) {
      .wrn("when `int = TRUE`, `moments` must be greater than or equal to 1. Setting `moments = 1`")
      moments <- 1L
    }
    else {
      moments <- as.integer(moments)
    }
  }
  else {
    moments <- {
      if (int) 1L
      else .weightit_methods[[method]]$moments_default
    }
  }

  list(moments = moments, int = int, quantile = quantile)
}

.process_MSM_method <- function(is.MSM.method, method) {
  if (is_null(method)) {
    return(FALSE)
  }

  if (is.function(method)) {
    if (isTRUE(is.MSM.method)) {
      .err("currently, only user-defined methods that work with `is.MSM.method = FALSE` are allowed")
    }

    return(FALSE)
  }

  if (.weightit_methods[[method]]$msm_method_available) {
    if (is_null(is.MSM.method)) {
      return(TRUE)
    }

    chk::chk_flag(is.MSM.method)

    if (!is.MSM.method) {
      .msg(sprintf("%s can be used with a single model when multiple time points are present. Using a seperate model for each time point. To use a single model, set `is.MSM.method` to `TRUE`",
                   .method_to_phrase(method)))
    }

    return(is.MSM.method)
  }

  if (is_not_null(is.MSM.method)) {

    chk::chk_flag(is.MSM.method)

    if (is.MSM.method) {
      .wrn(sprintf("%s cannot be used with a single model when multiple time points are present. Using a seperate model for each time point",
                   .method_to_phrase(method)))
    }
  }

  FALSE
}

.process_missing <- function(missing, method) {
  if (is_null(method)) {
    return("")
  }

  allowable.missings <- .weightit_methods[[method]]$missing

  if (is_null(missing)) {
    .wrn(sprintf("missing values are present in the covariates. See `?WeightIt::method_%s` for information on how these are handled",
                 method))
    return(allowable.missings[1L])
  }

  chk::chk_string(missing)

  if (missing %nin% allowable.missings) {
    .err(sprintf("only %s allowed for `missing` with %s",
                 word_list(allowable.missings, quotes = 2L, is.are = TRUE),
                 .method_to_phrase(method)))
  }

  missing
}

.missing_to_phrase <- function(missing) {
  switch(missing,
         ind = "missingness indicators",
         saem = "SAEM",
         surr = "surrogate splitting",
         missing)
}

.process_missing2 <- function(missing, covs) {
  if (is_null(missing) || identical(missing, "") || !anyNA(covs)) {
    return("")
  }

  missing
}

.check_user_method <- function(method) {
  #Check to make sure it accepts treat and covs
  if (all(c("covs", "treat") %in% names(formals(method)))) {
  }
  # else if (all(c("covs.list", "treat.list") %in% names(formals(method)))) {
  # }
  else {
    .err("the user-provided function to `method` must contain `covs` and `treat` as named parameters")
  }
}

.process_ps <- function(ps, data = NULL, treat = NULL) {
  if (is_null(ps)) {
    return(NULL)
  }

  if (chk::vld_string(ps)) {
    if (is_null(data)) {
      .err("`ps` was specified as a string but there was no argument to `data`")
    }

    if (!utils::hasName(data, ps)) {
      .err("the name supplied to `ps` is not the name of a variable in `data`")
    }

    ps <- data[[ps]]

    if (!is.numeric(ps)) {
      .err("the name supplied to `ps` must correspond to a numeric variable in `data`")
    }
  }
  else if (is.numeric(ps)) {
    if (length(ps) != length(treat)) {
      .err("`ps` must have the same number of units as the treatment")
    }
  }
  else {
    .err("the argument to `ps` must be a vector of propensity scores or the (quoted) name of a numeric variable in `data` that contains propensity scores")
  }

  ps
}

.process_focal_and_estimand <- function(focal, estimand, treat, treated = NULL) {
  reported.estimand <- estimand

  if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
  treat.type <- get_treat_type(treat)

  if (treat.type == "continuous") {
    return(list(focal = NULL,
                estimand = "ATE",
                reported.estimand = "ATE"))
  }

  if (estimand %nin% c("ATT", "ATC") && is_not_null(focal)) {
    .wrn(sprintf('`estimand = %s` is not compatible with `focal`. Setting `estimand` to "ATT"',
                 add_quotes(estimand)))
    reported.estimand <- estimand <- "ATT"
  }

  if (treat.type == "binary") {
    ct <- .get_control_and_treated_levels(treat, estimand, focal, treated)

    focal <- switch(estimand,
                    "ATT" = ct["treated"],
                    "ATC" = ct["control"],
                    NULL)

    treated <- ct["treated"]
  }
  else {
    unique.vals <- {
      if (chk::vld_character_or_factor(treat))
        levels(factor(treat, nmax = ceiling(length(treat) / 4)))
      else
        sort(unique(treat, nmax = ceiling(length(treat) / 4)))
    }

    #Check focal
    if (is_not_null(focal) && (length(focal) > 1L || focal %nin% unique.vals)) {
      .err("the argument supplied to `focal` must be the name of a level of treatment")
    }

    if (estimand == "ATT") {
      if (is_null(focal)) {
        if (is_null(treated) || treated %nin% unique.vals) {
          .err('when `estimand = "ATT"` for multi-category treatments, an argument must be supplied to `focal`')
        }
        focal <- treated
      }
    }
    else if (estimand == "ATC") {
      if (is_null(focal)) {
        .err('when `estimand = "ATC"` for multi-category treatments, an argument must be supplied to `focal`')
      }
      estimand <- "ATT"
    }
  }

  list(focal = unname(focal),
       estimand = estimand,
       reported.estimand = reported.estimand,
       treated = switch(treat.type, "binary" = unname(treated), NULL))
}

.get_control_and_treated_levels <- function(treat, estimand, focal = NULL, treated = NULL) {

  if (is_not_null(attr(treat, "control")) &&
      is_not_null(attr(treat, "treated"))) {
    return(setNames(c(attr(treat, "control"), attr(treat, "treated")),
                    c("control", "treated")))
  }

  control <- NULL
  throw_message <- FALSE

  unique.vals <- {
    if (chk::vld_character_or_factor(treat))
      levels(factor(treat, nmax = 2L))
    else
      sort(unique(treat, nmax = 2L))
  }

  if (is_not_null(focal)) {
    if (length(focal) > 1L || focal %nin% unique.vals) {
      .err("the argument supplied to `focal` must be the name of a level of treatment")
    }

    if (estimand == "ATC") {
      control <- focal
      treated <- NULL
    }
    else {
      treated <- focal
    }
  }
  else if (is_not_null(attr(treat, "treated", TRUE))) {
    treated <- attr(treat, "treated", TRUE)
  }
  else if (is_not_null(attr(treat, "control", TRUE))) {
    control <- attr(treat, "control", TRUE)
  }
  else if (is_not_null(treated)) {
    if (length(treated) > 1L || treated %nin% unique.vals) {
      .err("the argument supplied to `treated` must be the name of a level of treatment")
    }
  }
  else if (is.logical(treat)) {
    treated <- TRUE
    control <- FALSE
  }
  else if (is.numeric(unique.vals)) {
    control <- unique.vals[unique.vals == 0]

    if (is_null(control)) {
      control <- unique.vals[1L]

      throw_message <- TRUE
    }
  }
  else if (can_str2num(unique.vals)) {
    unique.vals.numeric <- str2num(unique.vals)

    control <- unique.vals[unique.vals.numeric == 0]

    if (is_null(control)) {
      control <- unique.vals[which.min(unique.vals.numeric)]

      throw_message <- TRUE
    }
  }
  else {
    treated_options <- c("t", "tr", "treat", "treated", "exposed")
    control_options <- c("c", "co", "ctrl", "control", "unexposed")

    t_match <- which(unique.vals %in% treated_options)
    c_match <- which(unique.vals %in% control_options)

    if (length(t_match) == 1L) {
      treated <- unique.vals[t_match]
    }
    else if (length(c_match) == 1L) {
      control <- unique.vals[c_match]
    }
  }

  if (is_null(control) && is_null(treated)) {
    control <- unique.vals[1L]
    treated <- unique.vals[2L]

    throw_message <- TRUE
  }
  else if (is_null(control)) {
    control <- setdiff(unique.vals, treated)
  }
  else if (is_null(treated)) {
    treated <- setdiff(unique.vals, control)
  }

  if (throw_message) {
    if (estimand == "ATT") {
      .msg(sprintf("assuming %s is the treated level. If not, supply an argument to `focal`",
                   add_quotes(treated, !is.numeric(unique.vals))))

    }
    else if (estimand == "ATC") {
      .msg(sprintf("assuming %s is the control level. If not, supply an argument to `focal`",
                   add_quotes(control, !is.numeric(unique.vals))))
    }
    else {
      .msg(sprintf("assuming %s is the treated level. If not, recode the treatment so that 1 is treated and 0 is control",
                   add_quotes(treated, !is.numeric(unique.vals))))
    }
  }

  setNames(c(control, treated),
           c("control", "treated"))
}

get_treated_level <- function(treat, estimand, focal = NULL) {
  ct <- .get_control_and_treated_levels(treat, estimand, focal)

  unname(ct["treated"])
}

.process_by <- function(by, data, treat, treat.name = NULL, by.arg = "by") {

  ##Process by
  bad.by <- FALSE
  n <- length(treat)

  if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
  treat.type <- get_treat_type(treat)

  if (missing(by)) {
    bad.by <- TRUE
  }
  else if (is_null(by)) {
    by <- NULL
    by.name <- NULL
  }
  else if (chk::vld_string(by) && utils::hasName(data, by)) {
    by.name <- by
    by <- data[[by]]
  }
  else if (length(dim(by)) == 2L && len(by) == n) {
    by.name <- colnames(by)[1L]
    by <- drop(by[, 1L])
  }
  else if (rlang::is_formula(by, lhs = FALSE)) {
    t.c <- get_covs_and_treat_from_formula(by, data)
    by <- t.c[["reported.covs"]]
    if (NCOL(by) != 1L) {
      .err(sprintf("only one variable can be on the right-hand side of the formula for `%s`",
                   by.arg))
    }
    by.name <- colnames(by)
  }
  else {
    bad.by <- TRUE
  }

  if (bad.by) {
    .err(sprintf("`%s` must be a string containing the name of the variable in data for which weighting is to occur within strata or a one-sided formula with the stratifying variable on the right-hand side",
                 by.arg))
  }

  if (anyNA(by)) {
    .err(sprintf("the variable supplied to `%s` cannot contain any missing (NA) values",
                 by.arg))
  }

  by.components <- data.frame(by)

  names(by.components) <- {
    if (is_not_null(colnames(by))) colnames(by)
    else by.name
  }

  by.factor <- {
    if (is_null(by)) factor(rep.int(1L, n), levels = 1L)
    else factor(by.components[[1L]], levels = sort(unique(by.components[[1L]])),
                labels = paste(names(by.components), "=", sort(unique(by.components[[1L]]))))
  }

  if (treat.type != "continuous" &&
      any_apply(levels(by.factor), function(x) nunique(treat) != nunique(treat[by.factor == x]))) {
    .err(sprintf("not all the groups formed by `%s` contain all treatment levels%s. Consider coarsening `%s`",
                 by.arg,
                 if (is_not_null(treat.name)) sprintf(" in %s", treat.name) else "",
                 by.arg))
  }

  attr(by.components, "by.factor") <- by.factor

  by.components
}

.make_closer_to_1 <- function(x) {
  if (chk::vld_character_or_factor(x) || all_the_same(x)) {
    return(x)
  }

  if (is_binary(x)) {
    return(as.numeric(x == max(x, na.rm = TRUE)))
  }

  (x - mean_fast(x, TRUE)) / sd(x, na.rm = TRUE)
}

.int_poly_f <- function(d, ex = NULL, int = FALSE, poly = 1, center = TRUE, orthogonal_poly = TRUE) {
  #Adds to data frame interactions and polynomial terms
  #d=matrix input
  #ex=names of variables to exclude in interactions and polynomials; a subset of df
  #int=whether to include interactions or not; currently only 2-way are supported
  #poly=degree of polynomials to include; will also include all below poly. If 1, no polynomial will be included

  if (!is.matrix(d)) {
    if (!is.numeric(d)) {
      .err("an error occurred, probably a bug")
    }

    matrix(d, ncol = 1L, dimnames = list(NULL, "x"))
  }

  if (is_null(ex)) ex <- rep.int(FALSE, ncol(d))

  binary.vars <- is_binary_col(d)

  if (center && (int || !orthogonal_poly)) {
    d[, !binary.vars] <- center(d[, !binary.vars, drop = FALSE])
  }

  nd <- NCOL(d)

  if (poly == 0 || nd == 0L) {
    poly_terms <- poly_co.names <- list()
  }
  else if (poly == 1) {
    poly_terms <- list(d)
    poly_co.names <- list(colnames(d))
  }
  else {
    poly_terms <- poly_co.names <- make_list(nd)

    for (i in seq_col(d)) {
      if (ex[i] || binary.vars[i]) {
        poly_terms[[i]] <- d[, i]
        poly_co.names[[i]] <- colnames(d)[i]
      }
      else {
        poly_terms[[i]] <- poly(d[, i], degree = poly, raw = !orthogonal_poly, simple = TRUE)
        poly_co.names[[i]] <- sprintf("%s%s%s",
                                      if (orthogonal_poly) "orth_" else "",
                                      colnames(d)[i],
                                      num_to_superscript(seq_len(poly)))
      }
    }
  }

  if (int && nd > 1L) {
    int_terms <- int_co.names <- make_list(1L)
    ints_to_make <- utils::combn(colnames(d)[!ex], 2L, simplify = FALSE)

    if (is_not_null(ints_to_make)) {
      int_terms[[1L]] <- do.call("cbind", lapply(ints_to_make, function(i) d[, i[1L]] * d[, i[2L]]))

      int_co.names[[1L]] <- vapply(ints_to_make, paste, character(1L), collapse = " * ")
    }
  }
  else {
    int_terms <- int_co.names <- list()
  }

  if (is_null(poly_terms) && is_null(int_terms)) {
    return(matrix(ncol = 0L, nrow = nrow(d), dimnames = list(rownames(d), NULL)))
  }

  out <- do.call("cbind", c(poly_terms, int_terms))
  out_co.names <- c(unlist(poly_co.names), unlist(int_co.names))

  colnames(out) <- out_co.names

  #Remove single values
  if (is_not_null(out)) {
    single_value <- apply(out, 2L, all_the_same)
    out <- out[, !single_value, drop = FALSE]
  }

  out
}

.quantile_f <- function(d, qu = NULL, s.weights = NULL, focal = NULL, treat = NULL, const = 2000) {
  # Creates new variables for use in balance quantiles. `qu` is a list of quantiles for each
  # continuous variable in `d`, and returns a matrix with a column for each requested quantile
  # of each variable, taking on 0 for values less than the quantile, .5 for values at the quantile,
  # and 1 for values greater than the quantile. The mean of each variable is equal to the quantile.

  if (is_null(qu)) {
    return(matrix(ncol = 0L, nrow = nrow(d), dimnames = list(rownames(d), NULL)))
  }

  vld_qu <- function(x) {
    is.numeric(x) && all(x >= 0) && all(x <= 1)
  }

  binary.vars <- is_binary_col(d)

  if (length(qu) == 1L && is_null(names(qu))) {
    qu <- setNames(qu[rep.int(1L, sum(!binary.vars))],
                   colnames(d)[!binary.vars])
  }

  if (!all(names(qu) %in% colnames(d)[!binary.vars])) {
    .err("all names of `quantile` must refer to continuous covariates")
  }

  for (i in qu) {
    if (!vld_qu(i)) {
      .err("`quantile` must be a number between 0 and 1 or a named list thereof")
    }
  }

  if (is_not_null(focal) && is_not_null(s.weights)) {
    s.weights <- s.weights[treat == focal]
  }

  do.call("cbind", lapply(names(qu), function(i) {
    target <- if (is_null(focal)) d[, i] else d[treat == focal, i]
    out <- do.call("cbind", lapply(qu[[i]], function(q) {
      plogis(const * (d[, i] - w.quantile(target, q, s.weights)))
    }))

    colnames(out) <- paste0(i, "_", qu[[i]])
    out
  }))
}

get.s.d.denom.weightit <- function(s.d.denom = NULL, estimand = NULL, weights = NULL, treat = NULL, focal = NULL) {
  s.d.denom.specified <- is_not_null(s.d.denom)
  estimand.specified <- is_not_null(estimand)
  if (!is.factor(treat)) treat <- factor(treat)

  if (s.d.denom.specified) {
    allowable.s.d.denoms <- c("treated", "control", "pooled", "all", "weighted", "hedges")

    try.s.d.denom <- try(match_arg(s.d.denom, allowable.s.d.denoms), silent = TRUE)

    if (!null_or_error(try.s.d.denom)) {
      return(try.s.d.denom)
    }
  }

  if (estimand.specified) {
    allowable.estimands <- c("ATT", "ATC", "ATE", "ATO", "ATM")

    try.estimand <- try(match_arg(toupper(estimand), allowable.estimands), silent = TRUE)

    if (!null_or_error(try.estimand) && try.estimand %nin% c("ATC", "ATT")) {
      s.d.denom <- switch(try.estimand,
                          ATO = "weighted",
                          ATM = "weighted",
                          "pooled")
      return(s.d.denom)
    }
  }

  if (is_not_null(focal)) {
    return(focal)
  }

  if (is_null(weights) || all_the_same(weights)) {
    return("pooled")
  }

  for (tv in levels(treat)) {
    if (all_the_same(weights[treat == tv]) &&
        !all_the_same(weights[treat != tv])) {
      return(tv)
    }
  }

  "pooled"
}

get.s.d.denom.cont.weightit <- function(s.d.denom = NULL) {
  s.d.denom.specified <- is_not_null(s.d.denom)

  if (!s.d.denom.specified) {
    return("all")
  }

  allowable.s.d.denoms <- c("all", "weighted")

  try.s.d.denom <- try(match_arg(s.d.denom, allowable.s.d.denoms), silent = TRUE)

  if (!null_or_error(try.s.d.denom)) {
    return(try.s.d.denom)
  }

  "all"
}

.check_estimated_weights <- function(w, treat, treat.type, s.weights) {

  tw <- w * s.weights

  extreme.warn <- FALSE
  if (all_the_same(w)) {
    .wrn(sprintf("all weights are %s, possibly indicating an estimation failure", w[1L]))
  }
  else if (treat.type == "continuous") {
    w.cv <- sd(tw, na.rm = TRUE) / mean(tw, na.rm = TRUE)
    if (!is.finite(w.cv) || w.cv > 4) extreme.warn <- TRUE
  }
  else {
    t.levels <- unique(treat)
    bad.treat.groups <- setNames(rep.int(FALSE, length(t.levels)), t.levels)
    for (i in t.levels) {
      ti <- which(treat == i)
      if (all(is.na(w[ti])) || all(check_if_zero(w[ti]))) {
        bad.treat.groups[as.character(i)] <- TRUE
      }
      else if (!extreme.warn && sum(is.finite(tw[ti])) > 1L) {
        w.cv <- sd(tw[ti], na.rm = TRUE) / mean(tw[ti], na.rm = TRUE)
        if (!is.finite(w.cv) || w.cv > 4) {
          extreme.warn <- TRUE
        }
      }
    }

    if (any(bad.treat.groups)) {
      n <- sum(bad.treat.groups)
      .wrn(sprintf("all weights are `NA` or 0 in treatment %s %s",
                   ngettext(n, "group", "groups"),
                   word_list(t.levels[bad.treat.groups], quotes = TRUE)))
    }
  }

  if (extreme.warn) {
    .wrn("some extreme weights were generated. Examine them with `summary()` and maybe trim them with `trim()`")
  }

  if (any(tw < 0)) {
    .wrn("some weights are negative; these cannot be used in most model fitting functions")
  }
}

.subclass_ps_multi <- function(ps_mat, treat, estimand = "ATE", focal = NULL, subclass) {
  chk::chk_count(subclass)
  subclass <- round(subclass)

  estimand <- toupper(estimand)

  if (estimand %nin% c("ATE", "ATT")) {
    .err("only the ATE, ATT, and ATC are compatible with stratification weights")
  }

  if (is_not_null(focal)) {
    ps_mat <- ps_mat[, c(focal, setdiff(colnames(ps_mat), focal))]
  }

  ps_sub <- sub_mat <- ps_mat * 0

  for (i in colnames(ps_mat)) {
    if (estimand == "ATE") {
      sub <- as.integer(findInterval(ps_mat[, as.character(i)],
                                     quantile(ps_mat[, as.character(i)],
                                              seq(0, 1, length.out = subclass + 1L)),
                                     all.inside = TRUE))
    }
    else if (estimand == "ATT") {
      if (i != focal) {
        ps_mat[, as.character(i)] <- 1 - ps_mat[, as.character(i)]
      }

      sub <- as.integer(findInterval(ps_mat[, as.character(i)],
                                     quantile(ps_mat[treat == focal, as.character(i)],
                                              seq(0, 1, length.out = subclass + 1L)),
                                     all.inside = TRUE))
    }

    sub_tab <- table(treat, sub)

    if (any(sub_tab == 0L)) {
      sub <- .subclass_scoot(sub, treat, ps_mat[, i])
      sub_tab <- table(treat, sub)
    }

    sub <- as.character(sub)

    sub_totals <- colSums(sub_tab)
    sub_ps <- setNames(sub_tab[as.character(i), ] / sub_totals,
                       colnames(sub_tab))

    ps_sub[, i] <- sub_ps[sub]
    sub_mat[, i] <- sub

    if (ncol(ps_sub) == 2L) {
      ps_sub[, colnames(ps_sub) != i] <- 1 - ps_sub[, i]
      sub_mat[, colnames(sub_mat) != i] <- sub
      break
    }
  }

  attr(ps_sub, "sub_mat") <- sub_mat

  ps_sub
}

.subclass_ps_bin <- function(ps, treat, estimand = "ATE", subclass) {
  chk::chk_count(subclass)
  subclass <- round(subclass)

  estimand <- toupper(estimand)

  if (estimand %nin% c("ATE", "ATT", "ATC")) {
    .err("only the ATE, ATT, and ATC are compatible with stratification weights")
  }

  sub <- as.integer(findInterval(ps,
                                 quantile(switch(estimand,
                                                 "ATE" = ps,
                                                 "ATT" = ps[treat == 1],
                                                 "ATC" = ps[treat == 0]),
                                          seq(0, 1, length.out = subclass + 1L)),
                                 all.inside = TRUE))

  max_sub <- max(sub)
  sub_tab1 <- tabulate(sub[treat == 1], max_sub)
  sub_tab0 <- tabulate(sub[treat == 0], max_sub)

  if (any(sub_tab1 == 0L) || any(sub_tab0 == 0L)) {
    sub <- .subclass_scoot(sub, treat, ps)
    sub_tab1 <- tabulate(sub[treat == 1], max_sub)
    sub_tab0 <- tabulate(sub[treat == 0], max_sub)
  }

  sub_totals <- sub_tab1 + sub_tab0
  sub1_prop <- sub_tab1 / sub_totals

  sub_ps <- sub1_prop[sub]

  attr(sub_ps, "sub") <- sub

  sub_ps
}

.subclass_scoot <- function(sub, treat, x, min.n = 1L) {
  #Reassigns subclasses so there are no empty subclasses
  #for each treatment group. min.n is the smallest a
  #subclass is allowed to be.
  treat <- as.character(treat)
  unique.treat <- unique(treat, nmax = 2L)

  names(x) <- seq_along(x)
  names(sub) <- seq_along(sub)
  original.order <- names(x)

  nsub <- nunique(sub)

  #Turn subs into a contiguous sequence
  sub <- setNames(setNames(seq_len(nsub), sort(unique(sub)))[as.character(sub)],
                  original.order)

  if (any(table(treat) < nsub * min.n)) {
    .err("too many subclasses were requested")
  }

  for (t in unique.treat) {
    if (length(x[treat == t]) == nsub) {
      sub[treat == t] <- seq_len(nsub)
    }
  }

  sub_tab <- table(treat, sub)

  if (all(sub_tab > 0L)) {
    return(sub)
  }

  .soft_thresh <- function(x, minus = 1) {
    x <- x - minus
    x[x < 0] <- 0
    x
  }

  for (t in unique.treat) {
    for (n in seq_len(min.n)) {
      while (any(sub_tab[t, ] == 0L)) {
        first_0 <- which(sub_tab[t, ] == 0L)[1L]

        if (first_0 == nsub ||
            (first_0 != 1L &&
             sum(.soft_thresh(sub_tab[t, seq_len(first_0 - 1L)]) / abs(first_0 - seq_len(first_0 - 1L))) >=
             sum(.soft_thresh(sub_tab[t, seq(first_0 + 1L, nsub)]) / abs(first_0 - seq(first_0 + 1L, nsub))))) {
          #If there are more and closer nonzero subs to the left...
          first_non0_to_left <- max(seq_len(first_0 - 1L)[sub_tab[t, seq_len(first_0 - 1L)] > 0L])

          name_to_move <- names(sub)[which(x == max(x[treat == t & sub == first_non0_to_left]) &
                                             treat == t & sub == first_non0_to_left)[1L]]

          sub[name_to_move] <- first_0
          sub_tab[t, first_0] <- 1L
          sub_tab[t, first_non0_to_left] <- sub_tab[t, first_non0_to_left] - 1L

        }
        else {
          #If there are more and closer nonzero subs to the right...
          first_non0_to_right <- min(seq(first_0 + 1L, nsub)[sub_tab[t, seq(first_0 + 1L, nsub)] > 0L])
          name_to_move <- names(sub)[which(x == min(x[treat == t & sub == first_non0_to_right]) &
                                             treat == t & sub == first_non0_to_right)[1L]]
          sub[name_to_move] <- first_0
          sub_tab[t, first_0] <- 1L
          sub_tab[t, first_non0_to_right] <- sub_tab[t, first_non0_to_right] - 1L
        }
      }

      sub_tab[t, ] <- sub_tab[t, ] - 1L
    }
  }

  #Unsort
  sub[names(sub)]
}

stabilize_w <- function(weights, treat) {
  t.levels <- {
    if (is.factor(treat)) levels(treat)
    else unique(treat)
  }

  w.names <- names(weights)

  tab <- setNames(vapply(t.levels, function(x) mean_fast(treat == x), numeric(1L)),
                  t.levels)

  setNames(weights * tab[as.character(treat)], w.names)
}

.get_dens_fun <- function(use.kernel = FALSE, bw = NULL, adjust = NULL, kernel = NULL,
                          n = NULL, treat = NULL, density = NULL, weights = NULL) {
  if (is_null(n)) n <- 10L * length(treat)
  if (is_null(adjust)) adjust <- 1

  if (!isFALSE(use.kernel)) {
    if (isTRUE(use.kernel)) {
      .wrn('`use.kernel` is deprecated; use `density = "kernel"` instead. Setting `density = "kernel"`')
      density <- "kernel"
    }
    else {
      .wrn("`use.kernel` is deprecated")
    }
  }

  if (identical(density, "kernel")) {
    if (is_null(bw)) bw <- "nrd0"
    if (is_null(kernel)) kernel <- "gaussian"

    densfun <- function(p, log = FALSE) {
      d <- stats::density(p, n = n,
                          weights = weights / sum(weights),
                          give.Rkern = FALSE,
                          bw = bw,
                          adjust = adjust,
                          kernel = kernel)

      out <- with(d, approxfun(x = x, y = y))(p)

      if (log) out <- log(out)

      attr(out, "density") <- d
      out
    }
  }
  else {
    if (is_null(density)) .density <- function(x, log = FALSE) dnorm(x, log = log)
    else if (is.function(density)) .density <- function(x, log = FALSE) {
      if (utils::hasName(formals(density), "log")) density(x, log = log)
      else if (log) log(density(x))
      else density(x)
    }
    else if (identical(density, "dlaplace")) .density <- function(x, log = FALSE) {
      mu <- 0
      b <- 1
      if (log)
        -abs(x - mu) / b - log(2 * b)
      else
        exp(-abs(x - mu) / b) / (2 * b)
    }
    else if (is.character(density) && length(density) == 1L) {
      splitdens <- strsplit(density, "_", fixed = TRUE)[[1L]]

      splitdens1 <- get0(splitdens[1L], mode = "function", envir = parent.frame())

      if (is_null(splitdens1)) {
        .err(sprintf("%s is not an appropriate argument to `density` because %s is not an available function",
                     density, splitdens[1L]))
      }

      if (length(splitdens) > 1L && !can_str2num(splitdens[-1L])) {
        .err(sprintf("%s is not an appropriate argument to `density` because %s cannot be coerced to numeric",
                     density, word_list(splitdens[-1L], and.or = "or", quotes = TRUE)))
      }

      .density <- function(x, log = FALSE) {
        if (utils::hasName(formals(splitdens1), "log")) {
          out <- tryCatch(do.call(splitdens1, c(list(x, log = log), as.list(str2num(splitdens[-1L])))),
                          error = function(e) {
                            .err(sprintf("Error in applying density:\n  %s",
                                         conditionMessage(e)),
                                 tidy = FALSE)
                          })
        }
        else {
          out <- tryCatch(do.call(splitdens1, c(list(x), as.list(str2num(splitdens[-1L])))),
                          error = function(e) {
                            .err(sprintf("Error in applying density:\n  %s",
                                         conditionMessage(e)),
                                 tidy = FALSE)
                          })

          if (log) out <- log(out)
        }

        out
      }
    }
    else {
      .err("the argument to `density` cannot be evaluated as a density function")
    }

    densfun <- function(p, log = FALSE) {
      # sd <- sd(p)
      # sd <- sqrt(col.w.v(p, s.weights))
      dens <- .density(p, log = log)
      if (is_null(dens) || !is.numeric(dens) || anyNA(dens)) {
        .err("there was a problem with the output of `density`. Try another density function or leave it blank to use the Gaussian density")
      }

      if ((log && !all(is.finite(dens))) ||
          (!log && !all(dens > 0))) {
        .err("the input to density may not accept the full range of standardized treatment values or residuals")
      }

      x <- seq.int(min(p) - 3 * adjust * bw.nrd0(p),
                   max(p) + 3 * adjust * bw.nrd0(p),
                   length.out = n)
      attr(dens, "density") <- data.frame(x = x,
                                          y = .density(x, log = log))
      dens
    }
  }

  densfun
}

.get_w_from_ps_internal_bin <- function(ps, treat, estimand = "ATE",
                                        subclass = NULL, stabilize = FALSE) {

  estimand <- toupper(estimand)
  w <- rep.int(1, length(treat))

  #Assume treat is binary
  if (is_not_null(subclass)) {
    #Get MMW subclass propensity scores
    ps <- .subclass_ps_bin(ps, treat, estimand, subclass)
  }

  i0 <- which(treat == 0)

  if (estimand == "ATE") {
    w[i0] <- 1 / (1 - ps[i0])
    w[-i0] <- 1 / ps[-i0]
  }
  else if (estimand == "ATT") {
    w[i0] <- .p2o(ps[i0])
  }
  else if (estimand == "ATC") {
    w[-i0] <- .p2o(1 - ps[-i0])
  }
  else if (estimand == "ATO") {
    w[i0] <- ps[i0]
    w[-i0] <- 1 - ps[-i0]
  }
  else if (estimand == "ATM") {
    w[i0][ps[i0] < .5] <- .p2o(ps[i0][ps[i0] < .5])
    w[-i0][ps[-i0] > .5] <- .p2o(1 - ps[-i0][ps[-i0] > .5])
  }
  else if (estimand == "ATOS") {
    w[i0] <- 1 / (1 - ps[i0])
    w[-i0] <- 1 / ps[-i0]

    ps.sorted <- sort(c(ps, 1 - ps))
    z <- ps * (1 - ps)
    alpha.opt <- 0
    for (i in seq_len(sum(ps < .5))) {
      if (i == 1L || !check_if_zero(ps.sorted[i] - ps.sorted[i - 1L])) {
        alpha <- ps.sorted[i]
        a <- alpha * (1 - alpha)
        if (2 * a * sum(1 / z[z >= a]) / sum(z >= a) >= 1) {
          alpha.opt <- alpha
          break
        }
      }
    }
    w[!between(ps, c(alpha.opt, 1 - alpha.opt))] <- 0
  }

  names(w) <- if_null_then(names(treat), NULL)

  if (stabilize) {
    w <- stabilize_w(w, treat)
  }

  w
}

.get_w_from_ps_internal_multi <- function(ps, treat, estimand = "ATE", focal = NULL,
                                          subclass = NULL, stabilize = FALSE) {

  estimand <- toupper(estimand)
  w <- rep.int(0.0, length(treat))

  ps_mat <- ps

  if (is_not_null(subclass)) {
    #Get MMW subclass propensity scores
    ps_mat <- .subclass_ps_multi(ps_mat, treat, estimand, focal, subclass)
  }

  for (i in colnames(ps_mat)) {
    w[treat == i] <- 1 / ps_mat[treat == i, i]
  }

  if (estimand == "ATE") {
    # w <- w
  }
  else if (estimand %in% c("ATT", "ATC")) {
    in_f <- which(treat == focal)
    w[in_f] <- 1
    w[-in_f] <- w[-in_f] * ps_mat[-in_f, as.character(focal)]
  }
  else if (estimand == "ATO") {
    w <- w / rowSums(1 / ps_mat) #Li & Li (2019)
  }
  else if (estimand == "ATM") {
    w <- w * do.call("pmin", lapply(seq_col(ps_mat), function(x) ps_mat[, x]), quote = TRUE)
  }
  else if (estimand == "ATOS") {
    #Crump et al. (2009)
    ps.sorted <- sort(c(ps_mat[, 2L], 1 - ps_mat[, 2L]))
    z <- ps_mat[, 2L] * (1 - ps_mat[, 2L])
    alpha.opt <- 0
    for (i in seq_len(sum(ps_mat[, 2L] < .5))) {
      if (i == 1L || !check_if_zero(ps.sorted[i] - ps.sorted[i - 1L])) {
        alpha <- ps.sorted[i]
        a <- alpha * (1 - alpha)
        if (2 * a * sum(1 / z[z >= a]) / sum(z >= a) >= 1) {
          alpha.opt <- alpha
          break
        }
      }
    }
    w[!between(ps_mat[, 2L], c(alpha.opt, 1 - alpha.opt))] <- 0
  }
  else {
    return(numeric(0L))
  }

  if (stabilize) {
    w <- stabilize_w(w, treat)
  }

  names(w) <- if_null_then(rownames(ps_mat), names(treat), NULL)

  w
}

.get_w_from_ps_internal_array <- function(ps, treat, estimand = "ATE", focal = NULL,
                                          subclass = NULL, stabilize = FALSE) {
  #Batch turn PS into weights; primarily for output of predict.gbm
  # Assumes a (0,1) treatment if binary
  if (is_null(dim(ps))) {
    ps <- matrix(ps, ncol = 1L)
  }

  eps <- 1e-8

  if (length(dim(ps)) == 2L) {
    #Binary treatment, vector ps

    w <- ps
    w[] <- 0

    if (is_not_null(subclass)) {
      #Get MMW subclass propensity scores
      for (p in seq_col(ps)) {
        ps[, p] <- .subclass_ps_bin(ps[, p], treat, estimand, subclass)
      }
    }

    t1 <- which(treat == 1)
    t0 <- which(treat == 0)

    if (estimand == "ATE") {
      ps[t1, ][ps[t1, ] < eps] <- eps
      ps[t0, ][ps[t0, ] > 1 - eps] <- 1 - eps

      w[t1, ] <- 1 / ps[t1, ]
      w[t0, ] <- 1 / (1 - ps[t0, ])
    }
    else if (estimand == "ATT") {
      ps[t0, ][ps[t0, ] > 1 - eps] <- 1 - eps

      w[t1, ] <- 1
      w[t0, ] <- .p2o(ps[t0, ])
    }
    else if (estimand == "ATC") {
      ps[t1, ][ps[t1, ] < eps] <- eps

      w[t1, ] <- .p2o(1 - ps[t1, ])
      w[t0, ] <- 1
    }
    else if (estimand == "ATO") {
      w[t1, ] <- 1 - ps[t1, ]
      w[t0, ] <- ps[t0, ]
    }
    else if (estimand == "ATM") {
      pslt.5 <- ps < .5
      w[t1, ][pslt.5[t1, ]] <- 1
      w[t1, ][!pslt.5[t1, ]] <- .p2o(1 - ps[t1, ][!pslt.5[t1, ]])

      w[t0, ][pslt.5[t0, ]] <- .p2o(ps[t0, ][pslt.5[t0, ]])
      w[t0, ][!pslt.5[t0, ]] <- 1
    }

    if (stabilize) {
      w[t1] <- w[t1] * length(t1) / length(treat)
      w[t0] <- w[t0] * length(t0) / length(treat)
    }
  }
  else if (length(dim(ps)) == 3L) {
    #Multi-category treatment, matrix PS

    if (is_not_null(subclass)) {
      #Get MMW subclass propensity scores
      for (p in seq_len(last(dim(ps))))
        ps[, , p] <- .subclass_ps_multi(ps[, , p], treat, estimand, focal, subclass)
    }

    ps <- squish(ps, eps)

    w <- matrix(0.0, ncol = dim(ps)[3L], nrow = dim(ps)[1L])
    t.levs <- unique(treat)

    for (i in t.levs) {
      w[treat == i, ] <- 1 / ps[treat == i, as.character(i), ]
    }

    if (estimand == "ATE") {
      #Do nothing
    }
    else if (estimand %in% c("ATT", "ATC")) {
      not_focal <- which(treat != focal)
      w[-not_focal, ] <- 1
      w[not_focal, ] <- w[not_focal, ] * ps[not_focal, as.character(focal), ]
    }
    else if (estimand == "ATO") {
      w <- w / colSums(aperm(1 / ps, c(2L, 1L, 3L)))
    }
    else if (estimand == "ATM") {
      treat <- as.integer(treat)

      for (p in seq_len(dim(ps)[3L])) {
        ps_p <- ps[, , p]
        min_ind <- max.col(-ps_p, ties.method = "first")
        no_match <- which(ps_p[cbind(seq_along(treat), treat)] != ps_p[cbind(seq_along(treat), min_ind)])

        if (length(no_match) < length(treat)) {
          w[-no_match, p] <- 1
        }

        if (is_not_null(no_match)) {
          w[no_match, p] <- w[no_match, p] * ps_p[cbind(no_match, min_ind[no_match])]
        }
      }
    }

    if (stabilize) {
      for (i in t.levs) {
        w[treat == i, ] <- mean_fast(treat == i) * w[treat == i, ]
      }
    }
  }
  else {
    .err("don't know how to process more than 3 dims (likely a bug)")
  }

  w
}

#Derivative of weights wrt ps for different estimands
.dw_dp_bin <- function(p, treat, estimand = "ATE") {
  estimand <- toupper(estimand)

  dw <- rep.int(0, length(treat))

  i0 <- which(treat == 0)

  if (estimand == "ATE") {
    dw[i0] <- (1 - p[i0])^(-2)
    dw[-i0] <- -p[-i0]^(-2)
  }
  else if (estimand == "ATT") {
    dw[i0] <- (1 - p[i0])^(-2)
  }
  else if (estimand == "ATC") {
    dw[-i0] <- -p[-i0]^(-2)
  }
  else if (estimand == "ATO") {
    dw[i0] <- 1
    dw[-i0] <- -1
  }
  else if (estimand == "ATM") {
    dw[i0][p[i0] < .5] <- (1 - p[i0][p[i0] < .5])^(-2)
    dw[-i0][p[-i0] > .5] <- -p[-i0][p[-i0] > .5]^(-2)
  }

  dw
}

#Derivative of weights wrt ps for different estimands
.dw_dp_multi <- function(p, treat, estimand = "ATE", focal = NULL) {
  estimand <- toupper(estimand)

  dw <- array(0, dim = dim(p), dimnames = dimnames(p))

  pA <- numeric(nrow(p))

  for (k in levels(treat)) {
    pA[treat == k] <- p[treat == k, k]
  }

  if (is_not_null(focal)) {
    pF <- p[, focal]

    for (i in setdiff(levels(treat), focal)) {
      dw[treat == i, focal] <- 1 / pA[treat == i]
      dw[treat == i, i] <- -pF[treat == i] / pA[treat == i]^2
    }
  }
  else if (estimand == "ATE") {
    for (k in levels(treat)) {
      dw[treat == k, k] <- -1 / pA[treat == k]^2
    }
  }
  else if (estimand == "ATO") {
    S <- 1 / rowSums(1 / p)
    for (k in levels(treat)) {
      dw[treat == k, k] <- (S[treat == k] / pA[treat == k]^2) * (S[treat == k] / pA[treat == k] - 1)

      for (j in setdiff(levels(treat), k)) {
        dw[treat == k, j] <- (1 / pA[treat == k]) * (S[treat == k] / p[treat == k, j])^2
      }
    }
  }
  else if (estimand == "ATM") {
    M <- do.call("pmin", as.data.frame(p))

    for (k in levels(treat)) {
      m1 <- p[, k] == M & treat != k
      dw[m1, k] <- 1 / pA[m1]

      m2 <- p[, k] != M & treat == k
      dw[m2, k] <- -M[m2] / pA[m2]^2
    }
  }

  dw
}

plot_density <- function(d.n, d.d, log = FALSE) {
  d.d <- cbind(as.data.frame(d.d[c("x", "y")]), dens = "Denominator Density", stringsAsfactors = FALSE)
  d.n <- cbind(as.data.frame(d.n[c("x", "y")]), dens = "Numerator Density", stringsAsfactors = FALSE)
  d.all <- rbind(d.d, d.n)
  d.all$dens <- factor(d.all$dens, levels = c("Numerator Density", "Denominator Density"))

  if (log) {
    d.all$x <- exp(d.all$x)
  }

  pl <- ggplot(d.all, aes(x = .data$x, y = .data$y)) +
    geom_line() +
    labs(title = "Weight Component Densities", x = "E[Treat|X]", y = "Density") +
    facet_grid(rows = vars(.data$dens)) +
    theme(panel.background = element_rect(fill = "white"),
          panel.border = element_rect(fill = NA, color = "black"),
          axis.text.x = element_text(color = "black"),
          axis.text.y = element_text(color = "black"),
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank())
  print(pl)
}

neg_ent <- function(w) {
  w <- w[w > 0]
  w <- w / mean_fast(w)
  mean(w * log(w))
}

replace_na_with <- function(covs, with = "median") {
  if (is.na(with) || !anyNA(covs)) {
    return(covs)
  }

  if (is.character(with)) {
    .with <- match.fun(with)
    for (i in colnames(covs)[anyNA_col(covs)]) {
      if (all(is.na(covs[, i]))) covs <- covs[, colnames(covs) != i, drop = FALSE]
      else covs[is.na(covs[, i]), i] <- .with(covs[, i], na.rm = TRUE)
    }
  }
  else {
    covs[is.na(covs)] <- with
  }

  covs
}

add_missing_indicators <- function(covs, replace_with = "median") {
  covs_w_missing <- which(anyNA_col(covs))
  if (is_null(covs_w_missing)) {
    return(covs)
  }

  missing_ind <- apply(covs[, covs_w_missing, drop = FALSE], 2L, function(x) as.numeric(is.na(x)))

  colnames(missing_ind) <- paste0(colnames(missing_ind), ":<NA>")
  covs <- cbind(covs, missing_ind)

  if (is_null(replace_with) || is.na(replace_with)) {
    return(covs)
  }

  replace_na_with(covs, replace_with)
}

verbosely <- function(expr, verbose = TRUE) {
  if (verbose) {
    return(expr)
  }

  invisible(utils::capture.output({
    out <- invisible(expr)
  }))

  out
}

#Generalized matrix inverse (port of MASS::ginv)
generalized_inverse <- function(sigma, .try = TRUE) {

  if (!.try) {
    sigmasvd <- svd(sigma)
    pos <- sigmasvd$d > max(1e-9 * sigmasvd$d[1L], 0)
    sigma_inv <- sigmasvd$v[, pos, drop = FALSE] %*% (sigmasvd$d[pos]^(-1) * t(sigmasvd$u[, pos, drop = FALSE]))

    return(sigma_inv)
  }

  tryCatch(solve(sigma),
           error = function(e) {
             generalized_inverse(sigma, .try = FALSE)
           })
}

#Compute gradient numerically using centered difference
.gradient <- function(.f, .x, .eps = 1e-8, .parm = NULL, .direction = "center", .method = "fd", ...) {
  .method <- match_arg(.method, c("fd", "richardson"))

  if (.method == "fd") {
    .gradientFD(.f = .f, .x = .x, .eps = .eps, .parm = .parm, .direction = .direction, ...)
  }
  else if (.method == "richardson") {
    .gradientRich(.f = .f, .x = .x, .eps = .eps, .parm = .parm, .direction = .direction, ...)
  }

}

#Finite difference gradient
.gradientFD <- function(.f, .x, .eps = 1e-8, .parm = NULL, .direction = "center", ...) {

  .direction <- match_arg(.direction, c("center", "left", "right"))

  if (is_null(.parm)) {
    .parm <- seq_along(.x)
  }

  .x0 <- .x

  .eps <- squish(abs(.x) * .eps, lo = .eps, hi = Inf)

  if (.direction != "center") {
    .f0 <- .f(.x0, ...)
  }

  for (jj in seq_along(.parm)) {
    j <- .parm[jj]

    if (.direction == "center") {
      .x[j] <- .x0[j] + .eps[j] / 2

      f_new_r <- .f(.x, ...)
    }
    else if (.direction == "left") {
      f_new_r <- .f0
    }
    else if (.direction == "right") {
      .x[j] <- .x0[j] + .eps[j]

      f_new_r <- .f(.x, ...)
    }

    if (j == 1L) {
      jacob <- matrix(0, nrow = length(f_new_r), ncol = length(.parm),
                      dimnames = list(names(f_new_r), names(.x)[.parm]))
    }

    if (.direction == "center") {
      .x[j] <- .x0[j] - .eps[j] / 2

      f_new_l <- .f(.x, ...)
    }
    else if (.direction == "left") {
      x[j] <- .x0[j] - .eps[j]

      f_new_l <- .f(.x, ...)
    }
    else if (.direction == "right") {
      f_new_l <- .f0
    }

    jacob[, jj] <- (f_new_r - f_new_l) / .eps[j]

    .x[j] <- .x0[j]
  }

  jacob
}

#Using Richardson extrapolation
.gradientRich <- function(.f, .x, .eps = 1e-8, .parm = NULL, .direction = "center", ...) {

  .direction <- match_arg(.direction, c("center", "left", "right"))

  if (is_null(.parm)) {
    .parm <- seq_along(.x)
  }

  if (.direction != "center") {
    .f0 <- .f(.x, ...)
  }

  n <- length(.x)

  d <- 1e-4
  r <- 4
  v <- 2
  a <- NULL
  h <- abs(d * .x) + .eps * (abs(.x) < 1e-5)

  for (k in seq_len(r)) {
    eps_i <- rep(0, length(.parm))

    for (ii in seq_along(.parm)) {
      i <- .parm[ii]

      eps_i[i] <- h[i]

      a_k_ii <- switch(.direction,
                       "center" = (.f(.x + eps_i, ...) - .f(.x - eps_i, ...)) / (2 * h[i]),
                       "right" = (.f(.x + 2 * eps_i, ...) - .f0) / (2 * h[i]),
                       "left" = (.f0 - .f(.x - 2 * eps_i, ...)) / (2 * h[i]))

      if (is_null(a)) {
        a <- array(NA_real_, dim = c(length(a_k_ii), r, n))
      }

      a[, k, ii] <- a_k_ii

      eps_i[i] <- 0
    }

    h <- h / v
  }

  for (m in seq_len(r - 1L)) {
    a <- (a[, 1L + seq_len(r - m), , drop = FALSE] * (4^m) - a[, seq_len(r - m), , drop = FALSE]) / (4^m - 1)
  }

  array(a, dim = dim(a)[c(1L, 3L)],
        dimnames = list(names(a_k_ii), names(.x)[.parm]))
}

#Convert probability to odds
.p2o <- function(p) {
  p / (1 - p)
}

#Get psi function (individual contributions to gradient) from glm fit
.get_glm_psi <- function(fit) {
  fam <- fit$family

  if (is_null(fam) ||
      identical(fam, gaussian(), ignore.environment = TRUE)) {
    psi <- function(B, X, y, weights, offset = 0) {
      p <- drop(X %*% B) + offset

      X * (weights * (y - p))
    }
  }
  else if (inherits(fit, "brglmFit") &&
           !identical(fit$type, "ML") &&
           !identical(fit$type, "correction")) {
    br_type <- fit$type

    if (is_null(fit$control[["a"]])) {
      rlang::check_installed("brglm2")
      fit$control[["a"]] <- eval(formals(brglm2::brglmControl)[["a"]])
    }

    br_psi <- function(X, W, d, p, XB, V) {
      DD <- fam$d2mu.deta(XB)
      Wt <- W * d^2 / V #"working weight"

      ## Compute hat values
      XWt <- sqrt(Wt) * X
      qrXWT <- qr(XWt)
      Q <- qr.Q(qrXWT)
      H <- rowSums(Q * Q)

      if (br_type %in% c("AS_mixed", "AS_mean")) {
        AA <- .5 * X * H * DD / d
        return(AA)
      }

      V1 <- fam$d1variance(p)

      if (br_type == "MPL_Jeffreys") {
        return(fit$control[["a"]] * X * H * (2 * DD / d - V1 * d / V))
      }

      #br_type == "AS_median"
      R_matrix <- qr.R(qrXWT)
      info_unscaled <- crossprod(R_matrix)
      inverse_info_unscaled <- chol2inv(R_matrix)

      b_vector <- vapply(seq_col(X), function(j) {
        inverse_info_unscaled_j <- inverse_info_unscaled[j, ]
        vcov_j <- tcrossprod(inverse_info_unscaled_j) / inverse_info_unscaled_j[j]
        hats_j <- rowSums((X %*% vcov_j) * X) * Wt

        inverse_info_unscaled_j %*% colSums(X * (hats_j * (d * V1 / (6 * V) - DD / (2 * d))))
      }, numeric(1L))

      AA <- .5 * X * H * DD / d
      sweep(AA, 2L, info_unscaled %*% b_vector / nrow(X), "+")
    }

    psi <-  function(B, X, y, weights, offset = 0) {
      XB <- drop(X %*% B) + offset
      p <- fam$linkinv(XB)
      d <- fam$mu.eta(XB)
      V <- fam$variance(p)

      .psi <- X * (weights * d * (y - p) / V)

      .psi + br_psi(X, weights, d, p, XB, V)
    }
  }
  else if (identical(fam, binomial(), ignore.environment = TRUE) ||
           identical(fam, quasibinomial(), ignore.environment = TRUE) ||
           identical(fam, poisson(), ignore.environment = TRUE) ||
           identical(fam, quasipoisson(), ignore.environment = TRUE)) {
    psi <- function(B, X, y, weights, offset = 0) {
      XB <- drop(X %*% B) + offset
      p <- fam$linkinv(XB)

      X * (weights * (y - p))
    }
  }
  else {
    psi <- function(B, X, y, weights, offset = 0) {
      XB <- drop(X %*% B) + offset
      p <- fam$linkinv(XB)
      d <- fam$mu.eta(XB)
      V <- fam$variance(p)

      X * (weights * d * (y - p) / V)
    }
  }

  psi
}

.make_link <- function(link) {
  link0 <- try(make.link(link), silent = TRUE)

  if (!null_or_error(link0)) {
    return(link)
  }

  if (!chk::vld_string(link) || !link %in% c("clog", "loglog")) {
    .err("link function not recognized")
  }

  if (link == "clog") {
    linkfun <- function(mu) -log(1 - mu)
    linkinv <- function(eta) squish(1 - exp(-eta), -Inf, 1 - .Machine$double.eps)
    mu.eta <- function(eta) squish(exp(-eta), .Machine$double.eps, Inf)
    valideta <- function(eta) TRUE
    name <- "clog"
  }
  else if (link == "loglog") {
    linkfun <- function(mu) -log(-log(mu))
    linkinv <- function(eta) squish(exp(-exp(-eta)), .Machine$double.eps)
    mu.eta <- function(eta) {
      eta <- squish(eta, -Inf, 700)
      squish(exp(-eta - exp(-eta)), .Machine$double.eps, Inf)
    }
    valideta <- function(eta) TRUE
    name <- "loglog"
  }

  out <- list(linkfun = linkfun,
              linkinv = linkinv,
              mu.eta = mu.eta,
              valideta = valideta,
              name = name)
  class(out) <- "link-glm"
  out
}

.get_glm_starting_values <- function(X, Y, w, family, offset = NULL) {

  if (is_null(w)) {
    w <- rep.int(1, length(Y))
  }

  if (is_null(offset)) {
    offset <- rep.int(0, length(Y))
  }

  mustart <- .25 + .5 * Y

  suppressWarnings({
    fit <- try(glm.fit(X, Y, weights = w, offset = offset, family = family,
                       mustart = mustart, control = list(maxit = 1e4L)), silent = TRUE)
  })

  if (!null_or_error(fit) && isTRUE(fit$converged)) {
    return(fit$coefficients)
  }

  coef_start <- c(family$linkfun(w.m(Y, w)), rep.int(0, ncol(X) - 1L))

  suppressWarnings({
    fit <- try(glm.fit(X, Y, weights = w, offset = offset, family = family,
                       start = coef_start, control = list(maxit = 1e4L)), silent = TRUE)
  })

  if (null_or_error(fit) || !isTRUE(fit$converged)) {
    return(coef_start)
  }

  fit$coefficients
}
ngreifer/WeightIt documentation built on March 6, 2025, 2:04 a.m.