tests/testthat/test-att_gt.R

#library(ggplot2)
#library(ggpubr)


## -----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
# test each estimation method with panel data
# Expected results: treatment effects = 1, p-value for pre-test
# uniformly distributed, ipw model is incorrectly specified here
#-----------------------------------------------------------------------------
test_that("att_gt works w/o dynamics, time effects, or group effects", {
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$ipw <- FALSE
  data <- did::build_sim_dataset(sp)

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="dr")
  # reg
  res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="reg")


  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_reg$att[1], 1, tol=.5)
})


test_that("att_gt works using ipw", {
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$reg <- FALSE
  data <- did::build_sim_dataset(sp)

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                   gname="G", est_method="dr")


  # ipw
  res_ipw <- att_gt(yname="Y", xformla=~1, data=data, tname="period", idname="id",
                gname="G", est_method="ipw")

  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_ipw$att[1], 1, tol=.5)
})

test_that("two period case", {
  set.seed(09142024)
  sp <- did::reset.sim(time.periods=2)
  sp$ipw <- FALSE
  sp$n <- 10000
  data <- did::build_sim_dataset(sp)

  res <- suppressWarnings(
    att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
           gname="G", est_method="reg")
  )
  res

  agg_simple <- suppressWarnings(aggte(res, type="simple"))
  agg_group <- suppressWarnings(aggte(res, type="group"))
  agg_dynamic <- suppressWarnings(aggte(res, type="dynamic"))
  agg_calendar <- suppressWarnings(aggte(res, type="calendar"))

  expect_equal(agg_simple$overall.att, 1, tol=.5)
  expect_equal(agg_group$overall.att, 1, tol=.5)
  expect_equal(agg_dynamic$overall.att, 1, tol=.5)
  expect_equal(agg_calendar$overall.att, 1, tol=.5)
})

test_that("no covariates case", {
  set.seed(09142024)
  time.periods <- 4
  sp <- did::reset.sim(time.periods=time.periods)

  # no effect of covariates
  sp$bett <- sp$betu <- rep(0,time.periods)
  data <- did::build_sim_dataset(sp)

  res_dr <- att_gt(yname="Y", xformla=~1, data=data, tname="period", idname="id",
                gname="G", est_method="dr")

  res_reg <- att_gt(yname="Y", xformla=~1, data=data, tname="period", idname="id",
                gname="G", est_method="reg")

  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_reg$att[1], 1, tol=.5)
})

test_that("repeated cross section", {
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp, panel=FALSE)

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                   gname="G", est_method="dr", panel=FALSE)

  # reg
  res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                    gname="G", est_method="reg", panel=FALSE)

  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_reg$att[1], 1, tol=.5)
})


test_that("ipw repeated cross sections", {
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$reg <- FALSE
  sp$n <- 20000 # these are noisy
  data <- did::build_sim_dataset(sp, panel=FALSE)

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method="dr", panel=FALSE)


  #ipw
  res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="ipw", panel=FALSE)

  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_ipw$att[1], 1, tol=.5)
})


test_that("repeated cross sections dynamic effects", {
  set.seed(09142024)
  time.periods <- 4
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te.e <- 1:time.periods
  data <- did::build_sim_dataset(sp, panel=FALSE)

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                   gname="G", est_method="dr", panel=FALSE)


  agg_dynamic <- aggte(res_dr, type="dynamic")
  agg_idx <- agg_dynamic$egt==2

  expect_equal(agg_dynamic$att.egt[agg_idx], 3, tol=.5)
})

test_that("unbalanced panel", {
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  # drop second row to create unbalanced panel
  data <- data[-2,]

  # dr
  res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method="dr", allow_unbalanced_panel=TRUE)

  expect_equal(res_dr$att[1], 1, tol=.5)

  expect_warning(att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method="dr", allow_unbalanced_panel=FALSE))

  # ipw version
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$reg <- FALSE
  data <- did::build_sim_dataset(sp)
  data <- data[-2,]

  res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method="ipw", allow_unbalanced_panel=TRUE)

  expect_equal(res_dr$att[1], 1, tol=.5)

  # unbalanced paenl without providing id, should error
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data <- data[sample(1:nrow(data),  size=floor(.9*nrow(data))),]

  expect_error(att_gt(yname="Y", xformla=~X, data=data, tname="period", idname=NULL,
                      gname="G", est_method="reg", panel=TRUE, allow_unbalanced_panel=TRUE))
})

test_that("not yet treated comparison group", {
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$reg <- FALSE
  data <- did::build_sim_dataset(sp, panel=FALSE)

  # dr
  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="notyettreated",
                gname="G", est_method="dr", panel=FALSE)

  expect_equal(res$att[1], 1, tol=.5)

  # no never treated group
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data <- subset(data, G > 0) # drop nevertreated

  # dr
  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="notyettreated",
                gname="G", est_method="dr", panel=FALSE)
  expect_equal(res$att[1], 1, tol=.5)


  # try to use never treated group as comparison group, should warn
  expect_warning(nonev_orig <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                      control_group="nevertreated",
                      gname="G", est_method="dr", panel=FALSE, faster_mode = FALSE))

  # try to use never treated group as comparison group with faster mode, should warn
  expect_warning(nonev_faster <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                        control_group="nevertreated",
                        gname="G", est_method="dr", panel=FALSE, faster_mode = TRUE))

  # make sure both methods give same ATT(g,t) with no never treated group
  expect_equal(nonev_orig$att, nonev_faster$att)

})

test_that("aggregations", {
  set.seed(09142024)
  # dynamic effects
  time.periods <- 4
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.e <- 1:time.periods
  data <- did::build_sim_dataset(sp)

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="nevertreated",
                gname="G", est_method="reg", panel=FALSE)

  agg_dynamic <- aggte(res, type="dynamic")
  agg_idx <- agg_dynamic$egt==2

  expect_equal(agg_dynamic$att.egt[agg_idx], 2, tol=.5)


  # group effects
  set.seed(09142024)
  time.periods <- 4
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.bet.ind <- 1:time.periods
  sp$reg <- FALSE
  data <- did::build_sim_dataset(sp, panel=FALSE)

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="notyettreated",
                gname="G", est_method="ipw", panel=FALSE)

  agg_group <- aggte(res, type="group")

  expect_equal(agg_group$att.egt[2], 2*2, tol=.5)


  # calendar time effects
  set.seed(09142024)
  time.periods <- 4
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.t <- sp$thet + 1:time.periods
  data <- did::build_sim_dataset(sp, panel=FALSE)

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="nevertreated",
                gname="G", est_method="dr", panel=FALSE)

  agg_calendar <- aggte(res, type="calendar")
  expect_equal(agg_calendar$att.egt[2], 2, tol=.5)


  # balancing with respect to event time
  set.seed(09142024)
  sp <- did::reset.sim()
  sp$te <- 0
  sp$te.e <- 1:time.periods
  sp$te.bet.ind <- 1:time.periods
  data <- did::build_sim_dataset(sp)

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="nevertreated",
                gname="G", est_method="dr", panel=FALSE)


  agg_dynamic <- aggte(res, type="dynamic")
  agg_dynamic_balance <- aggte(res, type="dynamic", balance_e=1)

  ad_idx <- which(agg_dynamic$egt == 1)
  adb_idx <- which(agg_dynamic_balance$egt == 1)

  expect_equal(agg_dynamic_balance$att.egt[adb_idx] - agg_dynamic_balance$att.egt[adb_idx-1], 1, tol=.5)
})

