tests/testthat/tests-profile2d.R

fit <- stats::lm(Sepal.Length ~ Species, data = iris)
fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length")

test_that("light_profile2d works correctly for type predicted", {
  pr <- light_profile2d(fl, v = c("Species", "Petal.Width"), type = "predicted")
  expect_equal(
    pr$data[!duplicated(pr$data$Species), ]$value_,
    stats::aggregate(Sepal.Length ~ Species, data = iris, FUN = mean)$Sepal.Length
  )
  expect_s3_class(plot(pr), "ggplot")
})

test_that("predicted + residuals = response", {
  v <- c("Species", "Petal.Length")
  pr_res <- light_profile2d(fl, v = v, type = "residual", n_bins = 3)
  pr_pred <- light_profile2d(fl, v = v, type = "predicted", n_bins = 3)
  pr_y <- light_profile2d(fl, v = v, type = "response", n_bins = 3)
  expect_equal(pr_res$data$value_ + pr_pred$data$value_, pr_y$data$value_)
  expect_s3_class(plot(pr_res), "ggplot")})

test_that("partial dependence is constant if covariable not in model", {
  pr <- light_profile2d(fl, v = c("Petal.Length", "Petal.Width"))
  expect_true(var(pr$data$value_) == 0)
})

fit <- stats::lm(Sepal.Length ~ ., data = iris)
fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length")

test_that("partial dependence is the same as ice", {
  v <- c("Petal.Width", "Petal.Length")
  a_grid <- expand.grid(list(Petal.Width = 1:2, Petal.Length = 2:3))
  pr <- light_profile2d(fl, pd_indices = 1, pd_grid = a_grid)$data
  pr <- pr[order(pr$Petal.Width, pr$Petal.Length), ]
  ice <- light_ice(fl, indices = 1, grid = a_grid)$data
  ice <- ice[order(ice$Petal.Width, ice$Petal.Length), ]
  expect_equal(pr$value_, unname(ice$value_))
})

test_that("n_bins work approximately (one value)", {
  pr <- light_profile2d(fl, v = c("Petal.Width", "Petal.Length"), n_bins = 3)
  expect_lt(nrow(pr$data), 20L)
})

test_that("n_bins work exactly with quantile cuts", {
  pr <- light_profile2d(
    fl, v = c("Petal.Width", "Petal.Length"), n_bins = 3, cut_type = "quantile"
  )
  expect_equal(nrow(pr$data), 9L)
})

test_that("n_bins work (two values)", {
  v <- c("Petal.Width", "Petal.Length")
  pr1 <- light_profile2d(fl, v = v, n_bins = 3:2, cut_type = "quantile")
  pr2 <- light_profile2d(fl, v = v, n_bins = 3, cut_type = "quantile")
  pr3 <- light_profile2d(fl, v = v, n_bins = c(3, 3), cut_type = "quantile")
  expect_equal(dim(pr2$data), dim(pr3$data))
  expect_equal(nrow(pr1$data) + 3L, nrow(pr2$data))
})

test_that("cut_type reacts", {
  v <- c("Petal.Width", "Petal.Length")
  pr1 <- light_profile2d(fl, v = v, n_bins = 3, cut_type = "equal")
  pr2 <- light_profile2d(fl, v = v, n_bins = 3, cut_type = "quantile")
  pr3 <- light_profile2d(fl, v = v, n_bins = 3, cut_type = c("quantile", "equal"))
  expect_false(pr1$data[1L, 1L] == pr2$data[1L, 1L])
  expect_true(pr2$data[1L, 1L] == pr3$data[1L, 1L])
})

test_that("argument 'breaks' works for one variable", {
  v <- c("Petal.Width", "Species")
  pr <- light_profile2d(fl, v = v, breaks = list(Petal.Width = 0:3))
  expect_true(length(unique(pr$data$Petal.Width)) == 3L)
  expect_s3_class(plot(pr), "ggplot")
})

