context("Test functions for local explanations")
# Binary classification -----
test_task <- makeClassifTask(data = training_dataset, target = "y")
lrn_nnet <- makeLearner("classif.nnet", predict.type = "prob")
set.seed(17)
cl_nnet <- train(lrn_nnet, test_task)
# Zbedne -----
smaller_validation <- filter(validation_dataset, fct != "6")
smaller_validation <- mutate(smaller_validation, fct = as.factor(as.character(fct)))
# Classification with 3 levels of response -----
## With factor variable
set.seed(17)
new_training_dataset <- mutate(training_dataset,
y = as.factor(as.character(rbinom(100, 2, 0.5))))
test_task2 <- makeClassifTask(data = new_training_dataset, target = "y")
set.seed(17)
cl_nnet2 <- train(lrn_nnet, test_task2)
## Without factor variable
test_task3 <- makeClassifTask(data = new_training_dataset[, -21], target = "y")
set.seed(17)
cl_nnet3 <- train(lrn_nnet, test_task3)
# Create explainers ----
egalitarian_explanation2 <- egalitarian(training_dataset,
smaller_validation, "y", cl_nnet)
egalitarian_explanation3 <- egalitarian(new_training_dataset,
new_training_dataset,
"y", cl_nnet2)
# Create local explanation for each method ----
set.seed(17)
lm <- local_explanation(5, egalitarian_explanation2, "lime",
predict_function = function(x, y)
getPredictionResponse(predict(x, newdata = y)))
## live with binomial response
set.seed(17)
lv <- local_explanation(1, egalitarian_explanation2, "live",
size = 500,
predict_function = function(x, y, ...)
getPredictionResponse(predict(x, newdata = y)))
## live with multinomial response
set.seed(17)
lv2 <- local_explanation(5, egalitarian_explanation3, "live",
size = 500,
predict_function = function(x, y, ...)
getPredictionResponse(predict(x, newdata = y)))
set.seed(17)
bd <- local_explanation(5, egalitarian_explanation2, "breakDown",
predict_function = function(x, y, ...)
getPredictionProbabilities(predict(x, newdata = y)))
## Gower distances
dists <- distances(5, egalitarian_explanation2)
# Tests ----
testthat::test_that("Throw an error when method is incorrect", {
testthat::expect_error(local_explanation(5, egalitarian_explanation2, "method"))
})
testthat::test_that("All methods work", {
testthat::expect_is(bd, "egalitarian_local_explanation")
testthat::expect_is(lv, "egalitarian_local_explanation")
testthat::expect_is(lv2, "egalitarian_local_explanation")
testthat::expect_is(lm, "egalitarian_local_explanation")
})
testthat::test_that("live explainer works with different factor levels", {
set.seed(17)
testthat::expect_is(live_extractor(validation_dataset[4, ],
"y", training_dataset,
cl_nnet,
predict_function = function(x, y, ...)
getPredictionResponse(
predict(x, newdata = y)),
size = 500),
"list")
})
testthat::test_that("live explainer works with multinomial response", {
set.seed(17)
testthat::expect_is(live_extractor(new_training_dataset[1, ],
"y", new_training_dataset,
cl_nnet2,
predict_function = function(x, y, ...)
getPredictionResponse(
predict(x, newdata = y)),
size = 500),
"list")
})
testthat::test_that("live explainer works with strictly numerical data", {
set.seed(17)
testthat::expect_is(live_extractor(new_training_dataset[1, ],
"y", new_training_dataset,
cl_nnet2,
predict_function = function(x, y, ...)
getPredictionResponse(
predict(x, newdata = y)),
size = 500),
"list")
set.seed(17)
testthat::expect_is(live_extractor(new_training_dataset[4, -21],
"y", new_training_dataset[, -21],
cl_nnet3,
predict_function = function(x, y, ...)
getPredictionResponse(
predict(x, newdata = y)),
size = 500),
"list")
})
testthat::test_that("live explainer throws an error with correct prediction", {
set.seed(17)
testthat::expect_error(live_extractor(new_training_dataset[3, ],
"y", new_training_dataset,
cl_nnet2,
predict_function = function(x, y, ...)
getPredictionResponse(
predict(x, newdata = y)),
size = 500))
})
testthat::test_that("Gower distances are calculated", {
testthat::expect_is(dists, "egalitarian_similarity")
testthat::expect_is(dists$distances, "numeric")
})
testthat::test_that("Generic and plotting functions work", {
testthat::expect_output(print(bd))
testthat::expect_output(print(lm))
testthat::expect_output(print(lv))
testthat::expect_output(plot(bd), regexp = NA)
testthat::expect_output(plot(lm), regexp = NA)
testthat::expect_output(plot(lv), regexp = NA)
testthat::expect_output(print(dists))
testthat::expect_output(plot(dists, "histogram"), regexp = NA)
testthat::expect_output(plot(dists, "density"), regexp = NA)
testthat::expect_output(plot(dists, "boxplot"), regexp = NA)
testthat::expect_output(plot(bd, "histogram"), regexp = NA)
testthat::expect_output(plot(bd, "density"), regexp = NA)
testthat::expect_output(plot(bd, "boxplot"), regexp = NA)
testthat::expect_output(plot_local_explanation(lm), regexp = NA)
testthat::expect_output(plot_local_explanation(lv), regexp = NA)
testthat::expect_output(plot_local_explanation(bd), regexp = NA)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.