_dev/new.R

#New methods and functions

#------Preliminary template----
weightit2XXX <- function(covs, treat...) {
  stop("method = \"XXX\" isn't ready to use yet.", call. = FALSE)
}
#------Template----
weightit2XXX <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
  A <- list(...)

  covs <- covs[subset, , drop = FALSE]
  treat <- factor(treat[subset])

  if (missing == "ind") {
    missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    covs[is.na(covs)] <- 0
    covs <- cbind(covs, missing.ind)
  }

  covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int, center = TRUE))
  covs <- apply(covs, 2, make.closer.to.1)

  new.data <- data.frame(treat, covs)
  new.formula <- formula(new.data)

  for (f in names(formals(PACKAGE::FUNCTION))) {
    if (is_null(A[[f]])) A[[f]] <- formals(PACKAGE::FUNCTION)[[f]]
  }
  A[names(A) %in% names(formals(weightit2XXX))] <- NULL

  A[["formula"]] <- new.formula
  A[["data"]] <- new.data
  A[["estimand"]] <- estimand
  A[["s.weights"]] <- s.weights[subset]
  A[["focal"]] <- focal
  A[["verbose"]] <- TRUE

  if (check.package("optweight")) {
    out <- do.call(PACKAGE::FUNCTION, A, quote = TRUE)
    obj <- list(w = out[["weights"]], fit.obj = out)
    return(obj)
  }
}

#------Under construction----

