tests/testthat/test-psw.R

library(testthat)
library(vctrs)

test_that("new_psw creates valid psw objects", {
  x <- new_psw(c(0.1, 0.2, 0.3), estimand = "ate")
  expect_s3_class(
    x,
    c("psw", "causal_wts", "vctrs_vctr", "double"),
    exact = TRUE
  )
  expect_equal(vec_data(x), c(0.1, 0.2, 0.3))
  expect_equal(estimand(x), "ate")
})

test_that("psw helper function works correctly", {
  x <- psw(c(0.1, 0.2, 0.3), estimand = "att")
  expect_s3_class(
    x,
    c("psw", "causal_wts", "vctrs_vctr", "double"),
    exact = TRUE
  )
  expect_equal(vec_data(x), c(0.1, 0.2, 0.3))
  expect_equal(estimand(x), "att")
  expect_false(is_stabilized(x))
  expect_false(is_ps_trimmed(x))
  expect_false(is_ps_truncated(x))
  estimand(x) <- "ATT!"
  expect_equal(estimand(x), "ATT!")
})

test_that("is_psw identifies psw objects", {
  x <- psw(c(0.1, 0.2), estimand = "ato")
  expect_true(is_psw(x))
  expect_false(is_psw(c(0.1, 0.2)))
})

test_that("vec_ptype_abbr and vec_ptype_full return correct type labels", {
  x <- psw(c(0.1, 0.2), estimand = "atm")
  expect_equal(vec_ptype_abbr(x), "psw{atm}")
  expect_equal(vec_ptype_full(x), "psw{estimand = atm}")

  z <- psw(c(0.1, 0.2))
  expect_equal(vec_ptype_abbr(z), "psw")
  expect_equal(vec_ptype_full(z), "psw{estimand = unknown}")

  y <- psw(c(0.1, 0.2), estimand = "cens")
  x <- x * y
  expect_equal(vec_ptype_abbr(x), "psw{atm, cens}")
  expect_equal(vec_ptype_full(x), "psw{estimand = atm, cens}")
})

test_that("vec_cast works for psw and double", {
  x <- psw(c(0.1, 0.2, 0.3), estimand = "ate")

  # Cast to double
  double_x <- vec_cast(x, double())
  expect_equal(double_x, c(0.1, 0.2, 0.3))
  expect_true(is.numeric(double_x))

  # Cast back to psw
  psw_x <- as_psw(double_x, estimand = "ate")
  expect_s3_class(
    psw_x,
    c("psw", "causal_wts", "vctrs_vctr", "double"),
    exact = TRUE
  )
  expect_equal(vec_data(psw_x), double_x)
  expect_equal(estimand(psw_x), "ate")
})

test_that("vec_cast works for psw and integer with precision checks", {
  x <- psw(c(1, 2, 3), estimand = "ate")

  # Cast to integer
  int_x <- vec_cast(x, integer())
  expect_equal(int_x, as.integer(c(1, 2, 3)))
  expect_true(is.integer(int_x))

  # Cast back to psw
  psw_x <- as_psw(int_x, estimand = "ate")
  expect_s3_class(
    psw_x,
    c("psw", "causal_wts", "vctrs_vctr", "double"),
    exact = TRUE
  )
  expect_equal(vec_data(psw_x), as.numeric(int_x))
  expect_equal(estimand(psw_x), "ate")

  # Fail when precision is lost
  x_with_decimals <- psw(c(1.1, 2.2, 3.3), estimand = "ate")
  expect_propensity_error(
    vec_cast(x_with_decimals, integer())
  )
})

test_that("vec_ptype2 combines psw and other types correctly", {
  x <- psw(c(0.1, 0.2), estimand = "ate")
  y <- psw(c(0.3, 0.4), estimand = "ate")
  expect_equal(vec_ptype2(x, y), new_psw(estimand = "ate"))

  # Different estimands should warn and return numeric
  z <- psw(c(0.5, 0.6), estimand = "att")
  expect_propensity_warning(
    result <- vec_ptype2(x, z)
  )
  expect_identical(result, double())

  # Combining with double
  expect_propensity_warning(
    expect_equal(vec_ptype2(x, double()), double())
  )
  expect_propensity_warning(
    expect_equal(vec_ptype2(double(), x), double())
  )

  # Combining with integer
  expect_propensity_warning(
    expect_equal(vec_ptype2(x, integer()), integer())
  )
  expect_propensity_warning(
    expect_equal(vec_ptype2(integer(), x), integer())
  )
})

