model_support: Methods for extending limes model support

model_supportR Documentation

Methods for extending limes model support

Description

In order to have lime support for your model of choice lime needs to be able to get predictions from the model in a standardised way, and it needs to be able to know whether it is a classification or regression model. For the former it calls the predict_model() generic which the user is free to supply methods for without overriding the standard predict() method. For the latter the model must respond to the model_type() generic.

Usage

predict_model(x, newdata, type, ...)

model_type(x, ...)

Arguments

x

A model object

newdata

The new observations to predict

type

Either 'raw' to indicate predicted values, or 'prob' to indicate class probabilities

...

passed on to predict method

Value

A data.frame in the case of predict_model(). If type = 'raw' it will contain one column named 'Response' holding the predicted values. If type = 'prob' it will contain a column for each of the possible classes named after the class, each column holding the probability score for class membership. For model_type() a character string. Either 'regression' or 'classification' is currently supported.

Supported Models

Out of the box, lime supports the following model objects:

  • train from caret

  • WrappedModel from mlr

  • xgb.Booster from xgboost

  • H2OModel from h2o

  • keras.engine.training.Model from keras

  • lda from MASS (used for low-dependency examples)

If your model is not one of the above you'll need to implement support yourself. If the model has a predict interface mimicking that of predict.train() from caret, it will be enough to wrap your model in as_classifier()/as_regressor() to gain support. Otherwise you'll need need to implement a predict_model() method and potentially a model_type() method (if the latter is omitted the model should be wrapped in as_classifier()/as_regressor(), everytime it is used in lime()).

Examples

# Example of adding support for lda models (already available in lime)
predict_model.lda <- function(x, newdata, type, ...) {
  res <- predict(x, newdata = newdata, ...)
  switch(
    type,
    raw = data.frame(Response = res$class, stringsAsFactors = FALSE),
    prob = as.data.frame(res$posterior, check.names = FALSE)
  )
}

model_type.lda <- function(x, ...) 'classification'


lime documentation built on Aug. 19, 2022, 9:07 a.m.