test_that("unequally spaced groups", {
  set.seed(09142024)
  time.periods <- 8
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.e <- 1:time.periods
  data <- did::build_sim_dataset(sp)
  keep.periods <- c(1,2,5,7)
  data <- subset(data, G %in% c(0, keep.periods))
  data <- subset(data, period %in% keep.periods)

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                control_group="nevertreated",
                gname="G", est_method="reg", panel=FALSE)

  agg_dynamic <- aggte(res, type="dynamic")
  agg_idx <- agg_dynamic$egt==2

  expect_equal(agg_dynamic$att.egt[agg_idx], 3, tol=.5)

  agg_dynamic_balance <- aggte(res, type="dynamic", balance_e=0)
  agg_idx2 <- which(agg_dynamic_balance$egt==0)
  expect_equal(agg_dynamic_balance$att.egt[agg_idx2], 1, tol=.5)
})

test_that("some units treated in first period", {
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data <- subset(data, period >= 2)

  expect_warning(att_gt(yname="Y", xformla=~X, data=data, tname="period",
                        control_group="nevertreated",
                        gname="G", est_method="reg", panel=FALSE))
})

test_that("min and max length of exposures", {
  set.seed(09142024)
  sp <- did::reset.sim()
  time.periods <- 4
  sp$te <- 0
  sp$te.e <- 1:time.periods
  sp$bett <- sp$betu <- rep(0,time.periods)
  data <- did::build_sim_dataset(sp)

  res <- att_gt(yname="Y", xformla=~1, data=data, tname="period",
                idname="id",
                control_group="nevertreated",
                gname="G", est_method="reg", panel=TRUE,
                base_period="varying")

  agg_dynamic <- aggte(res, type="dynamic", min_e=-1, max_e=1)
  agg_idx <- which(agg_dynamic$egt == 1)
  expect_equal(agg_dynamic$att.egt[agg_idx], 2, tol=.5)
})


test_that("anticipation", {
  set.seed(09142024)
  time.periods <- 5
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.e <- -1:(time.periods-2)
  data <- did::build_sim_dataset(sp)
  data$G <- ifelse(data$G==0, 0, data$G + 1) # add anticipation
  data <- subset(data, G <= time.periods) # drop last period (due to way data is constructed)
  # this will have an anticipation effect=-1, no effect at exposure,
  # and treatment effects increasing by one in subsequent periods

  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                idname="id",
                control_group="nevertreated",
                gname="G", est_method="dr",
                anticipation=1
                )

  agg_dynamic <- aggte(res, type="dynamic")
  agg_idx <- which(agg_dynamic$egt==2)
  expect_equal(agg_dynamic$att.egt[agg_idx], 2, tol=.5)

  # incorrectly ignore anticipation
  # causes over-stating treatment effects
  # due to using incorrect base-period
  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                idname="id",
                control_group="nevertreated",
                gname="G", est_method="dr",
                anticipation=0
                )

  agg_dynamic <- aggte(res, type="dynamic")
  agg_idx <- which(agg_dynamic$egt==2)
  expect_equal(agg_dynamic$att.egt[agg_idx], 3, tol=.5)


  # test for previous bug when using anticipation and
  # a notyettreated comparison group
  data <- subset(data, G != 0)
  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                idname="id",
                control_group="notyettreated",
                gname="G", est_method="dr",
                anticipation=1
                )

  agg_dynamic <- aggte(res, type="dynamic")
  agg_idx <- which(agg_dynamic$egt==0)
  expect_equal(agg_dynamic$att.egt[agg_idx], 0, tol=.3)
})

test_that("significance level and uniform confidence bands", {
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  # 5% significance level
  set.seed(1234)
  res05 <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                  gname="G", est_method="dr", alp=0.05)
  # 1% significance level
  set.seed(1234)
  res01 <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                  gname="G", est_method="dr", alp=0.01)
  # 5% pointwise
  set.seed(1234)
  res_pw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                  gname="G", est_method="dr", alp=0.05, cband=FALSE)

  expect_lte(res05$att[1] + res05$c*res05$se[1],
             res01$att[1] + res01$c*res01$se[1])
  expect_gte(res05$att[1] + res05$c*res05$se[1],
             res_pw$att[1] + res_pw$c*res_pw$se[1])

})

test_that("malformed data", {
  set.seed(09142024)
  # some groups later than last treated period
  # plus missing groups
  time.periods <- 7
  sp <- did::reset.sim(time.periods=time.periods)
  data <- did::build_sim_dataset(sp)
  data <- subset(data, period <= 4)
  missingG_ids <- sample(unique(data$id), size=10)
  data[data$id %in% missingG_ids,"G"] <- NA

  expect_warning(att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method="dr"))


  #-----------------------------------------------------------------------------
  # incorrectly specified id
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  expect_error(att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="brant",
                      gname="G", est_method="dr"))
})

test_that("varying or universal base period", {
  set.seed(09142024)
  time.periods <- 8
  sp <- did::reset.sim(time.periods=time.periods)
  sp$te <- 0
  sp$te.e <- 1:time.periods
  data <- did::build_sim_dataset(sp)
  data <- subset(data, (G<=5) | G==0 )
  # add pre-treatment effects
  data$G <- ifelse(data$G==0, 0, data$G+3)


  # dr
  res_varying <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="dr", base_period="varying")

  agg_dynamic_varying <- aggte(res_varying, type="dynamic")

  res_universal <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                          gname="G", est_method="dr", base_period="universal")

  agg_dynamic_universal <- aggte(res_universal, type="dynamic")

  agg_idx <- which(agg_dynamic_varying$egt == -3)

  expect_equal(agg_dynamic_varying$att.egt[agg_idx], 1, tol=.5)
  expect_equal(agg_dynamic_universal$att.egt[agg_idx], -2, tol=.5)
})


test_that("small groups", {
  # code should still compute in this case (as comparison
  # group is large, but should give a warning about small groups)
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  # keep only one observation from group 2
  G2_keep_id <- unique(subset(data, G==2)$id)[1]
  data <- subset(data, (G != 2) | (id == G2_keep_id))

  # dr
  expect_warning(res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="dr"), "very few observations")
  # reg
  expect_warning(res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="reg"), "very few observations")
  # ipw
  expect_warning(res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="ipw"), "very few observations")

  # estimates will be imprecise for group 2
  idx <- which(res_dr$group == 3 & res_dr$t==3)
  expect_equal(res_dr$att[idx], 1, tol=.5)
  expect_equal(res_reg$att[idx], 1, tol=.5)
})


