Nothing
# predict_counterfactual_mean
test_that("`predict_counterfactual_mean` predicts correctly", {
treat_diff <- 10
dat <- data.frame(
Y = 1:(2*treat_diff),
X = rep(1:treat_diff, 2),
A = c(rep(0, treat_diff), rep(1, treat_diff)
)
)
mod <- glm(Y ~ X + A, data = dat)
pred0 <- predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "A",
group_val = 0)
pred1 <- predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "A",
group_val = 1)
expect_equal(pred0, pred1 - treat_diff)
})
test_that("`predict_counterfactual_mean` gives error when `exposure_indicator_name` not in model or data", {
dat <- data.frame(
Y = 1:10,
X = 1:10,
A = c(rep(0, 5), rep(1, 5)
)
)
mod <- glm(Y ~ X + A, data = dat)
expect_error(
predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "test",
group_val = 0),
regexp = "is not in"
)
})
test_that("`predict_counterfactual_mean` works with and without data specification", {
dat_fit <- data.frame(
Y = 1:10,
X = 1:10,
A = c(rep(0, 5), rep(1, 5)
)
)
mod <- glm(Y ~ X + A, data = dat_fit)
pred_nodatspec <- predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "A",
group_val = 0)
pred_datspec <- predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "A",
group_val = 0,
data = dat_fit)
expect_equal(pred_nodatspec, pred_datspec)
dat_pred <- data.frame(
Y = 1:11,
X = -5:5
)
pred_newdata <- predict_counterfactual_mean(
model = mod,
exposure_indicator_name = "A",
group_val = 0,
data = dat_pred)
expect_type(pred_newdata, "double")
})
# predict_counterfactual_means
test_that("`predict_counterfactual_mean` predicts correctly", {
treat_diff <- 10
dat <- data.frame(
Y = 1:(2*treat_diff),
X = rep(1:treat_diff, 2),
A = c(rep(0, treat_diff), rep(1, treat_diff)
)
)
mod <- glm(Y ~ X + A, data = dat)
preds <- predict_counterfactual_means(
model = mod,
exposure_indicator_name = "A")
expect_s3_class(preds, "data.frame")
expect_equal(preds$psi0, preds$psi1 - treat_diff)
})
# default_estimand_funs
test_that("`default_estimand_funs` switches correctly", {
default_ate <- default_estimand_funs()
specify_ate <- default_estimand_funs("ate")
specify_rr <- default_estimand_funs("rate_ratio")
expect_identical(default_ate, specify_ate)
expect_equal(default_ate(2, 1), 1)
expect_equal(specify_rr(2, .5), 4)
})
test_that("`default_estimand_funs` error when giving non-legal default", {
expect_error(default_estimand_funs("test"),
"should be one of")
})
# oos_fitted.values_counterfactual
test_that("`oos_fitted.values_counterfactual` snapshot test", {
dat <- data.frame(
Y = 1:10,
X = 1:10,
A = c(rep(0, 5), rep(1, 5)
)
)
args_glm <- list(
formula = formula(Y ~ X + A)
)
oos <- oos_fitted.values_counterfactual(
data = dat,
exposure_indicator_name = "A",
full_model.args_glm = args_glm
)
expect_named(oos, c("psi0", "psi1", "rowname"))
expect_s3_class(oos, "data.frame")
expect_snapshot(oos)
})
# extract_train_test
test_that("`extract_train_test` returns list of train and test data", {
dat <- data.frame(
Y = 1:10,
X = 1:10,
A = c(rep(0, 5), rep(1, 5)
)
)
withr::local_seed(42)
folds <- rsample::vfold_cv(dat, v = 2)
single_fold <- folds$splits[[1]]
list_of_train_test <- extract_train_test(single_fold)
expect_type(list_of_train_test, "list")
expect_named(list_of_train_test, c("train", "test"))
expect_snapshot(list_of_train_test)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.