inst/doc/nestedcv_shap.R

## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  warning = FALSE
)

## -----------------------------------------------------------------------------
library(nestedcv)
library(mlbench)  # Boston housing dataset

data(BostonHousing2)
dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select = -c(cmedv, medv, town, chas))

# Fit a glmnet model using nested CV
set.seed(1, "L'Ecuyer-CMRG")
fit <- nestcv.glmnet(y, x, family = "gaussian",
                     min_1se = 1, alphaSet = 1, cv.cores = 2)
vs <- var_stability(fit)
vs

## -----------------------------------------------------------------------------
plot_var_stability(fit)

## ----fig.width = 9, fig.height = 5--------------------------------------------
# overlay directionality using colour
p1 <- plot_var_stability(fit, final = FALSE, direction = 1)

# show directionality with the sign of the variable importance
p2 <- plot_var_stability(fit, final = FALSE, percent = F)

ggpubr::ggarrange(p1, p2, ncol=2)

## ----eval=FALSE---------------------------------------------------------------
#  # change bubble colour scheme
#  p1 + scale_fill_manual(values=c("orange", "green3"))

## -----------------------------------------------------------------------------
library(fastshap)

# Generate SHAP values using fastshap::explain
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim = 5)

# Plot overall variable importance
plot_shap_bar(sh, x)

## -----------------------------------------------------------------------------
# Plot beeswarm plot
plot_shap_beeswarm(sh, x, size = 1)

## -----------------------------------------------------------------------------
# Only 3 outer folds to speed up process
fit <- nestcv.train(y, x,
                    method = "gbm",
                    n_outer_folds = 3, cv.cores = 2)

# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_train, nsim = 5)
plot_shap_beeswarm(sh, x, size = 1)

## ----fig.width = 9, fig.height = 3.5------------------------------------------
library(ggplot2)
data("iris")
dat <- iris
y <- dat$Species
x <- dat[, 1:4]

# Only 3 outer folds to speed up process
fit <- nestcv.glmnet(y, x, family = "multinomial", n_outer_folds = 3, alphaSet = 0.6)

# SHAP values for each of the 3 classes
sh1 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class1, nsim = 5)
sh2 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class2, nsim = 5)
sh3 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class3, nsim = 5)

s1 <- plot_shap_bar(sh1, x, sort = FALSE) +
  ggtitle("Setosa")
s2 <- plot_shap_bar(sh2, x, sort = FALSE) +
  ggtitle("Versicolor")
s3 <- plot_shap_bar(sh3, x, sort = FALSE) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "bottom", common.legend = TRUE)

## ----fig.width = 9.5, fig.height = 3.5----------------------------------------
s1 <- plot_shap_beeswarm(sh1, x, sort = FALSE, cex = 0.7) +
  ggtitle("Setosa")
s2 <- plot_shap_beeswarm(sh2, x, sort = FALSE, cex = 0.7) +
  ggtitle("Versicolor")
s3 <- plot_shap_beeswarm(sh3, x, sort = FALSE, cex = 0.7) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "right", common.legend = TRUE)

Try the nestedcv package in your browser

Any scripts or data that you put into this service are public.

nestedcv documentation built on Oct. 26, 2023, 5:08 p.m.