explain: Explain predictions

View source: R/explain.R

explainR Documentation

Explain predictions

Description

Explain a prediction of the surrogate GLM via each feature's contribution.

Usage

explain(surro, instance, plt = TRUE)

Arguments

surro

The surrogate GLM fit (i.e., a "glm" object).

instance

Single row data frame with the instance to be explained.

plt

Boolean whether to return a ggplot or the underlying data.

Value

Tidy data frame or ggplot with each feature's contribution to the prediction of model surro on observation instance. When plt = FALSE, the columns fit_link and se_link contain the fitted coefficient and standard error on the linear predictor scale. The column fit_resp contains the coefficient on the response scale after taking the inverse link function. The columns upr_conf and lwr_conf contain the upper and lower bound of a ⁠95%⁠ confidence interval on the response scale. When plt = TRUE the ggplot shows the coefficient and confidence interval on the response scale. A green dashed line shows the value of the invere link function applied to zero. Features with bars close to this line have a neglegible impact on the predition.

Examples

## Not run: 
data('mtpl_be')
features <- setdiff(names(mtpl_be), c('id', 'nclaims', 'expo', 'long', 'lat'))
set.seed(12345)
gbm_fit <- gbm::gbm(as.formula(paste('nclaims ~',
                               paste(features, collapse = ' + '))),
                    distribution = 'poisson',
                    data = mtpl_be,
                    n.trees = 50,
                    interaction.depth = 3,
                    shrinkage = 0.1)
gbm_fun <- function(object, newdata) mean(predict(object, newdata, n.trees = object$n.trees, type = 'response'))
data_segm <- gbm_fit %>% insights(vars = c('ageph', 'bm', 'coverage', 'fuel', 'bm_fuel'),
                                  data = mtpl_be,
                                  interactions = 'user',
                                  pred_fun = gbm_fun) %>%
                          segmentation(data = mtpl_be,
                                       type = 'ngroups',
                                       values = setNames(c(7, 8, 2, 2, 3), c('ageph', 'bm', 'coverage', 'fuel', 'bm_fuel')))
data_segm %>% surrogate(formula = nclaims ~ ageph_ + bm_ + coverage_ + fuel_ + bm_fuel_,
                        family =  poisson(link = 'log'),
                        offset = log(expo)) %>%
              explain(instance = data_segm[34, ])

## End(Not run)

henckr/maidrr documentation built on July 27, 2023, 3:17 p.m.