View source: R/shapley.row.plot.R
| shapley.row.plot | R Documentation |
Computes and visualizes Weighted Mean SHAP contributions (WMSHAP) for a single row
(subject/observation) across multiple models in a shapley object.
For each feature, the function computes a weighted mean of row-level SHAP contributions
across models using shapley$weights and reports an approximate 95
interval summarizing variability across models.
shapley.row.plot(
shapley,
row_index,
top_n_features = NULL,
features = NULL,
nonzeroCI = FALSE,
plot = TRUE,
print = FALSE
)
shapley |
object of class |
row_index |
Integer (length 1). The row/subject identifier to visualize. This is
matched against the |
top_n_features |
Integer. If specified, the top n features with the highest weighted SHAP values will be selected. This will be overrulled by the 'features' argument. |
features |
Optional character vector of feature names to plot. If |
nonzeroCI |
Logical. If |
plot |
Logical. If |
print |
Logical. If |
a list including the GGPLOT2 object and the data frame of WMSHAP summary values.
E. F. Haghish
## Not run:
# load the required libraries for building the base-learners and the ensemble models
library(h2o) #shapley supports h2o models
library(shapley)
# initiate the h2o server
h2o.init(ignore_config = TRUE, nthreads = 2, bind_to_localhost = FALSE,
insecure = TRUE)
# upload data to h2o cloud
prostate_path <- system.file("extdata", "prostate.csv", package = "h2o")
prostate <- h2o.importFile(path = prostate_path, header = TRUE)
set.seed(10)
### H2O provides 2 types of grid search for tuning the models, which are
### AutoML and Grid. Below, I demonstrate how weighted mean shapley values
### can be computed for both types.
#######################################################
### EXAMPLE 1: PREPARE AutoML Grid (takes a couple of minutes)
#######################################################
# run AutoML to tune various models (GBM) for 60 seconds
y <- "CAPSULE"
prostate[,y] <- as.factor(prostate[,y]) #convert to factor for classification
aml <- h2o.automl(y = y, training_frame = prostate, max_runtime_secs = 120,
include_algos=c("GBM"),
seed = 2023, nfolds = 10,
keep_cross_validation_predictions = TRUE)
### call 'shapley' function to compute the weighted mean and weighted confidence intervals
### of SHAP values across all trained models.
### Note that the 'newdata' should be the testing dataset!
result <- shapley(models = aml, newdata = prostate,
performance_metric = "aucpr", plot = TRUE)
shapley.row.plot(result, row_index = 11)
#######################################################
### EXAMPLE 2: PREPARE H2O Grid (takes a couple of minutes)
#######################################################
# make sure equal number of "nfolds" is specified for different grids
grid <- h2o.grid(algorithm = "gbm", y = y, training_frame = prostate,
hyper_params = list(ntrees = seq(1,50,1)),
grid_id = "ensemble_grid",
# this setting ensures the models are comparable for building a meta learner
seed = 2023, fold_assignment = "Modulo", nfolds = 10,
keep_cross_validation_predictions = TRUE)
result2 <- shapley(models = grid, newdata = prostate,
performance_metric = "aucpr", plot = TRUE)
shapley.row.plot(result2, row_index = 9)
shapley.row.plot(result2, row_index = 9, nonzeroCI = TRUE)
shapley.row.plot(result2, row_index = 9, top_n_features = 10)
#######################################################
### EXAMPLE 3: PREPARE autoEnsemble STACKED ENSEMBLE MODEL
#######################################################
### get the models' IDs from the AutoML and grid searches.
### this is all that is needed before building the ensemble,
### i.e., to specify the model IDs that should be evaluated.
library(autoEnsemble)
ids <- c(h2o.get_ids(aml), h2o.get_ids(grid))
autoSearch <- ensemble(models = ids, training_frame = prostate, strategy = "search")
result3 <- shapley(models = autoSearch, newdata = prostate,
performance_metric = "aucpr", plot = TRUE)
#plot all important features
shapley.row.plot(result3, row_index = 13)
#plot only the given features
shapPlot <- shapley.row.plot(result3, row_index = 13, features = c("PSA", "AGE"))
# inspect the computed data for the row 13
ptint(shapPlot$summary)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.