tests/testthat/test_local_explanations.R

context("Test functions for local explanations")
# Binary classification -----
test_task <- makeClassifTask(data = training_dataset, target = "y")
lrn_nnet <- makeLearner("classif.nnet", predict.type = "prob")
set.seed(17)
cl_nnet <- train(lrn_nnet, test_task)
# Zbedne -----
smaller_validation <- filter(validation_dataset, fct != "6")
smaller_validation <- mutate(smaller_validation, fct = as.factor(as.character(fct)))
# Classification with 3 levels of response -----
## With factor variable
set.seed(17)
new_training_dataset <- mutate(training_dataset,
                               y = as.factor(as.character(rbinom(100, 2, 0.5))))
test_task2 <- makeClassifTask(data = new_training_dataset, target = "y")
set.seed(17)
cl_nnet2 <- train(lrn_nnet, test_task2)
## Without factor variable
test_task3 <- makeClassifTask(data = new_training_dataset[, -21], target = "y")
set.seed(17)
cl_nnet3 <- train(lrn_nnet, test_task3)
# Create explainers ----
egalitarian_explanation2 <- egalitarian(training_dataset,
                                        smaller_validation, "y", cl_nnet)

egalitarian_explanation3 <- egalitarian(new_training_dataset,
                                        new_training_dataset,
                                        "y", cl_nnet2)
# Create local explanation for each method ----
set.seed(17)
lm <- local_explanation(5, egalitarian_explanation2, "lime",
                        predict_function = function(x, y)
                          getPredictionResponse(predict(x, newdata = y)))
## live with binomial response
set.seed(17)
lv <- local_explanation(1, egalitarian_explanation2, "live",
                        size = 500,
                        predict_function = function(x, y, ...)
                          getPredictionResponse(predict(x, newdata = y)))
## live with multinomial response
set.seed(17)
lv2 <- local_explanation(5, egalitarian_explanation3, "live",
                         size = 500,
                         predict_function = function(x, y, ...)
                           getPredictionResponse(predict(x, newdata = y)))
set.seed(17)
bd <- local_explanation(5, egalitarian_explanation2, "breakDown",
                        predict_function = function(x, y, ...)
                          getPredictionProbabilities(predict(x, newdata = y)))
## Gower distances
dists <- distances(5, egalitarian_explanation2)
# Tests ----
testthat::test_that("Throw an error when method is incorrect", {
  testthat::expect_error(local_explanation(5, egalitarian_explanation2, "method"))
})
testthat::test_that("All methods work", {
  testthat::expect_is(bd, "egalitarian_local_explanation")
  testthat::expect_is(lv, "egalitarian_local_explanation")
  testthat::expect_is(lv2, "egalitarian_local_explanation")
  testthat::expect_is(lm, "egalitarian_local_explanation")
})
testthat::test_that("live explainer works with different factor levels", {
  set.seed(17)
  testthat::expect_is(live_extractor(validation_dataset[4, ],
                                     "y", training_dataset,
                                     cl_nnet,
                                     predict_function = function(x, y, ...)
                                       getPredictionResponse(
                                         predict(x, newdata = y)),
                                     size = 500),
                      "list")
})
testthat::test_that("live explainer works with multinomial response", {
  set.seed(17)
  testthat::expect_is(live_extractor(new_training_dataset[1, ],
                                     "y", new_training_dataset,
                                     cl_nnet2,
                                     predict_function = function(x, y, ...)
                                       getPredictionResponse(
                                         predict(x, newdata = y)),
                                     size = 500),
                      "list")

})
testthat::test_that("live explainer works with strictly numerical data", {
  set.seed(17)
  testthat::expect_is(live_extractor(new_training_dataset[1, ],
                                     "y", new_training_dataset,
                                     cl_nnet2,
                                     predict_function = function(x, y, ...)
                                       getPredictionResponse(
                                         predict(x, newdata = y)),
                                     size = 500),
                      "list")
  set.seed(17)
  testthat::expect_is(live_extractor(new_training_dataset[4, -21],
                                     "y", new_training_dataset[, -21],
                                     cl_nnet3,
                                     predict_function = function(x, y, ...)
                                       getPredictionResponse(
                                         predict(x, newdata = y)),
                                     size = 500),
                      "list")
})
testthat::test_that("live explainer throws an error with correct prediction", {
  set.seed(17)
  testthat::expect_error(live_extractor(new_training_dataset[3, ],
                                        "y", new_training_dataset,
                                        cl_nnet2,
                                        predict_function = function(x, y, ...)
                                          getPredictionResponse(
                                            predict(x, newdata = y)),
                                        size = 500))
})

testthat::test_that("Gower distances are calculated", {
  testthat::expect_is(dists, "egalitarian_similarity")
  testthat::expect_is(dists$distances, "numeric")
})

testthat::test_that("Generic and plotting functions work", {
  testthat::expect_output(print(bd))
  testthat::expect_output(print(lm))
  testthat::expect_output(print(lv))
  testthat::expect_output(plot(bd), regexp = NA)
  testthat::expect_output(plot(lm), regexp = NA)
  testthat::expect_output(plot(lv), regexp = NA)
  testthat::expect_output(print(dists))
  testthat::expect_output(plot(dists, "histogram"), regexp = NA)
  testthat::expect_output(plot(dists, "density"), regexp = NA)
  testthat::expect_output(plot(dists, "boxplot"), regexp = NA)
  testthat::expect_output(plot(bd, "histogram"), regexp = NA)
  testthat::expect_output(plot(bd, "density"), regexp = NA)
  testthat::expect_output(plot(bd, "boxplot"), regexp = NA)
  testthat::expect_output(plot_local_explanation(lm), regexp = NA)
  testthat::expect_output(plot_local_explanation(lv), regexp = NA)
  testthat::expect_output(plot_local_explanation(bd), regexp = NA)
})
mstaniak/egalitaRian documentation built on Aug. 26, 2019, 11:11 p.m.