weightit2enet <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, moments, int, ...) {
  A <- list(...)

  covs <- covs[subset, , drop = FALSE]
  treat <- factor(treat[subset])
  s.weights <- s.weights[subset]

  if (anyNA(covs) && missing == "ind") {
    missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    covs[is.na(covs)] <- 0
    covs <- cbind(covs, missing.ind)
  }

  covs <- apply(covs, 2, make.closer.to.1)
  model.covs <- cbind(covs, int.poly.f(covs, int = int, poly = moments))
  model.covs <- apply(model.covs, 2, make.closer.to.1)

  if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
  treat.type <- get_treat_type(treat)

  if (is_null(A[["stop.method"]])) {
    warning("No stop.method was provided. Using \"es.mean\".",
            call. = FALSE, immediate. = TRUE)
    A[["stop.method"]] <- "es.mean"
  }
  else if (length(A[["stop.method"]]) > 1) {
    warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
            call. = FALSE, immediate. = TRUE)
    A[["stop.method"]] <- A[["stop.method"]][1]
  }

  cv <- 0
  available.stop.methods <- bal_criterion(treat.type, list = TRUE)
  s.m.matches <- charmatch(A[["stop.method"]], available.stop.methods)
  if (is.na(s.m.matches) || s.m.matches == 0L) {
    if (startsWith(A[["stop.method"]], "cv") && can_str2num(numcv <- substr(A[["stop.method"]], 3, nchar(A[["stop.method"]])))) {
      cv <- round(str2num(numcv))
      if (cv < 3) stop("At least 3 CV-folds must be specified in stop.method.", call. = FALSE)
    }
    else stop(paste0("'stop.method' must be one of ", word_list(c(available.stop.methods, "cv{#}"), "or", quotes = TRUE), "."), call. = FALSE)
  }
  else stop.method <- available.stop.methods[s.m.matches]

  tunable <- c("alpha", "relax", "type.multinomial", "reg.covs")

  trim.at <- if_null_then(A[["trim.at"]], 0)
  if (is_null(A[["alpha"]])) A[["alpha"]] <- 1 - .0001
  if (is_null(A[["thresh"]])) A[["thresh"]] <- 1e-7
  if (is_null(A[["maxit"]])) A[["maxit"]] <- 10^5
  if (is_null(A[["relax"]])) A[["relax"]] <- FALSE
  if (is_null(A[["reg.covs"]])) A[["reg.covs"]] <- TRUE
  nlambda <- if_null_then(A[["nlambda"]], 5000)

  if (moments == 1 && !int && any(!A[["reg.covs"]])) {
    stop("If moments = 1 and int = FALSE (the default), 'reg.covs' cannot be FALSE.", call. = FALSE)
  }

  if (is_null(A[["lambda"]])) {
    lambda <- c(exp(seq(log(1/ncol(model.covs)), log(1/ncol(model.covs)/nlambda), length.out = nlambda - 1)), 0)
  }
  else {
    if (is.numeric(A[["lambda"]])) {
      lambda <- sort.int(A[["lambda"]], decreasing = TRUE)
      nlambda <- length(lambda)
    }
    else {
      stop("'lambda' must be a numeric vector.")
    }
  }

  if (treat.type == "binary")  {
    family <- "binomial"
    treat <- binarize(treat, one = focal)
    if (is_not_null(focal)) focal <- "1"
    A[["type.multinomial"]] <- NULL
  }
  else {
    family <- "multinomial"
    A[["type.multinomial"]] <- if_null_then(A[["type.multinomial"]], "ungrouped")
  }

  tune <- do.call("expand.grid", c(A[names(A) %in% tunable],
                                   list(stringsAsFactors = FALSE, KEEP.OUT.ATTRS = FALSE)))
  if (cv == 0) {
    start.lambda <- if_null_then(A[["start.lambda"]], 1)
    if (is_null(A[["n.grid"]])) {
      n.grid <- round(1 + sqrt(2*(nlambda-start.lambda+1)))
    }
    else if (!is_(A[["n.grid"]], "numeric") || length(A[["n.grid"]]) > 1 ||
             !between(A[["n.grid"]], c(2, nlambda))) {
      stop(paste0("'n.grid' must be a numeric value between 2 and ", nlambda, "."), call. = FALSE)
    }
    else n.grid <- round(A[["n.grid"]])

    if (n.grid >= nlambda/3) n.grid <- nlambda

    balance.covs <- if_null_then(A[["balance.covs"]], covs)
    crit <- bal_criterion(treat.type, stop.method)
    init <- crit$init(balance.covs, treat, estimand = estimand, s.weights = s.weights, focal = focal)
  }
  else {
    foldid <- sample(rep(seq_len(cv), length = length(treat)))
    type.measure <- if_null_then(A[["type.measure"]], "default")
    A[["type.measure"]] <-  match_arg(type.measure, formals(glmnet::cv.glmnet)[["type.measure"]])
  }

  current.best.loss <- Inf

  for (i in seq_row(tune)) {

    A[["penalty.factor"]] <- rep(1, ncol(model.covs))
    if (!tune[["reg.covs"]][i]) A[["penalty.factor"]][seq_col(covs)] <- 0
    gamma <- ifelse(tune[["relax"]][i], 0,  1)

    if (cv == 0) {
      fit <- do.call(glmnet::glmnet, list(model.covs, treat, family = family, standardize = FALSE,
                                          lambda = lambda, alpha = tune[["alpha"]][i], thresh = A[["thresh"]],
                                          maxit = A[["maxit"]], relax = tune[["relax"]][i], weights = s.weights,
                                          penalty.factor = A[["penalty.factor"]],
                                          type.multinomial = tune[["type.multinomial"]][i]))

      if (treat.type == "binary") {
        treat <- binarize(treat, one = focal)
        if (is_not_null(focal)) focal <- "1"
      }

      iters <- seq_along(fit$lambda)
      iters.grid <- round(seq(1, length(fit$lambda), length.out = n.grid))

      ps <- predict(fit, newx = model.covs, type = "response", s = fit$lambda[iters.grid], gamma = gamma)
      w <- get.w.from.ps(ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass)
      if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))

      iter.grid.balance <- apply(w, 2, function(w_) {
        crit$fun(init = init, weights = w_)
      })

      if (n.grid == nlambda) {
        best.lambda.index <- which.min(iter.grid.balance)
        best.loss <- iter.grid.balance[best.lambda.index]
        best.lambda <- lambda[best.lambda.index]
        lambda.val <- setNames(data.frame(fit$lambda,
                                          iter.grid.balance),
                               c("lambda", stop.method))
      }
      else {
        it <- which.min(iter.grid.balance) + c(-1, 1)
        it[1] <- iters.grid[max(1, it[1])]
        it[2] <- iters.grid[min(length(iters.grid), it[2])]
        iters.to.check <- iters[between(iters, iters[it])]

        if (is_null(iters.to.check) || anyNA(iters.to.check) || any(iters.to.check > nlambda)) stop("A problem has occurred")

        ps <- predict(fit, newx = model.covs, type = "response", s = fit$lambda[iters.to.check], gamma = gamma)
        w <- get.w.from.ps(ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass)
        if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))

        iter.grid.balance.fine <- apply(w, 2, function(w_) {
          crit$fun(init = init, weights = w_)
        })

        best.lambda.index <- which.min(iter.grid.balance.fine)
        best.loss <- iter.grid.balance.fine[best.lambda.index]
        best.lambda <- lambda[iters.to.check[best.lambda.index]]
        lambda.val <- setNames(data.frame(c(fit$lambda[iters.grid], fit$lambda[iters.to.check]),
                                          c(iter.grid.balance, iter.grid.balance.fine)),
                               c("lambda", stop.method))
      }

      lambda.val <- unique(lambda.val[order(lambda.val$lambda),])
      w <- w[,best.lambda.index]
      ps <- if (treat.type == "binary") ps[,best.lambda.index] else NULL

      tune[[paste.("best", stop.method)]][i] <- best.loss
      tune[["best.lambda"]][i] <- best.lambda

      if (best.loss < current.best.loss) {
        best.fit <- fit
        best.w <- w
        best.ps <- ps
        current.best.loss <- best.loss
        best.tune.index <- i

        info <- list(best.lambda = best.lambda,
                     lambda.val = lambda.val,
                     coef = predict(fit, type = "coef", s = best.lambda))
      }

    }
    else {
      fit <- do.call(glmnet::cv.glmnet, list(model.covs, treat, family = family, standardize = FALSE,
                                             lambda = lambda, alpha = tune[["alpha"]][i], thresh = A[["thresh"]],
                                             maxit = A[["maxit"]], relax = tune[["relax"]][i], weights = s.weights,
                                             penalty.factor = A[["penalty.factor"]],
                                             type.multinomial = tune[["type.multinomial"]][i],
                                             foldid = foldid))

      best.lambda.index <- which.min(fit$cvm)
      best.lambda <- fit$lambda[best.lambda.index]
      best.loss <- fit$cvm[best.lambda.index]

      tune[[paste.("best", names(fit$name))]][i] <- best.loss
      tune[["best.lambda"]][i] <- best.lambda

      if (best.loss < current.best.loss) {
        best.fit <- fit
        best.ps <- predict(fit, newx = model.covs, type = "response", s = best.lambda, gamma = gamma)
        best.w <- drop(get.w.from.ps(best.ps, treat = treat, estimand = estimand, focal = focal, stabilize = stabilize, subclass = subclass))
        current.best.loss <- best.loss
        best.tune.index <- i

        lambda.val <- setNames(data.frame(fit$lambda,
                                          fit$cvm),
                               c("lambda", names(fit$name)))

        info <- list(best.lambda = best.lambda,
                     lambda.val = lambda.val,
                     coef = predict(fit, type = "coef", s = best.lambda))

        if (treat.type == "multinomial") best.ps <- NULL
      }
    }
  }

  tune[tunable[vapply(tunable, function(x) length(A[[x]]) == 1, logical(1L))]] <- NULL

  if (ncol(tune) > 2) {
    info[["tune"]] <- tune
    info[["best.tune"]] <- tune[best.tune.index,]
  }

  obj <- list(w = best.w, ps = best.ps, info = info, fit.obj = best.fit)
  return(obj)

}
weightit2enet.cont <- function(covs, treat, s.weights, subset, estimand, focal, stabilize, subclass, missing, moments, int, ...) {
  A <- list(...)

  covs <- covs[subset, , drop = FALSE]
  treat <- treat[subset]
  s.weights <- s.weights[subset]

  if (anyNA(covs) && missing == "ind") {
    missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    covs[is.na(covs)] <- 0
    covs <- cbind(covs, missing.ind)
  }

  covs <- apply(covs, 2, make.closer.to.1)
  model.covs <- cbind(covs, int.poly.f(covs, int = int, poly = moments))
  model.covs <- apply(model.covs, 2, make.closer.to.1)

  treat <- make.closer.to.1(treat)

  if (is_null(A[["stop.method"]])) {
    warning("No stop.method was provided. Using \"p.mean\".",
            call. = FALSE, immediate. = TRUE)
    A[["stop.method"]] <- "p.mean"
  }
  else if (length(A[["stop.method"]]) > 1) {
    warning("Only one stop.method is allowed at a time. Using just the first stop.method.",
            call. = FALSE, immediate. = TRUE)
    A[["stop.method"]] <- A[["stop.method"]][1]
  }

  cv <- 0
  available.stop.methods <- bal_criterion("continuous", list = TRUE)
  s.m.matches <- charmatch(A[["stop.method"]], available.stop.methods)
  if (is.na(s.m.matches) || s.m.matches == 0L) {
    if (startsWith(A[["stop.method"]], "cv") && can_str2num(numcv <- substr(A[["stop.method"]], 3, nchar(A[["stop.method"]])))) {
      cv <- round(str2num(numcv))
      if (cv < 3) stop("At least 3 CV-folds must be specified in stop.method.", call. = FALSE)
    }
    else stop(paste0("'stop.method' must be one of ", word_list(c(available.stop.methods, "cv{#}"), "or", quotes = TRUE), "."), call. = FALSE)
  }
  else stop.method <- available.stop.methods[s.m.matches]

  tunable <- c("alpha", "relax", "reg.covs")

  trim.at <- if_null_then(A[["trim.at"]], 0)
  if (is_null(A[["alpha"]])) A[["alpha"]] <- 1 - .0001
  if (is_null(A[["thresh"]])) A[["thresh"]] <- 1e-14
  if (is_null(A[["maxit"]])) A[["maxit"]] <- 10^5
  if (is_null(A[["relax"]])) A[["relax"]] <- FALSE
  if (is_null(A[["reg.covs"]])) A[["reg.covs"]] <- TRUE
  gamma <- if (isTRUE(is_null(A[["relax"]]))) 0 else 1
  nlambda <- if_null_then(A[["nlambda"]], 5000)

  if (moments == 1 && !int && any(!A[["reg.covs"]])) {
    stop("If moments = 1 and int = FALSE (the default), 'reg.covs' cannot be FALSE.", call. = FALSE)
  }

  if (is_null(A[["lambda"]])) {
    lambda <- c(exp(seq(log(1/ncol(model.covs)), log(1/ncol(model.covs)/nlambda), length.out = nlambda - 1)), 0)
  }
  else {
    if (is.numeric(A[["lambda"]])) {
      lambda <- sort.int(A[["lambda"]], decreasing = TRUE)
      nlambda <- length(lambda)
    }
    else {
      stop("'lambda' must be a numeric vector.")
    }
  }

  family <- "gaussian"

  if (cv == 0) {
    start.lambda <- if_null_then(A[["start.lambda"]], 1)
    if (is_null(A[["n.grid"]])) {
      n.grid <- round(1 + sqrt(2*(nlambda-start.lambda+1)))
    }
    else if (!is_(A[["n.grid"]], "numeric") || length(A[["n.grid"]]) > 1 ||
             !between(A[["n.grid"]], c(2, nlambda))) {
      stop(paste0("'n.grid' must be a numeric value between 2 and ", nlambda, "."), call. = FALSE)
    }
    else n.grid <- round(A[["n.grid"]])

    if (n.grid >= nlambda/3) n.grid <- nlambda

    crit <- bal_criterion("continuous", stop.method)
    init <- crit$init(covs, treat, s.weights = s.weights)
  }
  else {
    foldid <- sample(rep(seq_len(cv), length = length(treat)))
    type.measure <- if_null_then(A[["type.measure"]], "default")
    A[["type.measure"]] <-  match_arg(type.measure, formals(glmnet::cv.glmnet)[["type.measure"]])
  }

  tune <- do.call("expand.grid", c(A[names(A) %in% tunable],
                                   list(stringsAsFactors = FALSE, KEEP.OUT.ATTRS = FALSE)))

  #Process density params
  if (isTRUE(A[["use.kernel"]])) {
    if (is_null(A[["bw"]])) A[["bw"]] <- "nrd0"
    if (is_null(A[["adjust"]])) A[["adjust"]] <- 1
    if (is_null(A[["kernel"]])) A[["kernel"]] <- "gaussian"
    if (is_null(A[["n"]])) A[["n"]] <- 10*length(treat)
    use.kernel <- TRUE
    densfun <- NULL
  }
  else {
    if (is_null(A[["density"]])) densfun <- dnorm
    else if (is.function(A[["density"]])) densfun <- A[["density"]]
    else if (is.character(A[["density"]]) && length(A[["density"]] == 1)) {
      splitdens <- strsplit(A[["density"]], "_", fixed = TRUE)[[1]]
      if (exists(splitdens[1], mode = "function", envir = parent.frame())) {
        if (length(splitdens) > 1 && !can_str2num(splitdens[-1])) {
          stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
                     word_list(splitdens[-1], and.or = "or", quotes = TRUE), "cannot be coerced to numeric."), call. = FALSE)
        }
        densfun <- function(x) {
          tryCatch(do.call(get(splitdens[1]), c(list(x), as.list(str2num(splitdens[-1])))),
                   error = function(e) stop(paste0("Error in applying density:\n  ", conditionMessage(e)), call. = FALSE))
        }
      }
      else {
        stop(paste(A[["density"]], "is not an appropriate argument to 'density' because",
                   splitdens[1], "is not an available function."), call. = FALSE)
      }
    }
    else stop("The argument to 'density' cannot be evaluated as a density function.", call. = FALSE)
    use.kernel <- FALSE
  }

  #Stabilization - get dens.num
  p.num <- treat - mean(treat)

  if (use.kernel) {
    d.n <- density(p.num, n = A[["n"]],
                   weights = s.weights/sum(s.weights), give.Rkern = FALSE,
                   bw = A[["bw"]], adjust = A[["adjust"]], kernel = A[["kernel"]])
    dens.num <- with(d.n, approxfun(x = x, y = y))(p.num)
  }
  else {
    dens.num <- densfun(p.num/sd(treat))
    if (is_null(dens.num) || !is.atomic(dens.num) || anyNA(dens.num)) {
      stop("There was a problem with the output of 'density'. Try another density function or leave it blank to use the normal density.", call. = FALSE)
    }
    else if (any(dens.num <= 0)) {
      stop("The input to 'density' may not accept the full range of treatment values.", call. = FALSE)
    }
  }

  current.best.loss <- Inf

  for (i in seq_row(tune)) {

    A[["penalty.factor"]] <- rep(1, ncol(model.covs))
    if (!tune[["reg.covs"]][i]) A[["penalty.factor"]][seq_col(covs)] <- 0

    if (cv == 0) {
      fit <- do.call(glmnet::glmnet, list(model.covs, treat, family = family, standardize = FALSE,
                                          lambda = lambda, alpha = tune[["alpha"]][i], thresh = A[["thresh"]],
                                          maxit = A[["maxit"]], relax = tune[["relax"]][i], weights = s.weights,
                                          penalty.factor = A[["penalty.factor"]]))


      iters <- seq_along(fit$lambda)
      iters.grid <- round(seq(1, length(fit$lambda), length.out = n.grid))

      gps <- predict(fit, newx = model.covs, s = fit$lambda[iters.grid], gamma = gamma)
      w <- get_cont_weights(gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
                            densfun = densfun, use.kernel = use.kernel, densControl = A)
      if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))

      iter.grid.balance <- apply(w, 2, function(w_) {
        crit$fun(init = init, weights = w_)
      })

      if (n.grid == nlambda) {
        best.lambda.index <- which.min(iter.grid.balance)
        best.loss <- iter.grid.balance[best.lambda.index]
        best.lambda <- lambda[best.lambda.index]
        lambda.val <- setNames(data.frame(fit$lambda,
                                          iter.grid.balance),
                               c("lambda", stop.method))
      }
      else {
        it <- which.min(iter.grid.balance) + c(-1, 1)
        it[1] <- iters.grid[max(1, it[1])]
        it[2] <- iters.grid[min(length(iters.grid), it[2])]
        iters.to.check <- iters[between(iters, iters[it])]

        if (is_null(iters.to.check) || anyNA(iters.to.check) || any(iters.to.check > nlambda)) stop("A problem has occurred")

        gps <- predict(fit, newx = model.covs, s = fit$lambda[iters.to.check], gamma = gamma)
        w <- get_cont_weights(gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
                              densfun = densfun, use.kernel = use.kernel, densControl = A)
        if (trim.at != 0) w <- suppressMessages(apply(w, 2, trim, at = trim.at, treat = treat))

        iter.grid.balance.fine <- apply(w, 2, function(w_) {
          crit$fun(init = init, weights = w_)
        })

        best.lambda.index <- which.min(iter.grid.balance.fine)
        best.loss <- iter.grid.balance.fine[best.lambda.index]
        best.lambda <- lambda[iters.to.check[best.lambda.index]]
        lambda.val <- setNames(data.frame(c(fit$lambda[iters.grid], fit$lambda[iters.to.check]),
                                          c(iter.grid.balance, iter.grid.balance.fine)),
                               c("lambda", stop.method))
      }

      lambda.val <- unique(lambda.val[order(lambda.val$lambda),])
      w <- w[,best.lambda.index]
      gps <- gps[,best.lambda.index]

      tune[[paste.("best", stop.method)]][i] <- best.loss
      tune[["best.lambda"]][i] <- best.lambda

      if (best.loss < current.best.loss) {
        best.fit <- fit
        best.w <- w
        best.gps <- gps
        current.best.loss <- best.loss
        best.tune.index <- i

        info <- list(best.lambda = best.lambda,
                     lambda.val = lambda.val,
                     coef = predict(fit, type = "coef", s = best.lambda))
      }

    }
    else {
      fit <- do.call(glmnet::cv.glmnet, list(model.covs, treat, family = family, standardize = FALSE,
                                             lambda = lambda, alpha = tune[["alpha"]][i], thresh = A[["thresh"]],
                                             maxit = A[["maxit"]], relax = tune[["relax"]][i], weights = s.weights,
                                             penalty.factor = A[["penalty.factor"]],
                                             foldid = foldid))

      best.lambda.index <- which.min(fit$cvm)
      best.lambda <- fit$lambda[best.lambda.index]
      best.loss <- fit$cvm[best.lambda.index]

      tune[[paste.("best", names(fit$name))]][i] <- best.loss
      tune[["best.lambda"]][i] <- best.lambda

      if (best.loss < current.best.loss) {
        best.fit <- fit
        best.gps <- predict(fit, newx = model.covs, s = best.lambda, gamma = gamma)
        best.w <- get_cont_weights(best.gps, treat = treat, s.weights = s.weights, dens.num = dens.num,
                                   densfun = densfun, use.kernel = use.kernel, densControl = A)
        current.best.loss <- best.loss
        best.tune.index <- i

        lambda.val <- setNames(data.frame(fit$lambda,
                                          fit$cvm),
                               c("lambda", names(fit$name)))

        info <- list(best.lambda = best.lambda,
                     lambda.val = lambda.val,
                     coef = predict(fit, type = "coef", s = best.lambda))
      }
    }
  }

  if (use.kernel && isTRUE(A[["plot"]])) {
    d.d <- density(treat - best.gps, n = A[["n"]],
                   weights = s.weights/sum(s.weights), give.Rkern = FALSE,
                   bw = A[["bw"]], adjust = A[["adjust"]],
                   kernel = A[["kernel"]])
    plot_density(d.n, d.d)
  }

  tune[tunable[vapply(tunable, function(x) length(A[[x]]) == 1, logical(1L))]] <- NULL

  if (ncol(tune) > 2) {
    info[["tune"]] <- tune
    info[["best.tune"]] <- tune[best.tune.index,]
  }

  obj <- list(w = best.w, info = info, fit.obj = best.fit)
  return(obj)

}