test_that("small comparison group", {
  set.seed(09142024)
  # code doesn't run here if use never treated comparison group
  # but should run for all groups except the last one when
  # the not-yet-treated comparison group
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  # keep only one observation from untreated group
  G0_keep_id <- unique(subset(data, G==0)$id)[1]
  data <- subset(data, (G != 0) | (id == G0_keep_id))

  #-----------------------------------------------------------------------------
  # never treated comparison group
  #-----------------------------------------------------------------------------
  # dr
  expect_error(expect_warning(res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="dr"), "very few observations"), "never-treated group is too small")
  # reg
  expect_error(expect_warning(res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="reg"), "very few observations"), "never-treated group is too small")
  # ipw
  expect_error(expect_warning(res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
              gname="G", est_method="ipw"), "very few observations"), "never-treated group is too small")

  #-----------------------------------------------------------------------------
  # not-yet-treated comparison group with faster_mode = TRUE
  #-----------------------------------------------------------------------------

  # dr
  expect_warning(res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                                               gname="G", est_method="dr", faster_mode = TRUE), "very few observations")
  # reg
  expect_warning(res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                                                gname="G", est_method="reg", faster_mode = TRUE), "very few observations")
  # ipw
  expect_warning(res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                                   gname="G", est_method="ipw", faster_mode = TRUE), "very few observations")

  #-----------------------------------------------------------------------------
  # not-yet-treated comparison group
  #-----------------------------------------------------------------------------
  # dr
  expect_warning(res_dr <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
              gname="G", est_method="dr", faster_mode = FALSE), "very few observations")
  # reg
  expect_warning(res_reg <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
              gname="G", est_method="reg", faster_mode = FALSE), "singular or numerically ill-conditioned")
  # ipw
  expect_warning(res_ipw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
              gname="G", est_method="ipw", faster_mode = FALSE), "overlap condition violated")

  # code should still work for some (g,t)'s
  expect_equal(res_dr$att[1], 1, tol=.5)
  expect_equal(res_reg$att[1], 1, tol=.5)
  expect_equal(res_ipw$att[1], 1, tol=.5)

  #-----------------------------------------------------------------------------
  # aggregations
  #-----------------------------------------------------------------------------

  agg_dyn <- aggte(res_dr, type="dynamic", na.rm=TRUE)
  agg_group <- aggte(res_reg, type="group", na.rm=TRUE)

  expect_equal(agg_dyn$att.egt[3], 1, tol=.5)
  expect_equal(agg_group$att.egt[1], 1, tol=.5)

  # make sure that standard errors are computed too
  expect_false(is.na(agg_dyn$se.egt[3]))
  expect_false(is.na(agg_group$se.egt[1]))

  agg_cal <- aggte(res_ipw, type="calendar", na.rm=TRUE)
  expect_equal(agg_cal$att.egt[1], 1, tol=.5)
  expect_false(is.na(agg_cal$se.egt[1]))
})

test_that("custom estimation method", {
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  res <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                gname="G", est_method=DRDID::drdid_imp_panel, panel=TRUE)
  expect_equal(res$att[1], 1, tol=.5)
})



test_that("sampling weights", {
  set.seed(09142024)
  # the idea here is that we can re-weight and should
  # get the same thing as if we subset
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data2 <- data
  keepids <- sample(unique(data$id), length(unique(data$id)))
  data$w <- 1*(data$id %in% keepids) # weights shouldn't have to have mean/sum 1
  data2 <- subset(data, id %in% keepids)

  res_weights <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                        gname="G", est_method="reg", weightsname="w")

  res_subset <- att_gt(yname="Y", xformla=~X, data=data2, tname="period", idname="id", control_group="notyettreated",
                        gname="G", est_method="reg")

  # test for same att's
  expect_equal(res_weights$att[1], res_subset$att[1])
  # test for same standard errors
  expect_equal(res_weights$se[1], res_subset$se[1], tol=.02)

})

# =============================================================================
# Column naming: user columns named gname/tname/idname should not crash
# =============================================================================

test_that("works when user column is literally named 'gname'", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  # Rename columns to match parameter names exactly
  names(data)[names(data) == "G"] <- "gname"
  names(data)[names(data) == "period"] <- "tname"
  names(data)[names(data) == "id"] <- "idname"

  mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname",
                gname="gname", est_method="reg", bstrap=FALSE)
  expect_false(all(is.na(mod$att)))

  # aggte should also work (this was the specific dreamerr bug)
  agg <- suppressWarnings(aggte(mod, type="simple"))
  expect_false(is.na(agg$overall.att))

  agg_dyn <- suppressWarnings(aggte(mod, type="dynamic"))
  expect_false(is.na(agg_dyn$overall.att))
})

test_that("works when user column is literally named 'gname' with faster_mode", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  names(data)[names(data) == "G"] <- "gname"
  names(data)[names(data) == "period"] <- "tname"
  names(data)[names(data) == "id"] <- "idname"

  mod <- att_gt(yname="Y", xformla=~X, data=data, tname="tname", idname="idname",
                gname="gname", est_method="reg", bstrap=FALSE, faster_mode=TRUE)
  expect_false(all(is.na(mod$att)))

  agg <- aggte(mod, type="simple")
  expect_false(is.na(agg$overall.att))
})

# =============================================================================
# Time-varying weights: fix_weights tests
# =============================================================================

test_that("time-varying weights: faster_mode matches slow mode (default fix_weights=NULL)", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1)

  for (em in c("reg", "dr", "ipw")) {
    for (bp in c("varying", "universal")) {
      res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                         gname="G", est_method=em, weightsname="tv_weight",
                         base_period=bp, faster_mode=FALSE, bstrap=FALSE)
      res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                         gname="G", est_method=em, weightsname="tv_weight",
                         base_period=bp, faster_mode=TRUE, bstrap=FALSE)

      expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                   label=paste("ATT match:", em, bp))
    }
  }
})

test_that("fix_weights options: faster_mode matches slow mode (balanced panel)", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1)

  for (fw in c("varying", "base_period", "first_period")) {
    res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                       gname="G", est_method="dr", weightsname="tv_weight",
                       fix_weights=fw, faster_mode=FALSE, bstrap=FALSE)
    res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                       gname="G", est_method="dr", weightsname="tv_weight",
                       fix_weights=fw, faster_mode=TRUE, bstrap=FALSE)

    expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                 label=paste("ATT match:", fw))
  }
})

test_that("time-invariant weights: all fix_weights options produce identical ATTs", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  n_ids <- length(unique(data$id))
  n_periods <- length(unique(data$period))
  data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods)

  res_default <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                        gname="G", est_method="reg", weightsname="const_weight",
                        bstrap=FALSE)

  for (fw in c("base_period", "first_period")) {
    res_fw <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                     gname="G", est_method="reg", weightsname="const_weight",
                     fix_weights=fw, bstrap=FALSE)
    expect_equal(res_default$att, res_fw$att, tolerance=1e-10,
                 label=paste("same ATT for", fw))
  }
})

