tests/testthat/test-ps_trunc.R

test_that("ps_trunc() - PS method uses fixed bounds", {
  set.seed(123)
  ps <- c(0.01, 0.1, 0.3, 0.8, 0.95)
  out <- ps_trunc(ps, method = "ps", lower = 0.2, upper = 0.8)

  expect_s3_class(out, "ps_trunc")
  out_data <- as.numeric(out)

  # Values <0.2 => 0.2, >0.8 => 0.8
  expect_equal(out_data[1], 0.2) # was 0.01
  expect_equal(out_data[2], 0.2) # was 0.1
  expect_equal(out_data[3], 0.3) # stays 0.3
  expect_equal(out_data[4], 0.8) # stays 0.8
  expect_equal(out_data[5], 0.8) # was 0.95 => truncated
})

test_that("ps_trunc() - pctl method uses quantiles", {
  set.seed(1)
  n <- 50
  ps <- plogis(rnorm(n, 0, 1.2))

  # default [0.05, 0.95]
  out1 <- ps_trunc(ps, method = "pctl")
  meta1 <- ps_trunc_meta(out1)
  expect_equal(meta1$lower_pctl, 0.05)
  expect_equal(meta1$upper_pctl, 0.95)
  out1_data <- as.numeric(out1)

  # Check boundary
  q_l <- quantile(ps, probs = 0.05)
  q_u <- quantile(ps, probs = 0.95)
  expect_true(all(out1_data >= q_l - 1e-8))
  expect_true(all(out1_data <= q_u + 1e-8))

  # custom [0.2, 0.8]
  out2 <- ps_trunc(ps, method = "pctl", lower = 0.2, upper = 0.8)
  meta2 <- ps_trunc_meta(out2)
  expect_equal(meta2$lower_pctl, 0.2)
  expect_equal(meta2$upper_pctl, 0.8)
  out2_data <- as.numeric(out2)
  q_l2 <- quantile(ps, probs = 0.2)
  q_u2 <- quantile(ps, probs = 0.8)

  # Everything below q_l2 => replaced with q_l2
  expect_true(all(out2_data >= q_l2 - 1e-8))
  expect_true(all(out2_data <= q_u2 + 1e-8))
})

test_that("ps_trunc() - cr method uses min(ps_treat)/max(ps_untrt)", {
  set.seed(2)
  n <- 30
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.4 * x))
  fit <- glm(z ~ x, family = binomial)
  ps <- predict(fit, type = "response")

  out_cr <- ps_trunc(ps, .exposure = z, method = "cr", .focal_level = 1)
  expect_s3_class(out_cr, "ps_trunc")

  meta_cr <- ps_trunc_meta(out_cr)
  ps_treat <- ps[z == 1]
  ps_untrt <- ps[z == 0]
  cr_lower <- min(ps_treat)
  cr_upper <- max(ps_untrt)
  expect_equal(meta_cr$lower_bound, cr_lower)
  expect_equal(meta_cr$upper_bound, cr_upper)

  # check bounding
  out_data <- as.numeric(out_cr)
  expect_true(all(out_data >= cr_lower - 1e-8))
  expect_true(all(out_data <= cr_upper + 1e-8))
})

test_that("ps_trunc() errors on invalid usage or .exposure", {
  # if method="cr" but no .exposure => error
  expect_propensity_error(
    ps_trunc(runif(10), method = "cr")
  )

  # if .exposure not 0/1 => error
  expect_propensity_error(
    ps_trunc(runif(5), .exposure = 1:5, method = "cr")
  )

  # if lower >= upper => error for method="ps"
  expect_propensity_error(
    ps_trunc(runif(5), method = "ps", lower = 0.8, upper = 0.3)
  )
})