weightit2kernbal <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
  check.package("osqp")

  A <- list(...)

  n <- length(treat)
  covs <- covs[subset, , drop = FALSE]
  treat <- factor(treat[subset])
  s.weights <- s.weights[subset]

  if (!has_treat_type(treat)) treat <- assign_treat_type(treat)
  treat.type <- get_treat_type(treat)

  if (missing == "ind") {
    missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    if (is_not_null(missing.ind)) {
      covs[is.na(covs)] <- 0
      covs <- cbind(covs, missing.ind)
    }
  }

  covs <- mat_div(center(covs, at = col.w.m(covs, s.weights)),
                  sqrt(col.w.v(covs, s.weights)))

  if (is_not_null(A[["dist.mat"]])) {
    if (inherits(A[["dist.mat"]], "dist")) A[["dist.mat"]] <- as.matrix(A[["dist.mat"]])

    if (is.matrix(A[["dist.mat"]]) && all(dim(A[["dist.mat"]]) == n) &&
        all(check_if_zero(diag(A[["dist.mat"]]))) && !any(A[["dist.mat"]] < 0) &&
        isSymmetric(unname(A[["dist.mat"]]))) {
      d <- unname(A[["dist.mat"]][subset, subset])
    }
    else stop("'dist.mat' must be a square, symmetric distance matrix with a value for all pairs of units.", call. = FALSE)
  }
  else d <- as.matrix(dist(covs))

  n <- length(treat)
  levels_treat <- levels(treat)
  diagn <- diag(n)

  min.w <- if_null_then(A[["min.w"]], 1e-8)
  if (!is.numeric(min.w) || length(min.w) != 1 || min.w < 0) {
    warning("'min.w' must be a nonnegative number. Setting min.w = 1e-8.", call. = FALSE, immediate. = TRUE)
    min.w <- 1e-8
  }

  for (t in levels_treat) s.weights[treat == t] <- s.weights[treat == t]/mean(s.weights[treat == t])

  tmat <- vapply(levels_treat, function(t) treat == t, logical(n))
  nt <- colSums(tmat)

  J <- setNames(lapply(levels_treat, function(t) s.weights*tmat[,t]/nt[t]), levels_treat)

  if (estimand == "ATE") {
    J0 <- as.matrix(s.weights/n)

    M2_array <- vapply(levels_treat, function(t) -2 * tcrossprod(J[[t]]) * d, diagn)
    M1_array <- vapply(levels_treat, function(t) 2 * J[[t]] * d %*% J0, J0)

    M2 <- rowSums(M2_array, dims = 2)
    M1 <- rowSums(M1_array)

    if (!isFALSE(A[["improved"]])) {
      all_pairs <- combn(levels_treat, 2, simplify = FALSE)
      M2_pairs_array <- vapply(all_pairs, function(p) -2 * tcrossprod(J[[p[1]]]-J[[p[2]]]) * d, diagn)
      M2 <- M2 + rowSums(M2_pairs_array, dims = 2)
    }

    #Constraints for positivity and sum of weights
    Amat <- rbind(diagn, t(s.weights * tmat))
    lvec <- c(rep(min.w, n), nt)
    uvec <- c(ifelse(check_if_zero(s.weights), min.w, Inf), nt)
  }
  else {

    J0_focal <- as.matrix(J[[focal]])
    clevs <- levels_treat[levels_treat != focal]

    M2_array <- vapply(clevs, function(t) -2 * tcrossprod(J[[t]]) * d, diagn)
    M1_array <- vapply(clevs, function(t) 2 * J[[t]] * d %*% J0_focal, J0_focal)

    M2 <- rowSums(M2_array, dims = 2)
    M1 <- rowSums(M1_array)

    #Constraints for positivity and sum of weights
    Amat <- rbind(diagn, t(s.weights*tmat))
    lvec <- c(ifelse_(check_if_zero(s.weights), min.w, treat == focal, 1, min.w), nt)
    uvec <- c(ifelse_(check_if_zero(s.weights), min.w, treat == focal, 1, Inf), nt)
  }

  #Add weight penalty
  if (is_not_null(A[["lambda"]])) diag(M2) <- diag(M2) + A[["lambda"]] / n

  if (moments != 0 || int) {
    #Exactly balance moments and/or interactions
    covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))

    if (estimand == "ATE") targets <- col.w.m(covs, s.weights)
    else targets <- col.w.m(covs[treat == focal, , drop = FALSE], s.weights[treat == focal])

    Amat <- do.call("rbind", c(list(Amat),
                               lapply(levels_treat, function(t) {
                                 if (is_null(focal) || t != focal) t(covs * J[[t]])
                               })))
    lvec <- do.call("c", c(list(lvec),
                           lapply(levels_treat, function(t) {
                             if (is_null(focal) || t != focal) targets
                           })))
    uvec <- do.call("c", c(list(uvec),
                           lapply(levels_treat, function(t) {
                             if (is_null(focal) || t != focal) targets
                           })))
  }

  if (is_not_null(A[["eps"]])) {
    if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- A[["eps"]]
    if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- A[["eps"]]
  }
  A[names(A) %nin% names(formals(osqp::osqpSettings))] <- NULL
  if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 2E3L
  if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1E-8
  if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1E-8
  A[["verbose"]] <- TRUE

  options.list <- do.call(osqp::osqpSettings, A)

  opt.out <- do.call(osqp::solve_osqp, list(P = M2, q = M1, A = Amat, l = lvec, u = uvec,
                                            pars = options.list),
                     quote = TRUE)

  if (identical(opt.out$info$status, "maximum iterations reached")) {
    warning("The optimization failed to converge. See Notes section at ?method_energy for information.", call. = FALSE)
  }

  w <- opt.out$x

  if (estimand == "ATT") w[treat == focal] <- 1

  w[w <= min.w] <- min.w

  obj <- list(w = w, fit.obj = opt.out)
  return(obj)

}