test_that("message emitted for time-varying weights in balanced panel", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5)

  expect_message(
    att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
           gname="G", weightsname="tv_weight", bstrap=FALSE),
    "Time-varying weights detected"
  )
})

test_that("no message for time-invariant weights", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  n_ids <- length(unique(data$id))
  n_periods <- length(unique(data$period))
  data$const_weight <- rep(runif(n_ids, 1, 10), each = n_periods)

  expect_no_message(
    att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
           gname="G", weightsname="const_weight", bstrap=FALSE)
  )
})

test_that("notyettreated with time-varying weights: faster_mode matches", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), 0, 0.5)

  res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                     gname="G", est_method="dr", weightsname="tv_weight",
                     control_group="notyettreated", faster_mode=FALSE, bstrap=FALSE)
  res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                     gname="G", est_method="dr", weightsname="tv_weight",
                     control_group="notyettreated", faster_mode=TRUE, bstrap=FALSE)

  expect_equal(res_slow$att, res_fast$att, tolerance=1e-10)
})

test_that("RC with time-varying weights: faster_mode matches", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period * 1.0 + runif(nrow(data), 0, 0.5)

  res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id",
                     gname="G", est_method="reg", weightsname="tv_weight",
                     panel=FALSE, faster_mode=FALSE, bstrap=FALSE)
  res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id",
                     gname="G", est_method="reg", weightsname="tv_weight",
                     panel=FALSE, faster_mode=TRUE, bstrap=FALSE)

  expect_equal(res_slow$att, res_fast$att, tolerance=1e-10)
})

test_that("fix_weights validation", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  expect_error(
    att_gt(yname="Y", data=data, tname="period", idname="id",
           gname="G", fix_weights="invalid_option", bstrap=FALSE),
    "fix_weights must be NULL"
  )

  # base_period and first_period not supported for repeated cross sections
  expect_error(
    att_gt(yname="Y", data=data, tname="period", idname="id",
           gname="G", fix_weights="base_period", panel=FALSE, bstrap=FALSE),
    "not supported for repeated cross sections"
  )
  expect_error(
    att_gt(yname="Y", data=data, tname="period", idname="id",
           gname="G", fix_weights="first_period", panel=FALSE, bstrap=FALSE),
    "not supported for repeated cross sections"
  )

  # varying not supported with custom est_method when panel = TRUE
  my_panel_est <- function(y1, y0, D, covariates, i.weights, inffunc, ...) {
    list(ATT = mean(y1 - y0), att.inf.func = rep(0, length(y1)))
  }
  expect_error(
    att_gt(yname="Y", data=data, tname="period", idname="id",
           gname="G", fix_weights="varying", est_method=my_panel_est,
           panel=TRUE, bstrap=FALSE),
    "not currently supported with custom est_method"
  )

  # varying IS supported with custom est_method when panel = FALSE (RC signature)
  my_rc_est <- function(y, post, D, covariates, i.weights, inffunc, ...) {
    n_obs <- length(y)
    post_c <- post[D==0]
    y_c <- y[D==0]
    w_c <- i.weights[D==0]
    att <- mean(y_c[post_c==1] * w_c[post_c==1]) / mean(w_c[post_c==1]) -
           mean(y_c[post_c==0] * w_c[post_c==0]) / mean(w_c[post_c==0])
    list(ATT = att, att.inf.func = rep(0, n_obs))
  }
  # Wald pre-test warning is expected with this small sim dataset (group 2
  # has only one pre-treatment period), but the key thing is: no error and
  # no recycling warnings from mismatched influence-function length.
  rc_result <- expect_no_error(
    withCallingHandlers(
      att_gt(yname="Y", data=data, tname="period", idname="id",
             gname="G", fix_weights="varying", est_method=my_rc_est,
             panel=FALSE, bstrap=FALSE),
      warning = function(w) {
        if (grepl("not a multiple of replacement length", conditionMessage(w)))
          stop("IF length mismatch: ", conditionMessage(w))
        invokeRestart("muffleWarning")
      }
    )
  )
  expect_true(inherits(rc_result, "MP"))
  expect_false(anyNA(rc_result$att))
})

test_that("unbalanced panel fix_weights with units missing from reference period", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  # Drop some treated units from first period so fix_weights="first_period" must drop them
  first_p <- min(data$period)
  drop_ids <- unique(data$id[data$G > 0])[1:10]
  data <- data[!(data$id %in% drop_ids & data$period == first_p), ]
  data$w <- runif(nrow(data), 1, 5)

  for (fw in c("first_period", "base_period")) {
    res_slow <- suppressWarnings(suppressMessages(
      att_gt(yname = "Y", data = data, tname = "period", idname = "id",
             gname = "G", allow_unbalanced_panel = TRUE,
             fix_weights = fw, weightsname = "w",
             bstrap = FALSE, faster_mode = FALSE)
    ))
    res_fast <- suppressWarnings(suppressMessages(
      att_gt(yname = "Y", data = data, tname = "period", idname = "id",
             gname = "G", allow_unbalanced_panel = TRUE,
             fix_weights = fw, weightsname = "w",
             bstrap = FALSE, faster_mode = TRUE)
    ))
    expect_equal(res_slow$att, res_fast$att, tolerance = 1e-10,
                 label = paste("unbalanced", fw, "ATT match"))
  }
})

# =============================================================================
# Influence function consistency: slow vs fast mode ATT AND SE must match
# =============================================================================

test_that("IF consistency: balanced panel, all fix_weights x est_method x base_period", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1)

  for (fw in c(NA, "varying", "base_period", "first_period")) {
    fw_arg <- if (is.na(fw)) NULL else fw
    for (em in c("dr", "ipw", "reg")) {
      for (bp in c("varying", "universal")) {
        label <- paste("panel", if (is.na(fw)) "NULL" else fw, em, bp)

        res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                           idname="id", gname="G", est_method=em,
                           weightsname="tv_weight", fix_weights=fw_arg,
                           base_period=bp, faster_mode=FALSE,
                           bstrap=FALSE, cband=FALSE)
        res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                           idname="id", gname="G", est_method=em,
                           weightsname="tv_weight", fix_weights=fw_arg,
                           base_period=bp, faster_mode=TRUE,
                           bstrap=FALSE, cband=FALSE)

        expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                     label=paste("ATT", label))
        expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
                     label=paste("SE", label))
      }
    }
  }
})

