Partial Dependence Functions

Description

Compute partial dependence functions (i.e., marginal effects) for various model fitting objects.

Usage

1
2
3
4
5
6
partial(object, ...)

## Default S3 method:
partial(object, pred.var, pred.grid, grid.resolution = NULL,
  type, which.class = 1L, plot = FALSE, smooth = FALSE, rug = FALSE,
  chull = FALSE, train, check.class = TRUE, ...)

Arguments

object

A fitted model object of appropriate class (e.g., "gbm", "lm", "randomForest", etc.).

...

Additional optional arguments to be passed onto plyr::aaply.

pred.var

Character string giving the names of the predictor variables of interest. For reasons of computation/interpretation, this should include no more than three variables.

pred.grid

Data frame containing the joint values of the variables listed in pred.var.

grid.resolution

Integer giving the number of equally spaced points to use (only used for the continuous variables listed in pred.var when pred.grid is not supplied). If left NULL, it will default to the minimum between 51 and the number of unique data points for each of the continuous independent variables listed in pred.var.

type

Character string specifying the type of supervised learning. Current options are "regression" or "classification". For some objects (e.g., tree-based models like "rpart"), partial can usually extract the necessary information from object.

which.class

Integer specifying which column of the matrix of predicted probabilities to use as the "focus" class. Default is to use the first class. Only used for classification problems (i.e., when type = "classification").

plot

Logical indicating whether to return a data frame containing the partial dependence values (FALSE) or plot the partial dependence function directly (TRUE). Default is FALSE.

smooth

Logical indicating whether or not to overlay a LOESS smooth. Default is FALSE.

rug

Logical indicating whether or not to include rug marks on the predictor axes. Only used when plot = TRUE. Default is FALSE.

chull

Logical indicating wether or not to restrict the first two variables in pred.var to lie within the convex hull of their training values; this affects pred.grid. Default is FALSE.

train

An optional data frame containing the original training data. This may be required depending on the class of object. For objects that do not store a copy of the original training data, this argument is required.

check.class

Logical indicating whether or not to make sure each column in pred.grid has the correct class, levels, etc. Default is TRUE.

References

J. H. Friedman. Greedy function approximation: A gradient boosting machine. Annals of Statistics, 29: 1189-1232, 2000.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
## Not run: 
#
# Regression example (requires randomForest package to run)
#

# Fit a random forest to the boston housing data
library(randomForest)
data (boston)  # load the boston housing data
set.seed(101)  # for reproducibility
boston.rf <- randomForest(cmedv ~ ., data = boston)

# Using randomForest's partialPlot function
partialPlot(boston.rf, pred.data = boston, x.var = "lstat")

# Using pdp's partial function
head(partial(boston.rf, pred.var = "lstat"))  # returns a data frame
partial(boston.rf, pred.var = "lstat", plot = TRUE, rug = TRUE)

# The partial function allows for multiple predictors
partial(boston.rf, pred.var = c("lstat", "rm"), grid.resolution = 40,
        plot = TRUE, chull = TRUE, .progress = "text")

# The plotPartial function offers more flexible plotting
pd <- partial(boston.rf, pred.var = c("lstat", "rm"), grid.resolution = 40)
plotPartial(pd)  # the default
plotPartial(pd, levelplot = FALSE, zlab = "cmedv", drape = TRUE,
            colorkey = FALSE, screen = list(z = -20, x = -60))

#
# Classification example (requires randomForest package to run)
#

# Fit a random forest to the Pima Indians diabetes data
data (pima)  # load the boston housing data
set.seed(102)  # for reproducibility
pima.rf <- randomForest(diabetes ~ ., data = pima, na.action = na.omit)

# Partial dependence of glucose on diabetes test result (neg/pos)
partial(pima.rf, pred.var = c("glucose", "age"), plot = TRUE, chull = TRUE,
        .progress = "text")

#
# Interface with caret (requires caret package to run)
#

# Load required packages
library(caret)  # for model training/tuning

# Set up for 5-fold cross-validation
ctrl <- trainControl(method = "cv", number = 5, verboseIter = TRUE)

# Tune a support vector machine (SVM) using a radial basis function kerel to
# the Pima Indians diabetes data
set.seed(103)  # for reproducibility
pima.svm <- train(diabetes ~ ., data = pima, method = "svmRadial",
                  prob.model = TRUE, na.action = na.omit, trControl = ctrl,
                  tuneLength = 10)

# Partial dependence of glucose on diabetes test result (neg/pos)
partial(pima.svm, pred.var = "glucose", plot = TRUE, rug = TRUE)

## End(Not run)