tests/testthat/test-ps-weights.R

# Tests for propensity score estimation and weight calculation
# These tests call internal functions directly

test_that("PS estimates are in [0, 1] for binary treatment", {
  data <- make_test_data_binary(n = 150)

  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Check all PS values in [0, 1]
  expect_true(all(ps_result$ps_matrix >= 0))
  expect_true(all(ps_result$ps_matrix <= 1))

  # Check observed PS in [0, 1]
  expect_true(all(ps_result$ps >= 0))
  expect_true(all(ps_result$ps <= 1))
})

test_that("PS rows sum to 1 for multiple treatment groups", {
  data <- make_test_data_multi(n = 180)

  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Check rows sum to 1 (ignore names from rowSums)
  row_sums <- rowSums(ps_result$ps_matrix)
  expect_equal(as.numeric(row_sums), rep(1, nrow(data)), tolerance = 1e-6)
})

test_that("Binary treatment uses glm for PS model", {
  data <- make_test_data_binary(n = 150)

  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Check model object is glm
  expect_s3_class(ps_result$ps_model, "glm")
})

test_that("Multiple treatment uses multinom for PS model", {
  data <- make_test_data_multi(n = 180)

  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Check model object is multinom (from nnet package)
  expect_s3_class(ps_result$ps_model, "multinom")
})

test_that("Overlap weights are in [0, 1]", {
  data <- make_test_data_binary(n = 150)

  # First estimate PS
  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Then estimate overlap weights
  weight_result <- estimate_weights(
    ps_result = ps_result,
    data = data,
    treatment_var = "Z",
    estimand = "overlap",
    att_group = NULL,
    trim = NULL,
    delta = NULL,
    alpha = NULL
  )

  # Check all weights in [0, 1]
  expect_true(all(weight_result$weights >= 0))
  expect_true(all(weight_result$weights <= 1))
})

test_that("ATT weights: treated group has weight = 1", {
  data <- make_test_data_binary(n = 150)

  # First estimate PS
  ps_result <- estimate_ps(
    data = data,
    treatment_var = "Z",
    ps_formula = Z ~ X1 + X2 + B1,
    ps_control = list()
  )

  # Estimate ATT weights targeting group B
  weight_result <- estimate_weights(
    ps_result = ps_result,
    data = data,
    treatment_var = "Z",
    estimand = "ATT",
    att_group = "B",
    trim = NULL,
    delta = NULL,
    alpha = NULL
  )

  # Check weights for treated group (B) are all 1
  is_treated <- data$Z == "B"
  expect_equal(weight_result$weights[is_treated], rep(1, sum(is_treated)), tolerance = 1e-6)

  # Check att_group recorded
  expect_equal(weight_result$att_group, "B")
})

Try the PSsurvival package in your browser

Any scripts or data that you put into this service are public.

PSsurvival documentation built on Dec. 9, 2025, 9:07 a.m.