#------Ready for use, but not ready for CRAN----
#KBAL
weightit2kbal <- function(covs, treat, s.weights, subset, estimand, focal, ...) {
  A <- list(...)

  covs <- covs[subset, , drop = FALSE]
  treat <- factor(treat[subset])

  covs <- apply(covs, 2, make.closer.to.1)

  if (any(vars.w.missing <- anyNA_col(covs))) {
    missing.ind <- apply(covs[, vars.w.missing, drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    covs[is.na(covs)] <- 0
    covs <- cbind(covs, missing.ind)
  }

  if ("kbal.method" %in% names(A)) {
    names(A)[names(A) == "kbal.method"] <- "method"
  }
  for (f in names(formals(KBAL::kbal))) {
    if (is_null(A[[f]])) A[[f]] <- formals(KBAL::kbal)[[f]]
  }
  A[names(A) %nin% setdiff(names(formals(KBAL::kbal)), c("X", "D"))] <- NULL

  if (check.package("KBAL")) {
    if (hasName(A, "method")) {
      if (A[["method"]] == "el") check.package(c("glmc", "emplik"))
    }

    if (estimand == "ATT") {
      w <- rep(1, length(treat))
      control.levels <- levels(treat)[levels(treat) != focal]
      fit.list <- setNames(vector("list", length(control.levels)), control.levels)

      covs[treat == focal,] <- covs[treat == focal, , drop = FALSE] * s.weights[subset][treat == focal] * sum(treat == focal)/sum(s.weights[subset][treat == focal])

      for (i in control.levels) {
        treat.in.i.focal <- treat %in% c(focal, i)
        treat_ <- ifelse(treat[treat.in.i.focal] == i, 0L, 1L)
        covs_ <- covs[treat.in.i.focal, , drop = FALSE]

        colinear.covs.to.remove <- colnames(covs_)[colnames(covs_) %nin% colnames(make_full_rank(covs_[treat_ == 0, , drop = FALSE]))]

        covs_ <- covs_[, colnames(covs_) %nin% colinear.covs.to.remove, drop = FALSE]

        kbal.out <- do.call(KBAL::kbal, c(list(X = covs_, D = treat_), args))

        w[treat == i] <- (kbal.out$w / s.weights[subset])[treat_ == 0L]
        fit.list[[i]] <- kbal.out
      }
    }
    else if (estimand == "ATE") {
      w <- rep(1, length(treat))
      fit.list <- setNames(vector("list", nlevels(treat)), levels(treat))

      for (i in levels(treat)) {
        covs_i <- rbind(covs, covs[treat==i, , drop = FALSE])
        treat_i <- c(rep(1, nrow(covs)), rep(0, sum(treat==i)))

        colinear.covs.to.remove <- colnames(covs_i)[colnames(covs_i) %nin% colnames(make_full_rank(covs_i[treat_i == 0, , drop = FALSE]))]

        covs_i <- covs_i[, colnames(covs_i) %nin% colinear.covs.to.remove, drop = FALSE]

        covs_i[treat_i == 1,] <- covs_i[treat_i == 1,] * s.weights[subset] * sum(treat_i == 1) / sum(s.weights[subset])

        kbal.out_i <- do.call(KBAL::kbal, c(list(X = covs_i, D = treat_i), args))

        w[treat == i] <- kbal.out_i$w[treat_i == 0] / s.weights[subset][treat == i]
        fit.list[[i]] <- kbal.out_i
      }
    }
  }

  obj <- list(w = w)
  return(obj)

}

#Empirical Balancing Calibration weights with ATE
weightit2ebcw <- function(covs, treat, s.weights, subset, estimand, focal, missing, moments, int, ...) {
  check.package("ATE")

  A <- list(...)

  covs <- covs[subset, , drop = FALSE]
  treat <- factor(treat[subset])
  s.weights <- s.weights[subset]

  if (missing == "ind") {
    missing.ind <- apply(covs[, anyNA_col(covs), drop = FALSE], 2, function(x) as.numeric(is.na(x)))
    if (is_not_null(missing.ind)) {
      covs[is.na(covs)] <- 0
      covs <- cbind(covs, missing.ind)
    }
  }

  covs <- cbind(covs, int.poly.f(covs, poly = moments, int = int))
  for (i in seq_col(covs)) covs[,i] <- make.closer.to.1(covs[,i])

  for (f in names(formals(ATE::ATE))) {
    if (is_null(A[[f]])) A[[f]] <- formals(ATE::ATE)[[f]]
  }

  if (estimand == "ATT") {
    w <- rep(1, length(treat))
    control.levels <- levels(treat)[levels(treat) != focal]
    fit.list <- make_list(control.levels)

    for (i in control.levels) {
      treat.in.i.focal <- treat %in% c(focal, i)
      treat_ <- as.integer(treat[treat.in.i.focal] != i)
      covs_ <- covs[treat.in.i.focal, , drop = FALSE]

      colinear.covs.to.remove <- colnames(covs_)[colnames(covs_) %nin% colnames(make_full_rank(covs_[treat_ == 0, , drop = FALSE]))]

      covs_ <- covs_[, colnames(covs_) %nin% colinear.covs.to.remove, drop = FALSE]

      covs_[treat_ == 1,] <- covs_[treat_ == 1,] * s.weights[treat == focal] * sum(treat == focal)/ sum(s.weights[treat == focal])

      Y <- rep(0, length(treat_))

      ate.out <- ATE::ATE(Y = Y, Ti = treat_, X = covs_,
                          ATT = TRUE,
                          theta = A[["theta"]],
                          verbose = TRUE,
                          max.iter = A[["max.iter"]],
                          tol = A[["tol"]],
                          initial.values = A[["initial.values"]],
                          backtrack = A[["backtrack"]],
                          backtrack.alpha = A[["backtrack.alpha"]],
                          backtrack.beta = A[["backtrack.beta"]])
      w[treat == i] <- ate.out$weights.q[treat_ == 0] / s.weights[treat == i]
      fit.list[[i]] <- ate.out
    }
  }
  else if (estimand == "ATE") {
    w <- rep(1, length(treat))
    fit.list <- make_list(levels(treat))

    for (i in levels(treat)) {
      covs_i <- rbind(covs, covs[treat==i, , drop = FALSE])
      treat_i <- c(rep(1, nrow(covs)), rep(0, sum(treat==i)))

      colinear.covs.to.remove <- colnames(covs_i)[colnames(covs_i) %nin% colnames(make_full_rank(covs_i[treat_i == 0, , drop = FALSE]))]

      covs_i <- covs_i[, colnames(covs_i) %nin% colinear.covs.to.remove, drop = FALSE]

      covs_i[treat_i == 1,] <- covs_i[treat_i == 1,] * s.weights * sum(treat_i == 1) / sum(s.weights)

      Y <- rep(0, length(treat_i))

      ate.out <- ATE::ATE(Y = Y, Ti = treat_i, X = covs_i,
                          ATT = TRUE,
                          theta = A[["theta"]],
                          verbose = TRUE,
                          max.iter = A[["max.iter"]],
                          tol = A[["tol"]],
                          initial.values = A[["initial.values"]],
                          backtrack = A[["backtrack"]],
                          backtrack.alpha = A[["backtrack.alpha"]],
                          backtrack.beta = A[["backtrack.beta"]])
      w[treat == i] <- ate.out$weights.q[treat_i == 0] / s.weights[treat == i]
      fit.list[[i]] <- ate.out
    }
  }
  if (length(fit.list) == 1) fit.list <- fit.list[[1]]
  obj <- list(w = w, fit.obj = fit.list)
  return(obj)

}
ngreifer/WeightIt documentation built on March 6, 2025, 2:04 a.m.