test_that("Truncation workflow yields truncated psw with no refit logic", {
  set.seed(888)
  n <- 10
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.4 * x))

  # 1) Fit logistic model
  fit <- glm(z ~ x, family = binomial)
  ps <- predict(fit, type = "response")

  # 2) Truncate (winsorize) the PS
  truncated_ps <- ps_trunc(ps, method = "ps", lower = 0.2, upper = 0.8)
  expect_s3_class(truncated_ps, "ps_trunc")

  # 3) Compute ATE weights
  w_ate <- wt_ate(
    truncated_ps,
    .exposure = z,
    exposure_type = "binary",
    .focal_level = 1
  )
  expect_s3_class(w_ate, "psw")

  # 4) Verify truncated, not trimmed, not refit, estimand
  expect_true(is_ps_truncated(w_ate))
  expect_false(is_ps_trimmed(w_ate))
  expect_false(is_refit(w_ate))
  expect_match(estimand(w_ate), "; truncated$")
})

test_that("is_ps_truncated.default() -> FALSE, is_ps_truncated.ps_trunc() -> TRUE", {
  # 1) A plain numeric => default => FALSE
  expect_false(is_ps_truncated(runif(5)))

  # 2) A simple ps_trunc object => is_ps_truncated(...) => TRUE
  # Create via new_ps_trunc()
  my_trunc <- new_ps_trunc(
    x = c(0.2, 0.6, 0.8),
    meta = list(method = "ps", lower_bound = 0.2, upper_bound = 0.8)
  )
  expect_true(is_ps_truncated(my_trunc))
})

test_that("Arithmetic on ps_trunc returns numeric", {
  obj <- new_ps_trunc(
    x = c(0.2, 0.7, 0.9),
    meta = list(method = "ps", lower_bound = 0.2, upper_bound = 0.8)
  )

  # Arithmetic operations should return numeric
  expect_type(obj + 1, "double")
  expect_type(1 + obj, "double")
  expect_type(obj * 2, "double")
  expect_type(1 / obj, "double")

  # Verify values are correct
  expect_equal(obj + 1, c(1.2, 1.7, 1.9))
  expect_equal(1 / obj, c(5, 10 / 7, 10 / 9))

  # Combining two ps_trunc also returns numeric
  obj2 <- new_ps_trunc(
    x = c(0.1, 0.1, 0.3),
    meta = list(method = "ps", lower_bound = 0.1, upper_bound = 0.5)
  )
  expect_type(obj * obj2, "double")
  expect_equal(obj * obj2, c(0.02, 0.07, 0.27))
})

test_that("Combining & casting ps_trunc => correct ptype2, cast behavior", {
  obj <- new_ps_trunc(
    x = c(0.2, 0.6, 0.8),
    meta = list(method = "ps", lower_bound = 0.2, upper_bound = 0.8)
  )
  # 1) Combining two ps_trunc => error
  obj2 <- new_ps_trunc(
    x = c(0.4, 0.5, 0.7),
    meta = list(method = "ps", lower_bound = 0.3, upper_bound = 0.8)
  )

  # 3) Casting ps_trunc -> double => numeric data
  out_cast <- vctrs::vec_cast(obj, double())
  expect_type(out_cast, "double")
  expect_identical(out_cast, c(0.2, 0.6, 0.8))

  # 4) Casting double -> ps_trunc => new default meta
  new_vals <- runif(3)
  out_ps_trunc <- vctrs::vec_cast(new_vals, to = obj)
  expect_s3_class(out_ps_trunc, "ps_trunc")
  meta_new <- ps_trunc_meta(out_ps_trunc)
  expect_equal(meta_new$method, "unknown") # per your code
  expect_true(is.na(meta_new$lower_bound))
  expect_true(is.na(meta_new$upper_bound))
})

test_that("wt_atm.numeric calls atm_binary() for binary .exposure, returns psw", {
  set.seed(101)
  n <- 8
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.3 * x))

  # A numeric PS
  ps <- plogis(0.4 * x)

  # 1) Binary .exposure => calls atm_binary() => returns psw
  out_atm <- wt_atm.numeric(
    .propensity = ps,
    .exposure = z,
    exposure_type = "binary",
    .focal_level = 1
  )
  # Check it's a psw object with estimand "atm"
  expect_s3_class(out_atm, "psw")
  expect_equal(estimand(out_atm), "atm")
})

