explore: Explore a model's "reasoning" via counterfactual predictions

View source: R/explore.R

exploreR Documentation

Explore a model's "reasoning" via counterfactual predictions

Description

Make predictions for observations that vary over features of interest. There are two major use cases for this function. One is to understand how the model responds to features, not just individually but over combinations of features (i.e. interaction effects). The other is to explore how an individual prediction would vary if feature values were different. Note, however, that this function does not establish causality and the latter use case should be deployed judiciously.

Usage

explore(
  models,
  vary = 4,
  hold = list(numerics = median, characters = Mode),
  numerics = c(0.05, 0.25, 0.5, 0.75, 0.95),
  characters = 5
)

Arguments

models

A model_list object. The data the model was trained on must have been prepared, either by training with machine_learn or by preparing with prep_data before model training.

vary

Which (or how many) features to vary? Default is 4; if vary is a single integer (n), the n-most-important features are varied (see Details for how importance is determined). If vary is a vector of integers, those rankings of features are used (e.g. vary = 2:4 varies the 2nd, 3rd, and 4th most-important features). Alternatively, you can specify which features to vary by passing a vector of feature names. For the finest level of control, you can choose the alternative values to use by passing a list with names being features names and entries being values to use; in this case numerics and characters are ignored.

hold

How to choose the values of features not being varied? To make counterfactual predictions for a particular patient, this can be a row of the training data frame (or a one-row data frame containing values for all of the non-varying features). Alternatively, this can be functions to determine the values of non-varying features, in which case it must be a length-2 list with names "numerics" and "characters", each being a function to determine the values of non-varying features of that data type. The default is list(numerics = median, characters = Mode); numerics is applied to the column from the training data, characters is applied to a frequency table of the column from the training data.

numerics

How to determine values of numeric features being varied? By default, the 5th, 25th, 50th (median), 75th, and 95th percentile values from the training dataset will be used. To specify evenly spaced quantiles, starting with the 5th and ending with the 95th, pass an integer to this argument. To specify which quantiles to use, pass a numeric vector in [0, 1] to this argument, e.g. c(0, .5, 1) for the minimum, median, and maximum values from the training dataset.

characters

Integer. For categorical variables being varied, how many values to use? Values are used from most- to least-common; default is 5.

Details

If vary is an integer, the most important features are determined by get_variable_importance, unless glm is the only model present, in which case interpret is used with a warning. When selecting the most important features to vary, for categorical features the sum of feature importance of all the levels as dummies is used.

Value

A tibble with values of features used to make predictions and predictions. Has class explore_df and attribute vi giving information about the varying features.

See Also

plot.explore_df

Examples

# First, we need a model on which to make counterfactual predictions
set.seed(5176)
m <- machine_learn(pima_diabetes, patient_id, outcome = diabetes,
                   tune = FALSE, models = "xgb")

# By default, the four most important features are varied, with numeric
# features taking their 5, 25, 50, 75, and 95 percentile values, and
# categoricals taking their five most common values. Others features are
# held at their median and modal values for numeric and categorical features,
# respectively. This can provide insight into how the model responds to
# different features
explore(m)

# It is easy to plot counterfactual predictions. By default, only the two most
# important features are plotted over; see `?plot.explore_df` for
# customization options
explore(m) %>%
  plot()

# You can specify which features vary and what values they take in a variety of
# ways. For example, you could vary only "weight_class" and "plasma_glucose"
explore(m, vary = c("weight_class", "plasma_glucose"))

# You can also control what values non-varying features take.
# For example, if you want to simulate alternative scenarios for patient 321
patient321 <- dplyr::filter(pima_diabetes, patient_id == 321)
patient321
explore(m, hold = patient321)

# Here is an example in which both the varying and non-varying feature values
# are explicitly specified.
explore(m,
        vary = list(weight_class = c("normal", "overweight", "obese"),
                    plasma_glucose = seq(60, 200, 10)),
        hold = list(pregnancies = 2,
                    pedigree = .5,
                    age = 25,
                    insulin = NA,
                    skinfold = NA,
                    diastolic_bp = 85)) %>%
  plot()

healthcareai documentation built on Sept. 5, 2022, 5:12 p.m.