tests/testthat/test-rctglm_utils.R

# predict_counterfactual_mean
test_that("`predict_counterfactual_mean` predicts correctly", {
  treat_diff <- 10
  dat <- data.frame(
    Y = 1:(2*treat_diff),
    X = rep(1:treat_diff, 2),
    A = c(rep(0, treat_diff), rep(1, treat_diff)
    )
  )
  mod <- glm(Y ~ X + A, data = dat)
  pred0 <- predict_counterfactual_mean(
    model = mod,
    exposure_indicator_name = "A",
    group_val = 0)

  pred1 <- predict_counterfactual_mean(
    model = mod,
    exposure_indicator_name = "A",
    group_val = 1)

  expect_equal(pred0, pred1 - treat_diff)
})

test_that("`predict_counterfactual_mean` gives error when `exposure_indicator_name` not in model or data", {
  dat <- data.frame(
    Y = 1:10,
    X = 1:10,
    A = c(rep(0, 5), rep(1, 5)
    )
  )
  mod <- glm(Y ~ X + A, data = dat)

  expect_error(
    predict_counterfactual_mean(
      model = mod,
      exposure_indicator_name = "test",
      group_val = 0),
    regexp = "is not in"
  )
})

test_that("`predict_counterfactual_mean` works with and without data specification", {
  dat_fit <- data.frame(
    Y = 1:10,
    X = 1:10,
    A = c(rep(0, 5), rep(1, 5)
    )
  )
  mod <- glm(Y ~ X + A, data = dat_fit)

  pred_nodatspec <- predict_counterfactual_mean(
    model = mod,
    exposure_indicator_name = "A",
    group_val = 0)
  pred_datspec <- predict_counterfactual_mean(
    model = mod,
    exposure_indicator_name = "A",
    group_val = 0,
    data = dat_fit)

  expect_equal(pred_nodatspec, pred_datspec)

  dat_pred <- data.frame(
    Y = 1:11,
    X = -5:5
  )
  pred_newdata <- predict_counterfactual_mean(
    model = mod,
    exposure_indicator_name = "A",
    group_val = 0,
    data = dat_pred)

  expect_type(pred_newdata, "double")
})

# predict_counterfactual_means
test_that("`predict_counterfactual_mean` predicts correctly", {
  treat_diff <- 10
  dat <- data.frame(
    Y = 1:(2*treat_diff),
    X = rep(1:treat_diff, 2),
    A = c(rep(0, treat_diff), rep(1, treat_diff)
    )
  )
  mod <- glm(Y ~ X + A, data = dat)
  preds <- predict_counterfactual_means(
    model = mod,
    exposure_indicator_name = "A")

  expect_s3_class(preds, "data.frame")
  expect_equal(preds$psi0, preds$psi1 - treat_diff)
})

# default_estimand_funs
test_that("`default_estimand_funs` switches correctly", {
  default_ate <- default_estimand_funs()
  specify_ate <- default_estimand_funs("ate")
  specify_rr <- default_estimand_funs("rate_ratio")

  expect_identical(default_ate, specify_ate)
  expect_equal(default_ate(2, 1), 1)
  expect_equal(specify_rr(2, .5), 4)
})

test_that("`default_estimand_funs` error when giving non-legal default", {
  expect_error(default_estimand_funs("test"),
               "should be one of")
})

# oos_fitted.values_counterfactual
test_that("`oos_fitted.values_counterfactual` snapshot test", {
  dat <- data.frame(
    Y = 1:10,
    X = 1:10,
    A = c(rep(0, 5), rep(1, 5)
    )
  )

  args_glm <- list(
    formula = formula(Y ~ X + A)
  )

  oos <- oos_fitted.values_counterfactual(
    data = dat,
    exposure_indicator_name = "A",
    full_model.args_glm = args_glm
  )
  expect_named(oos, c("psi0", "psi1", "rowname"))
  expect_s3_class(oos, "data.frame")
  expect_snapshot(oos)
})

# extract_train_test
test_that("`extract_train_test` returns list of train and test data", {
  dat <- data.frame(
    Y = 1:10,
    X = 1:10,
    A = c(rep(0, 5), rep(1, 5)
    )
  )

  withr::local_seed(42)
  folds <- rsample::vfold_cv(dat, v = 2)
  single_fold <- folds$splits[[1]]
  list_of_train_test <- extract_train_test(single_fold)
  expect_type(list_of_train_test, "list")
  expect_named(list_of_train_test, c("train", "test"))
  expect_snapshot(list_of_train_test)
})

Try the postcard package in your browser

Any scripts or data that you put into this service are public.

postcard documentation built on April 12, 2025, 1:57 a.m.