test_that("vec_arith performs arithmetic on psw objects", {
  x <- psw(c(1, 2, 3), estimand = "ate")
  y <- psw(c(0.5, 1.5, 2.5), estimand = "ate")
  cens <- psw(c(3, 2, 1), estimand = "cens")

  classes <- c("psw", "causal_wts", "vctrs_vctr", "double")

  # same estimand
  result <- x + y
  expect_s3_class(result, classes, exact = TRUE)
  expect_equal(vec_data(result), c(1.5, 3.5, 5.5))
  expect_equal(estimand(result), "ate")

  # different estimand
  result <- x * cens
  expect_s3_class(result, classes, exact = TRUE)
  expect_equal(vec_data(result), c(3, 4, 3))
  expect_equal(estimand(result), "ate, cens")

  # Arithmetic with double
  result <- x * 2
  expect_s3_class(result, classes, exact = TRUE)
  expect_equal(vec_data(result), c(2, 4, 6))
  expect_equal(estimand(result), "ate")

  # Arithmetic with integer
  result <- x - 1L
  expect_s3_class(result, classes, exact = TRUE)
  expect_equal(vec_data(result), c(0, 1, 2))
  expect_equal(estimand(result), "ate")

  # various combos work in various orders
  expect_no_error(x * 2.1)
  expect_no_error(2.1 * x)
  expect_no_error(x / 2.1)
  expect_no_error(2.1 / x)
  expect_no_error(x + 2.1)
  expect_no_error(2.1 + x)
  expect_no_error(x * 2L)
  expect_no_error(2L * x)
  expect_no_error(x / 2L)
  expect_no_error(2L / x)
  expect_no_error(x + 2L)
  expect_no_error(2L + x)
  expect_no_error(x^2)

  # doesn't work with unsupported types
  thing <- new_vctr(runif(10), class = "thing")
  expect_propensity_error(x * thing)
  expect_propensity_error(x * list(runif(10)))
})

test_that("vec_math applies math functions to psw objects", {
  x <- psw(c(1, 4, 9), estimand = "ate")

  expect_false(is_psw(sqrt(x)))
  expect_false(is_psw(sum(x)))
})

test_that("Combination of arithmetic and math works correctly for psw", {
  # Example data
  y <- c(2, 4, 6, 8)
  x <- c(0, 1, 0, 1)
  wts <- psw(c(0.5, 1.0, 1.5, 2.0), estimand = "ate")

  # Compute expected results using double weights
  wts_double <- vec_data(wts) # Strip the psw class to use as raw weights

  expected_term1 <- sum(y * x * wts_double) / sum(x * wts_double)
  expected_term2 <- sum(y * (1 - x) * wts_double) / sum((1 - x) * wts_double)

  # Compute the actual results using psw weights
  term1 <- sum(y * x * wts) / sum(x * wts)
  term2 <- sum(y * (1 - x) * wts) / sum((1 - x) * wts)

  # Validate the results
  expect_equal(term1, expected_term1, tolerance = 1e-8)
  expect_equal(term2, expected_term2, tolerance = 1e-8)

  # Check for consistency in repeated calculations
  repeated_term2 <- sum(y * (1 - x) * wts) / sum((1 - x) * wts)
  expect_equal(repeated_term2, term2, tolerance = 1e-8)
})

test_that("Refit logic can be tracked if stored in attr", {
  # If you store a 'refit' attribute in the final psw object:
  w <- psw(c(0.1, 0.2), estimand = "ate", trimmed = TRUE)
  attr(w, "ps_trim_meta") <- list(refit = TRUE)
  # Suppose we define is_refit.psw() as checking attr(x, "refit")
  expect_true(is_refit(w))
})

test_that("psw objects can convert to character", {
  x <- as.character(new_psw(c(0.1, 0.2, 0.3), estimand = "ate"))
  expect_type(x, "character")
})

test_that("psw works with ggplot2", {
  skip_if_not_installed("ggplot2")
  expect_silent(
    type <- ggplot2::scale_type(psw(1))
  )

  expect_identical(type, "continuous")
})

Try the propensity package in your browser

Any scripts or data that you put into this service are public.

propensity documentation built on March 3, 2026, 1:06 a.m.