test_that("atm_binary() logic with transform_.exposure_binary() is triggered", {
  # atm_binary => pmin(ps, 1-ps) / (.exposure*ps + (1-.exposure)*(1-ps))
  ps_vec <- c(0.2, 0.8, 0.5)
  z_vec <- c(0, 1, 1)

  w <- atm_binary(
    .propensity = ps_vec,
    .exposure = z_vec,
    .focal_level = 1
  )
  # Just check dimension, no error
  expect_length(w, 3)

  # If .exposure isn't 0/1 or has different factor levels, transform_.exposure_binary
  # is tested. We'll do a quick check with factor( c("C","T","T") )
  w2 <- atm_binary(
    .propensity = ps_vec,
    .exposure = factor(c("C", "T", "T")),
    .focal_level = "T",
    .reference_level = "C"
  )
  expect_length(w2, 3)
})

test_that("wt_ato.numeric calls ato_binary() for binary .exposure, returns psw", {
  set.seed(202)
  n <- 6
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.6 * x))
  ps <- plogis(0.3 * x)

  # 1) Binary => calls ato_binary => returns psw
  out_ato <- wt_ato.numeric(
    .propensity = ps,
    .exposure = z,
    exposure_type = "binary",
    .focal_level = 1
  )
  expect_s3_class(out_ato, "psw")
  expect_equal(estimand(out_ato), "ato")
})

test_that("ato_binary() logic is triggered for p=0.3", {
  # (1 - p)*.exposure + p*(1-.exposure)
  ps_vec <- c(0.1, 0.9, 0.5)
  z_vec <- c(0, 1, 1)

  w <- ato_binary(
    .propensity = ps_vec,
    .exposure = z_vec,
    .focal_level = 1
  )
  expect_length(w, 3)
  # Just check no error, correct length
})

test_that("wt_atm.ps_trunc synergy with truncated object yields truncated psw", {
  set.seed(303)
  n <- 6
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.2 * x))

  ps <- plogis(0.7 * x)
  # Make a truncated object (like bounding ps in [0.2,0.8])
  trunc_obj <- ps_trunc(ps, method = "ps", lower = 0.2, upper = 0.8)

  # Now call wt_atm() on ps_trunc => dispatches wt_atm.ps_trunc()
  w_atm <- wt_atm(
    trunc_obj,
    .exposure = z,
    exposure_type = "binary",
    .focal_level = 1
  )
  expect_s3_class(w_atm, "psw")
  # Estimand => "atm; truncated"
  expect_match(estimand(w_atm), "atm; truncated")
  # truncated=TRUE
  expect_true(is_ps_truncated(w_atm))
})

test_that("wt_ato.ps_trunc synergy with truncated object yields truncated psw", {
  set.seed(404)
  n <- 7
  x <- rnorm(n)
  z <- rbinom(n, 1, plogis(0.1 * x))
  ps <- plogis(0.5 * x)

  # bounding p in [0.1, 0.9], e.g.
  trunc_obj <- ps_trunc(ps, method = "ps", lower = 0.1, upper = 0.9)
  w_ato <- wt_ato(
    trunc_obj,
    .exposure = z,
    exposure_type = "binary",
    .focal_level = 1
  )

  expect_s3_class(w_ato, "psw")
  expect_match(estimand(w_ato), "ato; truncated")
  expect_true(is_ps_truncated(w_ato))
})