test_that("argument 'breaks' works for two variables", {
  v <- c("Petal.Width", "Petal.Length")
  pr <- light_profile2d(fl, v = v, breaks = list(
      Petal.Width = 0:3, Petal.Length = c(1, 4, 7)
    )
  )
  expect_true(nrow(pr$data) == 6L)
  expect_s3_class(plot(pr), "ggplot")
})

test_that("argument 'pd_evaluate_at' works for one variable", {
  v <- c("Petal.Width", "Species")
  pr <- light_profile2d(fl, v = v, pd_evaluate_at = list(Petal.Width = 1:2))
  expect_true(length(unique(pr$data$Petal.Width)) == 2L)
  expect_s3_class(plot(pr), "ggplot")
})

test_that("argument 'pd_evaluate_at' works for two variables", {
  v <- c("Petal.Width", "Petal.Length")
  pr <- light_profile2d(
    fl, v = v, pd_evaluate_at = list(Petal.Width = 1:2, Petal.Length = 2:3)
  )
  expect_true(nrow(pr$data) == 4L)
})

test_that("argument 'pd_grid' works for two variables", {
  v <- c("Petal.Width", "Petal.Length")
  a_list <- list(Petal.Width = 1:2, Petal.Length = 2:3)
  pr1 <- light_profile2d(fl, v = v, pd_evaluate_at = a_list)
  pr2 <- light_profile2d(fl, v = v, pd_grid = expand.grid(a_list))
  expect_equal(dim(pr1$data), dim(pr2$data))
})

test_that("label construction can be controlled by '...'", {
  v <- c("Petal.Width", "Petal.Length")
  pr <- light_profile2d(
    fl, v = v, n_bins = 3, cut_type = "equal", type = "response", sep = ";")
  expect_true(grepl(";", pr$data$Petal.Width[1L]))
})

test_that("light_profile2d reacts on 'by' variable", {
  v <- c("Petal.Width", "Petal.Length")
  at <- list(Petal.Width = 1:2, Petal.Length = 2:3)
  fl1 <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length")
  fl2 <- flashlight(fl1, data = iris[iris$Species == "setosa", ])
  pd1 <- light_profile2d(fl1, v = v, by = "Species", pd_evaluate_at = at)
  pd2 <- light_profile2d(fl2, v = v, pd_evaluate_at = at)
  expect_s3_class(plot(pd1), "ggplot")
  pd1$data <- pd1$data[pd1$data$Species == "setosa", ]
  expect_equal(pd1$data$value_, pd2$data$value_)
})

test_that("argument 'counts' works", {
  v <- c("Petal.Width", "Species")
  pr <- light_profile2d(
    fl, v = v, type = "response", breaks = list(Petal.Width = 0:3))
  pr_nocounts <- light_profile2d(
    fl, v = v, type = "response", breaks = list(Petal.Width = 0:3), counts = FALSE
  )
  cnt <- pr$data[pr$data$Petal.Width == "[0, 1]" &
                   pr$data$Species == "versicolor", ]$counts_
  ref <- nrow(iris[iris$Petal.Width <= 1 & iris$Species == "versicolor", ])
  expect_equal(cnt, ref)
  expect_true("counts_" %in% colnames(pr$data))
  expect_false("counts_" %in% colnames(pr_nocounts$data))
})

test_that("argument 'linkinv' works", {
  v <- c("Petal.Width", "Species")
  pr <- light_profile2d(fl, v = v, type = "response", breaks = list(Petal.Width = 0:3))
  pr_log <- light_profile2d(
    flashlight(fl, linkinv = log),
    v = v,
    type = "response",
    breaks = list(Petal.Width = 0:3)
  )
  pr_log2 <- light_profile2d(
    flashlight(fl, linkinv = log),
    v = v,
    type = "response",
    use_linkinv = FALSE,
    breaks = list(Petal.Width = 0:3)
  )

  expect_true(all(pr_log$data$value_ < pr$data$value_))
  expect_equal(pr_log2$data, pr$data)
})

