In the following, we explain the counterfactuals workflow for both a classification and a regression task using
concrete use cases.
# NOT_CRAN <- identical(tolower(Sys.getenv("NOT_CRAN")), "true") knitr::opts_chunk$set( collapse = TRUE, fig.width = 7, fig.height = 3, comment = "#>" # purl = NOT_CRAN, # eval = NOT_CRAN ) options(width = 200)
library("counterfactuals") library("iml") library("rpart")
knitr::opts_chunk$set(echo = TRUE)
The Predictor class of the iml package provides the necessary flexibility
to cover classification and regression models fitted with diverse R packages.
In the introduction vignette, we saw models fitted with the mlr3 and randomForest packages.
In the following, we show extensions to - an classification tree fitted with
the caret package, the mlr (a predecesor of mlr3) and tidymodels.
For each model we generate counterfactuals for the 100th row of the plasma dataset of the gamlss.data package
using the WhatIf method.
data(plasma, package = "gamlss.data") x_interest = plasma[100L,]
library("caret") treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart", tuneGrid = data.frame(cp = 0.01)) predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma") predcaret$predict(x_interest) nicecaret = NICERegr$new(predcaret, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
library("tidymodels") treetm = decision_tree(mode = "regression", engine = "rpart") %>% fit(retplasma ~ ., data = plasma[-100L,]) predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma") predtm$predict(x_interest) nicetm = NICERegr$new(predtm, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
library("mlr") task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma") mod = mlr::makeLearner("regr.rpart") treemlr = mlr::train(mod, task) predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma") predmlr$predict(x_interest) nicemlr = NICERegr$new(predmlr, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
treerpart = rpart(retplasma ~ ., data = plasma[-100L,]) predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma") predrpart$predict(x_interest) nicerpart = NICERegr$new(predrpart, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.