fit1 <- stats::lm(Sepal.Length ~ Petal.Width, data = iris)
fit2 <- stats::lm(Sepal.Length ~ Petal.Width + Species + Sepal.Width, data = iris)
fl1 <- flashlight(model = fit1, label = "small", data = iris, y = "Sepal.Length")
fl2 <- flashlight(model = fit2, label = "large", data = iris, y = "Sepal.Length")
fls <- multiflashlight(list(fl1, fl2))
test_that("basic functionality works", {
br <- light_breakdown(fl1, iris[1L, ])
dat <- br$data
expect_equal(dat$before_[-2L], dat$after_[-2L])
expect_equal(dat$after_[2L] - dat$before_[2L], -0.8879879, tolerance = 0.001)
expect_s3_class(plot(br), "ggplot")
})
test_that("light_breakdown reacts on v", {
dat <- light_breakdown(fl1, iris[1L, ], v = "Petal.Width")$data
expect_equal(nrow(dat), 3L)
})
test_that("light_breakdown reacts on visit_strategy v", {
br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "v")
dat <- br$data
expect_equal(dat$before_[-4L], dat$after_[-4L])
expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001)
expect_s3_class(plot(br), "ggplot")
})
test_that("light_breakdown reacts on visit_strategy shap", {
br <- light_breakdown(fl1, iris[1L, ], visit_strategy = "permutation", seed = 1L)
dat <- br$data
expect_equal(dat$before_[-4L], dat$after_[-4L])
expect_equal(dat$after_[4L] - dat$before_[4L], -0.8879879, tolerance = 0.001)
expect_s3_class(plot(br), "ggplot")
})
test_that("light_breakdown reacts on weights", {
br <- light_breakdown(flashlight(fl1, w = "Petal.Length"), iris[1L, ])
dat <- br$data
expect_equal(dat$before_[-2L], dat$after_[-2L])
expect_equal(dat$after_[2L] - dat$before_[2L], -1.192293, tolerance = 0.001)
expect_s3_class(plot(br), "ggplot")
})
test_that("light_breakdown reacts on by", {
br <- light_breakdown(flashlight(fl1, by = "Species"), iris[1L, ])
dat <- br$data
expect_equal(dat$before_[-2L], dat$after_[-2L])
expect_equal(dat$after_[2L] - dat$before_[2L], -0.04087469, tolerance = 0.001)
expect_s3_class(plot(br), "ggplot")
})
test_that("light_breakdown reacts on multiflashlight", {
br <- light_breakdown(fls, iris[1L, ])
dat <- br$data
expect_equal(nrow(dat), 6L * 2L)
expect_false(all(dat[1:6, ] == dat[7:12, ]))
expect_s3_class(plot(br), "ggplot")
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.