knitr::opts_chunk$set(error = TRUE)
This package allows you to compute the partial dependence of covariates, compute permutation importance, and case/observation-level distance between data points given a fitted random forest from the following packages (outcome variable types supported in parenthesis): party (multivariate, regression, and classification), randomForestSRC (regression, classification), randomForest (regression and classification), and ranger (classification and regression).
partial_dependence outputs a
data.frame of class
pd. By default all of the supported random forest classifiers will return probabilities.
vars argument is a character vector which gives the features for which the partial dependence is desired.
library(edarf) data(iris) library(party) fit <- cforest(Species ~ ., iris, controls = cforest_unbiased(mtry = 2)) pd <- partial_dependence(fit, vars = "Petal.Width", n = c(10, 25)) print(pd)
n is a numeric vector of length two, which controls via its first element the size of the grid to evaluate the columns specified by the
vars argument at, and the second element gives the number of rows of the other variables to sample when marginalizing the prediction function. All the actual computation is performed in the mmpf package using the function
marginalPrediction. Additional arguments can be passed to this function by name (i.e., using
partial_dependence can be visualized using
plot_pd, which uses ggplot2 to create simple visualizations. In this simple case a line plot for each class.
partial_dependence method can also either return interactions (the partial dependence on unique combinations of a subset of the covariates) or a list of bivariate partial dependence estimates (when multiple covariates are specified in
interaction = FALSE).
pd_list <- partial_dependence(fit, c("Petal.Width", "Petal.Length"), n = c(10, 25), interaction = FALSE) plot_pd(pd_list)
pd_int <- partial_dependence(fit, c("Petal.Length", "Petal.Width"), n = c(10, 25), interaction = TRUE) plot_pd(pd_int)
When an interaction is computed the
facet argument is used to construct plots wheren the feature not set to be the facet is shown conditional on a particular value of the
facet feature. This works best when said feature is an integer or a factor.
variable_importance can be used to compute the permutation importance of covariates. A particular covariate, or covariates, as specified by the
vars parameter, are randomly shuffled, and predictions are computed. The aggregate permutation importance is the mean difference between the unpermuted predictions and the predictions made when the covariates specified in
vars are permuted. If unspecified
vars is set to all of the covariates used in the
imp <- variable_importance(fit, nperm = 10) plot_imp(imp)
The joint permutation importance of multiple covariates can be computed by setting
interaction = TRUE.
variable_importance relies on
mmpf for computation, and additional arguments which allow the computation of local or class-specific importance, as well as the importance under arbitrary loss and contrast functions, can be passed by name to
extract_proximity extracts or computes a matrix which gives the co-occurence of data points in the terminal nodes of the trees in the fitted random forest. This can be used to visualize the distance between data points in the model.
This matrix is too large to be visualized.
plot_prox takes a principal components decomposition of this matrix using
prcomp, and plots it using a biplot. Additional arguments allow other covariates to be mapped onto the size of the points, their color (shown), or their shape.
fit <- cforest(Species ~ ., iris, control = cforest_unbiased(mtry = 2)) prox <- extract_proximity(fit) pca <- prcomp(prox, scale = TRUE) plot_prox(pca, color = iris$Species, color_label = "Species", size = 2)
plot_pred makes plots of predictions versus observations and has the additional feature of being able to "tag" outliers, and assign them labels.
fit <- cforest(Fertility ~ ., swiss) pred <- as.numeric(predict(fit, newdata = swiss)) plot_pred(pred, swiss$Fertility, outlier_idx = which((pred - swiss$Fertility)^2 > var(swiss$Fertility)), labs = row.names(swiss))
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.