tests/testthat/tests-breakdown.R

fit1 <- stats::lm(Sepal.Length ~ Petal.Width, data = iris)
fit2 <- stats::lm(Sepal.Length ~ Petal.Width + Species + Sepal.Width, data = iris)
fl1 <- flashlight(model = fit1, label = "small", data = iris, y = "Sepal.Length")
fl2 <- flashlight(model = fit2, label = "large", data = iris, y = "Sepal.Length")
fls <- multiflashlight(list(fl1, fl2))

test_that("basic functionality works", {
  br <- light_breakdown(fl1, iris[1L, ])
  dat <- br$data
  expect_equal(dat$before_[-2L], dat$after_[-2L])
  expect_equal(dat$after_[2L] - dat$before_[2L], -0.8879879, tolerance = 0.001)
  expect_s3_class(plot(br), "ggplot")
})

test_that("light_breakdown reacts on v", {
  dat <- light_breakdown(fl1, iris[1L, ], v = "Petal.Width")$data
  expect_equal(nrow(dat), 3L)
})

test_that("light_breakdown reacts on visit_strategy v", {
  br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "v")
  dat <- br$data
  expect_equal(dat$before_[-4L], dat$after_[-4L])
  expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001)
  expect_s3_class(plot(br), "ggplot")
})

test_that("light_breakdown reacts on visit_strategy shap", {
  br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "permutation", seed = 1L)
  dat <- br$data
  expect_equal(dat$before_[-4L], dat$after_[-4L])
  expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001)
  expect_s3_class(plot(br), "ggplot")
})

test_that("light_breakdown reacts on weights", {
  br <- light_breakdown(flashlight(fl1, w = "Petal.Length"), iris[1L, ])
  dat <- br$data
  expect_equal(dat$before_[-2L], dat$after_[-2L])
  expect_equal(dat$after_[2L] - dat$before_[2L], -1.192293, tolerance = 0.001)
  expect_s3_class(plot(br), "ggplot")
})

test_that("light_breakdown reacts on by", {
  br <- light_breakdown(flashlight(fl1, by = "Species"), iris[1L, ])
  dat <- br$data
  expect_equal(dat$before_[-2L], dat$after_[-2L])
  expect_equal(dat$after_[2L] - dat$before_[2L], -0.04087469, tolerance = 0.001)
  expect_s3_class(plot(br), "ggplot")
})

test_that("light_breakdown reacts on multiflashlight", {
  br <- light_breakdown(fls, iris[1L, ])
  dat <- br$data
  expect_equal(nrow(dat), 6L * 2L)
  expect_false(all(dat[1:6, ] == dat[7:12, ]))
  expect_s3_class(plot(br), "ggplot")
})
mayer79/flashlight documentation built on July 26, 2024, 6:35 a.m.