test_that("IF consistency: balanced panel, notyettreated control group", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1)

  for (fw in c(NA, "varying", "base_period", "first_period")) {
    fw_arg <- if (is.na(fw)) NULL else fw
    label <- paste("notyettreated", if (is.na(fw)) "NULL" else fw)

    res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                       idname="id", gname="G", est_method="dr",
                       weightsname="tv_weight", fix_weights=fw_arg,
                       control_group="notyettreated",
                       faster_mode=FALSE, bstrap=FALSE, cband=FALSE)
    res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                       idname="id", gname="G", est_method="dr",
                       weightsname="tv_weight", fix_weights=fw_arg,
                       control_group="notyettreated",
                       faster_mode=TRUE, bstrap=FALSE, cband=FALSE)

    expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                 label=paste("ATT", label))
    expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
                 label=paste("SE", label))
  }
})

test_that("IF consistency: repeated cross-sections, default weights x est_method", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$tv_weight <- data$period + runif(nrow(data), -0.1, 0.1)

  # RC with default weights (fix_weights=NULL); fixed weight options tested separately
  for (em in c("dr", "ipw", "reg")) {
    label <- paste("RC NULL", em)

    res_slow <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                       idname="id", gname="G", est_method=em,
                       weightsname="tv_weight",
                       panel=FALSE, faster_mode=FALSE,
                       bstrap=FALSE, cband=FALSE)
    res_fast <- att_gt(yname="Y", xformla=~X, data=data, tname="period",
                       idname="id", gname="G", est_method=em,
                       weightsname="tv_weight",
                       panel=FALSE, faster_mode=TRUE,
                       bstrap=FALSE, cband=FALSE)

    expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                 label=paste("ATT", label))
    expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
                 label=paste("SE", label))
  }
})

test_that("IF consistency: unbalanced panel, default weights x est_method", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  # Create unbalanced panel by dropping some observations
  set.seed(42)
  drop_idx <- sample(nrow(data), size = floor(nrow(data) * 0.05))
  data_unbal <- data[-drop_idx, ]

  # Default weights (fix_weights=NULL); fixed weight options for unbalanced panels
  # have known edge cases with unit availability across periods
  for (em in c("dr", "reg")) {
    label <- paste("unbalanced NULL", em)

    res_slow <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period",
                       idname="id", gname="G", est_method=em,
                       allow_unbalanced_panel=TRUE,
                       faster_mode=FALSE, bstrap=FALSE, cband=FALSE)
    res_fast <- att_gt(yname="Y", xformla=~X, data=data_unbal, tname="period",
                       idname="id", gname="G", est_method=em,
                       allow_unbalanced_panel=TRUE,
                       faster_mode=TRUE, bstrap=FALSE, cband=FALSE)

    expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                 label=paste("ATT", label))
    expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
                 label=paste("SE", label))
  }
})

test_that("IF consistency: no covariates (xformla=~1), all data types", {
  set.seed(20260401)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  # Balanced panel, no covariates
  for (fw in c(NA, "varying")) {
    fw_arg <- if (is.na(fw)) NULL else fw
    label <- paste("no-covar panel", if (is.na(fw)) "NULL" else fw)

    res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id",
                       gname="G", fix_weights=fw_arg,
                       faster_mode=FALSE, bstrap=FALSE, cband=FALSE)
    res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id",
                       gname="G", fix_weights=fw_arg,
                       faster_mode=TRUE, bstrap=FALSE, cband=FALSE)

    expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
                 label=paste("ATT", label))
    expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
                 label=paste("SE", label))
  }

  # RC, no covariates
  res_slow <- att_gt(yname="Y", data=data, tname="period", idname="id",
                     gname="G", panel=FALSE,
                     faster_mode=FALSE, bstrap=FALSE, cband=FALSE)
  res_fast <- att_gt(yname="Y", data=data, tname="period", idname="id",
                     gname="G", panel=FALSE,
                     faster_mode=TRUE, bstrap=FALSE, cband=FALSE)

  expect_equal(res_slow$att, res_fast$att, tolerance=1e-10,
               label="ATT RC no-covar")
  expect_equal(res_slow$se, res_fast$se, tolerance=1e-10,
               label="SE RC no-covar")
})

test_that("clustered standard errors", {
  set.seed(09142024)
  # check that we can compute when clustered standard errors are supplied
  # either as numeric or as factor
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)

  data$cluster <- as.numeric(data$cluster)
  res_numeric <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                   gname="G", est_method="dr", clustervars="cluster")

  data$cluster <- as.factor(data$cluster)
  res_factor <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id",
                        gname="G", est_method="dr", clustervars="cluster")

  # test for same att's
  expect_equal(res_factor$att[1], res_numeric$att[1])
  # test for same standard errors
  expect_equal(res_factor$se[1], res_numeric$se[1], tol=.02)

  #-----------------------------------------------------------------------------
  # clustered standard errors with unbalanced panel
  data <- data[-3,] # drop one observation
  res_ub <- att_gt(yname="Y",
              tname="period",
              idname="id",
              gname="G",
              xformla=~X,
              data=data,
              panel=TRUE,
              allow_unbalanced_panel=TRUE,
              faster_mode = FALSE,
              clustervars="cluster")
  expect_equal(res_ub$att[1], 1, tol=.5)

  # clustered standard errors with unbalanced panel with faster mode
  res_ub_faster <- att_gt(yname="Y",tname="period",idname="id",gname="G",
              xformla=~X,data=data,panel=TRUE, allow_unbalanced_panel=TRUE,
              faster_mode = TRUE, clustervars="cluster")
  expect_equal(res_ub_faster$att[1], 1, tol=.5)

  #-----------------------------------------------------------------------------
  # also, check that we error when clustering variable varies within unit
  # over time -- identically in both modes (the slow path used to accept this
  # input and fall back to i.i.d. SEs with bstrap = FALSE)
  set.seed(09142024)
  sp <- did::reset.sim()
  data <- did::build_sim_dataset(sp)
  data$cluster <- as.numeric(data$cluster)
  data[1,]$cluster <- data[1,]$cluster+1

  for (fm in c(TRUE, FALSE)) {
    expect_error(res_vc <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                          gname="G", est_method="dr", clustervars="cluster", faster_mode=fm),
                 "Time-varying cluster variables are not supported")
  }

  #-----------------------------------------------------------------------------
  # clustered standard errors with repeated cross sections data
  data <- did::build_sim_dataset(sp, panel=FALSE)
  res_rc <- att_gt(yname="Y", xformla=~X, data=data, tname="period", idname="id", control_group="notyettreated",
                   gname="G", est_method="dr", clustervars="cluster", panel=FALSE)
  expect_equal(res_rc$att[1], 1, tol=.5)
})