test_that("is_unit_truncated.ps_trunc returns expected row-level booleans", {
  set.seed(101)
  ps_vec <- c(0.1, 0.2, 0.5, 0.85, 0.95)

  # Truncate outside [0.2, 0.8]
  truncated_obj <- ps_trunc(
    ps_vec,
    method = "ps",
    lower = 0.2,
    upper = 0.8
  )

  expect_s3_class(truncated_obj, "ps_trunc")

  row_trunc <- is_unit_truncated(truncated_obj)
  expect_type(row_trunc, "logical")
  expect_length(row_trunc, length(ps_vec))

  truncated_data <- as.numeric(truncated_obj)
  expect_equal(which(row_trunc), c(1, 4, 5))
  expect_equal(truncated_data, c(0.2, 0.2, 0.5, 0.8, 0.8))
})

test_that("ps_trunc objects can convert to character", {
  ps <- c(0.01, 0.1, 0.3, 0.8, 0.95)
  out <- as.character(ps_trunc(ps, method = "ps", lower = 0.2, upper = 0.8))
  expect_type(out, "character")
})

test_that("ps_trunc works with summarize(mean = mean(ps))", {
  skip_if_not_installed("dplyr")
  library(dplyr, warn.conflicts = FALSE)

  set.seed(200)
  n <- 600
  x <- rnorm(n)
  z <- rbinom(n, size = 1, prob = plogis(x + rnorm(n)))
  fit <- glm(z ~ x, family = binomial)

  ps <- predict(fit, type = "response") |>
    ps_trunc(method = "ps", lower = 0.3, upper = 0.7)

  out <- tibble(x, z, ps) |>
    group_by(truncated = is_unit_truncated(ps)) |>
    summarize(mean = mean(ps), .groups = "drop")

  expect_s3_class(out, "tbl_df")
  expect_named(out, c("truncated", "mean"))
  expect_type(out$mean, "double")
})

test_that("ps_trunc vec_ptype_full output matches expected format", {
  set.seed(123)
  ps <- runif(20, 0.05, 0.95)

  # Create ps_trunc with some values truncated
  ps_trunc_obj <- ps_trunc(ps, method = "ps", lower = 0.2, upper = 0.8)

  # Test the vec_ptype_full output - should show bounds and method
  expect_equal(
    vctrs::vec_ptype_full(ps_trunc_obj),
    "ps_trunc{[0.2,0.8], method=ps}"
  )

  # Test with different bounds
  ps_trunc_narrow <- ps_trunc(ps, method = "ps", lower = 0.4, upper = 0.6)
  expect_equal(
    vctrs::vec_ptype_full(ps_trunc_narrow),
    "ps_trunc{[0.4,0.6], method=ps}"
  )

  # Test with very wide bounds (no actual truncation)
  ps_trunc_wide <- ps_trunc(ps, method = "ps", lower = 0.01, upper = 0.99)
  expect_equal(
    vctrs::vec_ptype_full(ps_trunc_wide),
    "ps_trunc{[0.01,0.99], method=ps}"
  )
})

test_that("ps_trunc index tracking works when combining objects", {
  set.seed(456)
  ps1 <- runif(10, 0.05, 0.95)
  ps2 <- runif(10, 0.05, 0.95)

  # Create ps_trunc objects with same parameters
  ps_trunc1 <- ps_trunc(ps1, method = "ps", lower = 0.2, upper = 0.8)
  ps_trunc2 <- ps_trunc(ps2, method = "ps", lower = 0.2, upper = 0.8)

  # Get original truncated indices
  meta1 <- ps_trunc_meta(ps_trunc1)
  meta2 <- ps_trunc_meta(ps_trunc2)
  n_truncated1 <- length(meta1$truncated_idx)
  n_truncated2 <- length(meta2$truncated_idx)

  # Combine the objects
  combined <- c(ps_trunc1, ps_trunc2)

  # Should be a ps_trunc object
  expect_s3_class(combined, "ps_trunc")

  # Check that indices are properly tracked
  combined_meta <- ps_trunc_meta(combined)
  expect_equal(length(combined), 20)

  # The total number of truncated should be the sum
  expect_equal(
    length(combined_meta$truncated_idx),
    n_truncated1 + n_truncated2
  )

  # Check that values at bounds are at the correct positions
  combined_data <- vec_data(combined)
  lower_bound <- combined_meta$lower_bound
  upper_bound <- combined_meta$upper_bound

  # All truncated indices should have values at the bounds
  truncated_values <- combined_data[combined_meta$truncated_idx]
  expect_true(all(
    truncated_values == lower_bound | truncated_values == upper_bound
  ))
})

