In the following, we explain the counterfactuals
workflow for both a classification and a regression task using
concrete use cases.
# NOT_CRAN <- identical(tolower(Sys.getenv("NOT_CRAN")), "true") knitr::opts_chunk$set( collapse = TRUE, fig.width = 7, fig.height = 3, comment = "#>" # purl = NOT_CRAN, # eval = NOT_CRAN ) options(width = 200)
library("counterfactuals") library("iml") library("rpart")
knitr::opts_chunk$set(echo = TRUE)
The Predictor
class of the iml
package provides the necessary flexibility
to cover classification and regression models fitted with diverse R packages.
In the introduction vignette, we saw models fitted with the mlr3
and randomForest
packages.
In the following, we show extensions to - an classification tree fitted with
the caret
package, the mlr
(a predecesor of mlr3
) and tidymodels
.
For each model we generate counterfactuals for the 100th row of the plasma dataset of the gamlss.data
package
using the WhatIf
method.
data(plasma, package = "gamlss.data") x_interest = plasma[100L,]
library("caret") treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart", tuneGrid = data.frame(cp = 0.01)) predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma") predcaret$predict(x_interest) nicecaret = NICERegr$new(predcaret, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
library("tidymodels") treetm = decision_tree(mode = "regression", engine = "rpart") %>% fit(retplasma ~ ., data = plasma[-100L,]) predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma") predtm$predict(x_interest) nicetm = NICERegr$new(predtm, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
library("mlr") task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma") mod = mlr::makeLearner("regr.rpart") treemlr = mlr::train(mod, task) predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma") predmlr$predict(x_interest) nicemlr = NICERegr$new(predmlr, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
treerpart = rpart(retplasma ~ ., data = plasma[-100L,]) predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma") predrpart$predict(x_interest) nicerpart = NICERegr$new(predrpart, optimization = "proximity", margin_correct = 0.5, return_multiple = FALSE) nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.