test_that("faster mode enabled for panel data", {
  data <- did::mpdta
  out <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "universal",
                control_group = "nevertreated", est_method = "dr", faster_mode = FALSE)
  out2 <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "universal",
                control_group = "nevertreated", est_method = "dr", faster_mode = TRUE)

  # check if results are equal.
  expect_equal(out$att, out2$att)
  expect_equal(out$se, as.numeric(out2$se))
  # --------------------------------------------------------------------------------------------------------
  out <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "varying",
                control_group = "nevertreated", est_method = "dr", faster_mode = FALSE)
  out2 <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                 xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "varying",
                 control_group = "nevertreated", est_method = "dr", faster_mode = TRUE)

  # check if results are equal.
  expect_equal(out$att, out2$att)
  expect_equal(out$se, as.numeric(out2$se))

  # --------------------------------------------------------------------------------------------------------
  out <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "varying",
                control_group = "notyettreated", est_method = "dr", faster_mode = FALSE)
  out2 <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                 xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "varying",
                 control_group = "notyettreated", est_method = "dr", faster_mode = TRUE)

  # check if results are equal.
  expect_equal(out$att, out2$att)
  expect_equal(out$se, as.numeric(out2$se))

  # --------------------------------------------------------------------------------------------------------
  out <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "universal",
                control_group = "notyettreated", est_method = "dr", faster_mode = FALSE)
  out2 <- att_gt(yname = "lemp", gname = "first.treat", idname = "countyreal", tname = "year",
                 xformla = ~1, data = data, bstrap = FALSE, cband = FALSE, base_period = "universal",
                 control_group = "notyettreated", est_method = "dr", faster_mode = TRUE)

  # check if results are equal.
  expect_equal(out$att, out2$att)
  expect_equal(out$se, as.numeric(out2$se))

})

test_that("faster model enabled for repeated cross sectional data", {

  data_rcs <- as.data.table(did::build_sim_dataset(reset.sim(time.periods=4, n=1000), panel=FALSE))
  data_rcs$period <- as.integer(data_rcs$period)
  data_rcs[G == 0, G := Inf]

  # ----------------------------------------------------------------------------------------------------------------------------------
  out_rcs <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                    bstrap = FALSE, cband = FALSE, base_period = "universal", control_group = "nevertreated",
                    est_method = "dr", faster_mode = FALSE)

  out_rcs2 <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                      bstrap = FALSE, cband = FALSE, base_period = "universal", control_group = "nevertreated",
                      est_method = "dr", faster_mode = TRUE )

  # check if results are equal.
  expect_equal(out_rcs$att, out_rcs2$att)
  expect_equal(out_rcs$se, as.numeric(out_rcs2$se))

  # ----------------------------------------------------------------------------------------------------------------------------------
  out_rcs <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                    bstrap = FALSE, cband = FALSE, base_period = "varying", control_group = "nevertreated",
                    est_method = "dr", faster_mode = FALSE)

  out_rcs2 <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                      bstrap = FALSE, cband = FALSE, base_period = "varying", control_group = "nevertreated",
                      est_method = "dr", faster_mode = TRUE )

  # check if results are equal.
  expect_equal(out_rcs$att, out_rcs2$att)
  expect_equal(out_rcs$se, as.numeric(out_rcs2$se))

  # ----------------------------------------------------------------------------------------------------------------------------------
  out_rcs <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                    bstrap = FALSE, cband = FALSE, base_period = "varying", control_group = "notyettreated",
                    est_method = "dr", faster_mode = FALSE)

  out_rcs2 <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                      bstrap = FALSE, cband = FALSE, base_period = "varying", control_group = "notyettreated",
                      est_method = "dr", faster_mode = TRUE )

  # check if results are equal.
  expect_equal(out_rcs$att, out_rcs2$att)
  expect_equal(out_rcs$se, as.numeric(out_rcs2$se))

  # ----------------------------------------------------------------------------------------------------------------------------------
  out_rcs <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                    bstrap = FALSE, cband = FALSE, base_period = "universal", control_group = "notyettreated",
                    est_method = "dr", faster_mode = FALSE)

  out_rcs2 <- att_gt(yname = "Y", gname = "G", tname = "period", xformla = ~1, data = data_rcs, panel = FALSE,
                      bstrap = FALSE, cband = FALSE, base_period = "universal", control_group = "notyettreated",
                      est_method = "dr", faster_mode = TRUE )

  # check if results are equal.
  expect_equal(out_rcs$att, out_rcs2$att)
  expect_equal(out_rcs$se, as.numeric(out_rcs2$se))

})

