tests/testthat/helper-M_est.R

expect_M_parts_okay <- function(W, tolerance = 1e-5, ...) {
  Mparts.list <- {
    if (is_not_null(attr(W, "Mparts", exact = TRUE))) {
      Mparts.list <- list(attr(W, "Mparts"))
    }
    else if (is_not_null(attr(W, "Mparts.list", exact = TRUE))) {
      Mparts.list <- attr(W, "Mparts.list", exact = TRUE)
    }
    else {
      NULL
    }
  }

  expect_false(is_null(Mparts.list))

  psi_treat.list <- lapply(Mparts.list, `[[`, "psi_treat")
  wfun.list <- lapply(Mparts.list, `[[`, "wfun")
  Xtreat.list <- lapply(Mparts.list, `[[`, "Xtreat")
  A.list <- lapply(Mparts.list, `[[`, "A")
  btreat.list <- lapply(Mparts.list, `[[`, "btreat")
  SW <- W$s.weights

  psi_treat <- function(Btreat.list, A.list, Xtreat.list, SW) {
    do.call("cbind", lapply(seq_along(Btreat.list), function(i) {
      psi_treat.list[[i]](Btreat.list[[i]], A.list[[i]], Xtreat.list[[i]], SW)
    }))
  }

  wfun <- function(Btreat.list, A.list, Xtreat.list) {
    Reduce("*", lapply(seq_along(Btreat.list), function(i) {
      wfun.list[[i]](Btreat.list[[i]], Xtreat.list[[i]], A.list[[i]])
    }), init = 1)
  }

  psi <- function(B, A.list, Xtreat.list, SW) {
    Btreat.list <- btreat.list
    k <- 0
    for (i in seq_along(btreat.list)) {
      Btreat.list[[i]] <- B[k + seq_along(btreat.list[[i]])]
      k <- k + length(btreat.list[[i]])
    }

    psi_treat(Btreat.list, A.list, Xtreat.list, SW)
  }

  gradfun <- function(B, A, X, SW) {
    colSums(psi(B, A, X, SW))
  }

  start <- 1.01 * unlist(btreat.list)

  out <- rootSolve::multiroot(gradfun, start = start,
                              X = Xtreat.list, A = A.list, SW = SW,
                              maxiter = 1e5)

  Btreat.list <- btreat.list
  k <- 0
  for (i in seq_along(btreat.list)) {
    Btreat.list[[i]] <- out$root[k + seq_along(btreat.list[[i]])]
    k <- k + length(btreat.list[[i]])
  }

  w <- wfun(Btreat.list, A.list, Xtreat.list)

  expect_equal(unname(unlist(Btreat.list)), unname(unlist(btreat.list)),
               tolerance = tolerance, ...)

  expect_equal(unname(as.vector(w)), unname(as.vector(W$weights)),
               tolerance = tolerance, ...)

  invisible(list(solve = out,
                 b = out$root,
                 weights = w
  ))
}

Try the WeightIt package in your browser

Any scripts or data that you put into this service are public.

WeightIt documentation built on May 29, 2024, 9:48 a.m.