test_that("ps_trunc warns when combining objects with different parameters", {
  ps1 <- runif(10, 0.05, 0.95)
  ps2 <- runif(10, 0.05, 0.95)

  # Create ps_trunc objects with different parameters
  ps_trunc1 <- ps_trunc(ps1, method = "ps", lower = 0.2, upper = 0.8)
  ps_trunc2 <- ps_trunc(ps2, method = "ps", lower = 0.3, upper = 0.7)

  # Should warn and return numeric
  expect_propensity_warning(
    combined <- c(ps_trunc1, ps_trunc2)
  )

  expect_type(combined, "double")
  expect_false(inherits(combined, "ps_trunc"))
})

test_that("ps_trunc index tracking works with subsetting and combining", {
  set.seed(789)
  ps <- runif(20, 0.05, 0.95)

  # Create ps_trunc object
  ps_trunc_obj <- ps_trunc(ps, method = "ps", lower = 0.3, upper = 0.7)
  meta <- ps_trunc_meta(ps_trunc_obj)

  # Subset the object
  subset1 <- ps_trunc_obj[1:10]
  subset2 <- ps_trunc_obj[11:20]

  # Recombine
  recombined <- c(subset1, subset2)

  # Should maintain ps_trunc class
  expect_s3_class(recombined, "ps_trunc")

  # Check indices are properly tracked
  recombined_meta <- ps_trunc_meta(recombined)
  expect_equal(
    length(recombined_meta$truncated_idx),
    length(meta$truncated_idx)
  )

  # Check that truncated values are preserved at correct positions
  recombined_data <- vec_data(recombined)
  original_data <- vec_data(ps_trunc_obj)

  # Find which values were at the bounds
  lower_bound <- meta$lower_bound
  upper_bound <- meta$upper_bound
  original_at_bounds <- which(
    original_data == lower_bound | original_data == upper_bound
  )
  recombined_at_bounds <- which(
    recombined_data == lower_bound | recombined_data == upper_bound
  )

  expect_equal(recombined_at_bounds, original_at_bounds)
})

test_that("ps_trunc handles multiple combines correctly", {
  set.seed(321)

  # Create three ps_trunc objects
  ps1 <- runif(5, 0.05, 0.95)
  ps2 <- runif(5, 0.05, 0.95)
  ps3 <- runif(5, 0.05, 0.95)

  ps_trunc1 <- ps_trunc(ps1, method = "ps", lower = 0.25, upper = 0.75)
  ps_trunc2 <- ps_trunc(ps2, method = "ps", lower = 0.25, upper = 0.75)
  ps_trunc3 <- ps_trunc(ps3, method = "ps", lower = 0.25, upper = 0.75)

  # Combine all three
  combined <- c(ps_trunc1, ps_trunc2, ps_trunc3)

  # Should maintain ps_trunc class
  expect_s3_class(combined, "ps_trunc")
  expect_equal(length(combined), 15)

  # Check indices
  combined_meta <- ps_trunc_meta(combined)
  combined_data <- vec_data(combined)

  # All truncated indices should have values at bounds
  truncated_values <- combined_data[combined_meta$truncated_idx]
  expect_true(all(
    truncated_values == combined_meta$lower_bound |
      truncated_values == combined_meta$upper_bound
  ))
})

Try the propensity package in your browser

Any scripts or data that you put into this service are public.

propensity documentation built on March 3, 2026, 1:06 a.m.