test_that("faster model enabled for unbalanced panel data and time-varying covariates", {

  set.seed(05202025)
  # balanced panel dimensions
  N  <- 1000        # individuals
  TT <- 6           # time periods (t = 1, …, 6)
  # create id–time grid
  dt <- CJ(id = 1:N, t = 1:TT)   # CJ:=Cartesian join == balanced panel
  # assign cohort (first time treated) or never-treated
  # 40% treated in period 4, 30% in 5, 10% in 6, rest never treated
  cohort_vals  <- c(0, 4, 5, 6)                     # 0 = never treated
  cohort_probs <- c(1 - .40 - .30 - .10, .40, .30, .10)

  dt[, g := sample(
    cohort_vals,
    size = 1,          # one draw per id → time-invariant
    prob = cohort_probs),
    by = id]  # group by id

  # treatment indicator: D = 1 if t ≥ g and g > 0
  dt[, D := as.integer(t >= g & g > 0)]

  # four time-varying covariates
  dt[, `:=`(
    x1 = rnorm(.N),
    x2 = runif(.N, 0, 10) + 0.1 * t,
    x3 = rnorm(N)[id] + 0.2 * t + rnorm(.N),
    x4 = sin(2 * pi * t / TT) + rnorm(.N)
  )]

  # outcome model: baseline + covariate effects + treatment effect
  # note: just for debugging purposes, not a well-defined DGP for theoretical results!!
  beta  <- c(1.5, -0.8,  0.4, 2.0)          # coefficients on x1–x4
  tau   <- 3                                # true treatment effect
  alpha_i <- rnorm(N)[dt$id]                # id fixed effect
  gamma_t <- seq(-1, 1, length.out = TT)[dt$t]  # common time trend

  dt[, y := alpha_i + gamma_t +
       beta[1]*x1 + beta[2]*x2 + beta[3]*x3 + beta[4]*x4 +
       tau * D + rnorm(.N, sd = 2)]

  # outcome model: baseline + covariate effects + treatment effect
  # note: just for debugging purposes, not a well-defined DGP for theoretical results!!
  beta  <- c(1.5, -0.8,  0.4, 2.0)          # coefficients on x1–x4
  tau   <- 3                                # true treatment effect
  alpha_i <- rnorm(N)[dt$id]                # id fixed effect
  gamma_t <- seq(-1, 1, length.out = TT)[dt$t]  # common time trend

  dt[, y := alpha_i + gamma_t +
       beta[1]*x1 + beta[2]*x2 + beta[3]*x3 + beta[4]*x4 +
       tau * D + rnorm(.N, sd = 2)]

  #################################################################
  ##  Make the panel UNBALANCED
  ##  drop 25 % of rows at random, but leave ≥1 row per id
  #################################################################
  drop_rate <- 0.25
  dt_unbal <- dt[, .SD[runif(.N) > drop_rate], by = id]     # keep rows with prob 0.75
  dt_unbal  <- dt_unbal[, if (.N > 0) .SD, by = id]         # (safety) ids with ≥1 obs

  # ----------------------------------------------------------------------------------------------------------------------------------
  # let's run did with covariates using faster_mode=FALSE
  att_slower <- att_gt(
    yname   = "y",
    tname   = "t",
    idname  = "id",
    gname   = "g",
    xformla = ~ x1 + x2 + x3 + x4,   # four time-varying covariates
    data    = dt,
    base_period = "universal", # "universal" or "varying"
    panel   = TRUE,
    faster_mode = FALSE,
    bstrap = FALSE,
  )

  # now let's run did with covariates using faster_mode=TRUE
  att_faster <- att_gt(
    yname   = "y",
    tname   = "t",
    idname  = "id",
    gname   = "g",
    xformla = ~ x1 + x2 + x3 + x4,   # four time-varying covariates
    data    = dt,
    panel   = TRUE,
    base_period = "universal", # "universal" or "varying"
    faster_mode = TRUE,
    bstrap = FALSE,
  )

  # check if results are equal.
  expect_equal(att_slower$att, att_faster$att)
  expect_equal(att_slower$se, as.numeric(att_faster$se))

  # get event study estimates
  out1 = att_slower |>
    aggte(type = "dynamic",  cband = FALSE, bstrap = FALSE)

  out2 = att_faster |>
    aggte(type = "dynamic", cband = FALSE, bstrap = FALSE)

  # check if results are equal.
  expect_equal(out1$att.egt, out2$att.egt)
  expect_equal(out1$se.egt, as.numeric(out2$se.egt))

  # ----------------------------------------------------------------------------------------------------------------------------------
  # running the same but with unbalanced panel data
  att_slow <- att_gt(
    yname   = "y",
    tname   = "t",
    idname  = "id",
    gname   = "g",
    xformla = ~ x1 + x2 + x3 + x4,
    data    = dt_unbal,
    panel   = TRUE,
    allow_unbalanced_panel = TRUE,
    faster  = FALSE,
    bstrap = FALSE
  )

  expect_message(att_fast <- att_gt(
    yname   = "y",
    tname   = "t",
    idname  = "id",
    gname   = "g",
    xformla = ~ x1 + x2 + x3 + x4,
    data    = dt_unbal,
    panel   = TRUE,
    allow_unbalanced_panel = TRUE,
    faster  = TRUE,
    bstrap = FALSE
  ), "unbalanced panel")

  # check if results are equal.
  expect_equal(att_slow$att, att_fast$att)
  expect_equal(att_slow$se, as.numeric(att_fast$se))


  # get event study estimates
  out1 = att_slow |>
    aggte(type = "dynamic",  cband = FALSE, bstrap = FALSE)


  out2 = att_fast |>
    aggte(type = "dynamic", cband = FALSE, bstrap = FALSE)

  # check if results are equal.
  expect_equal(out1$att.egt, out2$att.egt)
  expect_equal(out1$se.egt, out2$se.egt, tol=.0005)

})


test_that("faster_mode = TRUE matches baseline on filtered sim dataset when there are not subsequent cohort and time periods", {
  # simulate full panel
  set.seed(09142024)
  sp <- reset.sim()
  dt <- build_sim_dataset(sp)

  # filter down to two periods
  # here we know build_sim_dataset() has periods 1:4, so pick 2 & 4
  # cohorts -> [3,4], periods -> [2,4]
  dt2 <- dt[dt$period %in% c(2, 4), ]

  # run att_gt with both modes (no errors)
  expect_warning({
    res_slow <- att_gt(
      yname         = "Y",
      tname         = "period",
      idname        = "id",
      gname         = "G",
      data          = dt2,
      panel         = TRUE,
      control_group = "nevertreated",
      xformla       = NULL,
      est_method    = "dr",
      base_period   = "universal",
      faster_mode   = FALSE
    )
    res_fast <- att_gt(
      yname         = "Y",
      tname         = "period",
      idname        = "id",
      gname         = "G",
      data          = dt2,
      panel         = TRUE,
      control_group = "nevertreated",
      xformla       = NULL,
      est_method    = "dr",
      base_period   = "universal",
      faster_mode   = TRUE
    )
  }, "Dropped 999 units that were already treated in the first period")


  # they should have the same length and (within tol) the same values
  expect_length(res_slow$att, 4)
  expect_length(res_fast$att, 4)
  expect_equal(res_slow$att, res_fast$att, tolerance = 1e-8)
})


#-----------------------------------------------------------------------------
# Regression tests for time indexing bug (Issue: faster_mode returns wrong time indices)
# These tests ensure that faster_mode=TRUE returns the same time periods (t field)
# as faster_mode=FALSE, not just the same ATT estimates
#-----------------------------------------------------------------------------

test_that("faster_mode time indexing matches baseline with repeated cross-sections", {
  set.seed(12345)

  # Create data with non-consecutive time periods to catch indexing bugs
  # Time periods: 2, 3, 4, 5, 6, 7, 8, 9 (starting at 2, not 1)
  n_groups <- 50
  n_periods <- 9

  df <- data.frame(
    g = rep(1:n_groups, each = n_periods),
    t = rep(1:n_periods, times = n_groups),
    gfe = rep(rnorm(n_groups, 0, 1), each = n_periods)
  )

  # Add time fixed effects
  tfe_vec <- 0.1 * rnorm(n_periods, 0, 1)
  df$tfe <- tfe_vec[df$t]

  # Assign treatment cohorts (using time values, not indices)
  df$cohort <- 0
  df$cohort[df$g >= 21 & df$g <= 30] <- 4
  df$cohort[df$g >= 11 & df$g <= 20] <- 6
  df$cohort[df$g >= 1 & df$g <= 10] <- 8

  # Generate outcomes
  df$y0 <- 0.1 * df$g + df$gfe + 0.1 * df$t + df$tfe + rnorm(nrow(df), 0, 1)
  df$y1 <- df$y0

  # Add treatment effects for each cohort
  treated_c4 <- df$cohort == 4 & df$t >= 4
  df$y1[treated_c4] <- df$y1[treated_c4] + 1.5 + 1.0 * (df$t[treated_c4] - 4)

  treated_c6 <- df$cohort == 6 & df$t >= 6
  df$y1[treated_c6] <- df$y1[treated_c6] + 1.0 + 0.7 * (df$t[treated_c6] - 6)

  treated_c8 <- df$cohort == 8 & df$t >= 8
  df$y1[treated_c8] <- df$y1[treated_c8] + 0.5 + 0.4 * (df$t[treated_c8] - 8)

  # Expand data (simulate varying sample sizes)
  set.seed(123)
  rep_by_g <- sapply(1:n_groups, function(j) floor(5 + 2 * (1.1)^j * runif(1)))
  df$rep_g <- rep_by_g[df$g]
  df <- df[rep(1:nrow(df), df$rep_g), ]

  rep_by_t <- sapply(1:n_periods, function(j) floor(2 + 3 * (1.2)^(n_periods - j) * runif(1)))
  df$rep_t <- rep_by_t[df$t]
  df <- df[rep(1:nrow(df), df$rep_t), ]

  df <- df[, !(names(df) %in% c("rep_g", "rep_t"))]

  # Treatment indicator and observed outcome
  df$tx <- (df$cohort > 0) & (df$t >= df$cohort)
  df$y <- ifelse(df$tx, df$y1, df$y0)
  df$id <- 1:nrow(df)

  # Estimate with both modes
  res_slow <- att_gt(
    yname = "y",
    tname = "t",
    gname = "cohort",
    idname = "id",
    data = df,
    est_method = "reg",
    panel = FALSE,
    bstrap = FALSE,
    cband = FALSE,
    faster_mode = FALSE
  )

  res_fast <- att_gt(
    yname = "y",
    tname = "t",
    gname = "cohort",
    idname = "id",
    data = df,
    est_method = "reg",
    panel = FALSE,
    bstrap = FALSE,
    cband = FALSE,
    faster_mode = TRUE
  )

  # Test that time indexing is identical (this was the bug)
  expect_equal(res_slow$t, res_fast$t,
               info = "Time periods (t field) should match between faster_mode=TRUE and FALSE")

  # Also verify group indexing
  expect_equal(res_slow$group, res_fast$group,
               info = "Group values should match between faster_mode=TRUE and FALSE")

  # Verify ATT estimates match
  expect_equal(res_slow$att, res_fast$att,
               info = "ATT estimates should match between faster_mode=TRUE and FALSE")

  # Verify standard errors match
  expect_equal(res_slow$se, as.numeric(res_fast$se),
               info = "Standard errors should match between faster_mode=TRUE and FALSE")
})


