pred_nestcv_glmnet: Prediction wrappers to use fastshap with nestedcv

View source: R/shap.R

pred_nestcv_glmnetR Documentation

Prediction wrappers to use fastshap with nestedcv

Description

Prediction wrapper functions to enable the use of the fastshap package for generating SHAP values from nestedcv trained models.

Usage

pred_nestcv_glmnet(x, newdata)

pred_nestcv_glmnet_class1(x, newdata)

pred_nestcv_glmnet_class2(x, newdata)

pred_nestcv_glmnet_class3(x, newdata)

pred_train(x, newdata)

pred_train_class1(x, newdata)

pred_train_class2(x, newdata)

pred_train_class3(x, newdata)

Arguments

x

a nestcv.glmnet or nestcv.train object

newdata

a matrix of new data

Details

These prediction wrapper functions are designed to be used with the fastshap package. The functions pred_nestcv_glmnet and pred_train work for nestcv.glmnet and nestcv.train models respectively for either binary classification or regression.

For multiclass classification use pred_nestcv_glmnet_class1, 2 and 3 for the first 3 classes. Similarly pred_train_class1 etc for nestcv.train objects. These functions can be inspected and easily modified to analyse further classes.

Value

prediction wrapper function designed for use with fastshap::explain()

Examples

library(fastshap)

# Boston housing dataset
library(mlbench)
data(BostonHousing2)
dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select = -c(cmedv, medv, town, chas))

# Fit a glmnet model using nested CV
# Only 3 outer CV folds and 1 alpha value for speed
fit <- nestcv.glmnet(y, x, family = "gaussian", n_outer_folds = 3, alphaSet = 1)

# Generate SHAP values using fastshap::explain
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim = 1)

# Plot overall variable importance
plot_shap_bar(sh, x)

# Plot beeswarm plot
plot_shap_beeswarm(sh, x, size = 1)


nestedcv documentation built on Oct. 26, 2023, 5:08 p.m.