fit1 <- stats::lm(Sepal.Length ~ Species + Petal.Length, data = iris)
fl1 <- flashlight(model = fit1, label = "lm", data = iris, y = "Sepal.Length")
fit2 <- stats::lm(Sepal.Length ~ 1, data = iris)
fl2 <- flashlight(model = fit2, label = "Empty", data = iris, y = "Sepal.Length")
fls <- multiflashlight(list(fl1, fl2))
v <- c("Species", "Petal.Length")

test_that("multiflashlight works for partial dependence", {
  multi <- light_profile2d(fls, v = v, n_bins = 3)
  single <- light_profile2d(fls$lm, v = v, n_bins = 3)
  expect_true(
    all(multi$data[multi$data$label == "lm", ]$value_ == single$data$value_)
  )
  expect_s3_class(plot(multi), "ggplot")})

test_that("multiflashlight works for type 'response'", {
  multi <- light_profile2d(fls, v = v, n_bins = 3, type = "response")
  single <- light_profile2d(fls$lm, v = v, n_bins = 3, type = "response")
  expect_true(
    all(multi$data[multi$data$label == "lm", ]$value_ == single$data$value_))
  expect_s3_class(plot(multi), "ggplot")})

test_that("multiflashlight works for type 'predicted'", {
  multi <- light_profile2d(fls, v = v, n_bins = 3, type = "predicted")
  single <- light_profile2d(fls$lm, v = v, n_bins = 3, type = "predicted")
  expect_true(all(multi$data[multi$data$label == "lm", ]$value_ == single$data$value_))
  expect_s3_class(plot(multi), "ggplot")
})

v <- c("Petal.Width", "Petal.Length")

test_that("multiflashlight works with 'breaks', 'n_bins', 'cut_type'", {
  multi <- light_profile2d(
    fls,
    v = v,
    n_bins = c(8, 2),
    cut_type = "quantile",
    breaks = list(Petal.Length = c(0, 4, 8))
  )
  expect_true(all(multi$data$Petal.Length %in% c(2, 6)))
  length(unique(multi$data$Petal.Width)) == 8
})

test_that("multiflashlight works with 'evaluate_at'", {
  at <- list(Petal.Length = c(0, 4, 8), Petal.Width = 1:2)
  multi <- light_profile2d(
    fls, v = v, n_bins = c(8, 2), cut_type = "quantile", pd_evaluate_at = at
  )
  expect_true(all(multi$data$Petal.Length %in% c(0, 4, 8)))
  length(unique(multi$data$Petal.Width)) == 2
})

test_that("multiflashlight works with 'by' variable", {
  multi_by <- light_profile2d(fls, v = v, n_bins = 3, by = "Species")
  expect_s3_class(plot(multi_by), "ggplot")
})

test_that("light_profile2d uses weights and counts_weighted correctly", {
  ir <- iris
  ir$pw <- ir$Petal.Width > 1
  fit <- stats::lm(Sepal.Length ~ 1, data = ir)
  fl <- flashlight(
    model = fit, label = "empty", data = ir, y = "Sepal.Length", w = "Sepal.Width"
  )
  pr <- light_profile2d(fl, v = c("Species", "pw"), type = "response")
  prc <- light_profile2d(
    fl, v = c("Species", "pw"), type = "response", counts_weighted = TRUE
  )
  out <- pr$data
  reference <- with(
    ir[ir$Species == out$Species[1L] & ir$pw == out$pw[1L], ],
    weighted.mean(Sepal.Length, Sepal.Width)
  )
  expect_equal(out$value_[1L], reference)
  expect_false(all(out$counts_ == prc$data$counts_))
})
mayer79/flashlight documentation built on Feb. 13, 2024, 1:09 p.m.