test_that("faster_mode time indexing matches baseline with panel data and varying base period", {
  set.seed(54321)

  # Use the package's built-in data with non-standard time periods
  data <- did::mpdta

  # Test with varying base period (default, where the bug was most obvious)
  res_slow <- att_gt(
    yname = "lemp",
    gname = "first.treat",
    idname = "countyreal",
    tname = "year",
    xformla = ~1,
    data = data,
    bstrap = FALSE,
    cband = FALSE,
    base_period = "varying",
    control_group = "nevertreated",
    est_method = "dr",
    faster_mode = FALSE
  )

  res_fast <- att_gt(
    yname = "lemp",
    gname = "first.treat",
    idname = "countyreal",
    tname = "year",
    xformla = ~1,
    data = data,
    bstrap = FALSE,
    cband = FALSE,
    base_period = "varying",
    control_group = "nevertreated",
    est_method = "dr",
    faster_mode = TRUE
  )

  # Critical test: time periods must match exactly
  expect_identical(res_slow$t, res_fast$t,
                   info = "Time periods must be identical between modes with varying base period")

  # Verify other components
  expect_identical(res_slow$group, res_fast$group)
  expect_equal(res_slow$att, res_fast$att)
  expect_equal(res_slow$se, as.numeric(res_fast$se))
})


test_that("faster_mode time indexing with non-consecutive time periods", {
  set.seed(99999)

  # Create data with gaps in time periods (e.g., 2000, 2002, 2005, 2007, 2010)
  # This tests that we use actual calendar times, not sequential indices
  time_periods <- c(2000, 2002, 2005, 2007, 2010)
  n_units <- 100

  # Create balanced panel
  df <- expand.grid(id = 1:n_units, year = time_periods)
  df$id <- as.integer(df$id)

  # Assign cohorts
  df$cohort <- 0
  df$cohort[df$id <= 30] <- 2005
  df$cohort[df$id > 30 & df$id <= 60] <- 2007
  # Rest are never treated (cohort = 0)

  # Generate outcomes
  df$y0 <- df$id * 0.1 + rnorm(nrow(df), 0, 0.5)
  df$y1 <- df$y0

  # Add treatment effects
  treated <- df$cohort > 0 & df$year >= df$cohort
  df$y1[treated] <- df$y1[treated] + 2.0

  # Observed outcome
  df$y <- ifelse(treated, df$y1, df$y0)

  # Estimate with both modes
  res_slow <- att_gt(
    yname = "y",
    tname = "year",
    idname = "id",
    gname = "cohort",
    data = df,
    panel = TRUE,
    bstrap = FALSE,
    cband = FALSE,
    est_method = "reg",
    faster_mode = FALSE
  )

  res_fast <- att_gt(
    yname = "y",
    tname = "year",
    idname = "id",
    gname = "cohort",
    data = df,
    panel = TRUE,
    bstrap = FALSE,
    cband = FALSE,
    est_method = "reg",
    faster_mode = TRUE
  )

  # The critical test: years should be actual calendar years, not indices
  expect_true(all(res_slow$t %in% time_periods),
              info = "Time values should be actual calendar years")
  expect_true(all(res_fast$t %in% time_periods),
              info = "Time values should be actual calendar years in faster_mode")

  # Time periods must match between modes
  expect_equal(res_slow$t, res_fast$t,
               info = "Time periods must match with non-consecutive calendar times")

  # Verify the minimum time period is correct (should be 2002, not 1 or 2000)
  # With varying base period, first time should be period after first period
  expect_true(min(res_slow$t) > min(time_periods),
              info = "With varying base period, should skip first time period")
  expect_equal(min(res_slow$t), min(res_fast$t))

  # Verify other components
  expect_equal(res_slow$group, res_fast$group)
  expect_equal(res_slow$att, res_fast$att)
  expect_equal(res_slow$se, as.numeric(res_fast$se))
})


test_that("faster_mode time indexing with universal base period", {
  set.seed(11111)

  # Simpler test with universal base period
  sp <- did::reset.sim(time.periods = 5)
  data <- did::build_sim_dataset(sp)

  res_slow <- att_gt(
    yname = "Y",
    xformla = ~X,
    data = data,
    tname = "period",
    idname = "id",
    gname = "G",
    est_method = "dr",
    base_period = "universal",
    bstrap = FALSE,
    faster_mode = FALSE
  )

  res_fast <- att_gt(
    yname = "Y",
    xformla = ~X,
    data = data,
    tname = "period",
    idname = "id",
    gname = "G",
    est_method = "dr",
    base_period = "universal",
    bstrap = FALSE,
    faster_mode = TRUE
  )

  # Time indexing must match
  expect_equal(res_slow$t, res_fast$t,
               info = "Time periods must match with universal base period")

  # Verify results are consistent
  expect_equal(res_slow$group, res_fast$group)
  expect_equal(res_slow$att, res_fast$att)
  expect_equal(res_slow$se, as.numeric(res_fast$se))
})

Try the did package in your browser

Any scripts or data that you put into this service are public.

did documentation built on June 13, 2026, 5:07 p.m.