#' Plot PD, ICE, and ALE curves
#'
#' This function allows to create PD, ICE, ALE plots
#' ICE plots can be centered at the minimum
#' Plots for several features of one model can be created at the same time
#'
#' @param pred A prediction object from package iml
#' @param features A character vector containing the names of the features for which the plot should be created
#' @param method A character string indicating the method to be applied: either "pdp", "pdp+ice", "ice", or "ale"
#' @param xlabel An optional character vector indicating x-axis label
#' @param ylabel An optional character string indicating y-axis label
#' @param title An optional character string indicating the title of the plot
#' @param center Logical indicating whether ICE curves should be centered at the minimum. Default is FALSE
#' @param limits An optional two-entry vector indicating the limits of the y-axis
#' @param nCol An optional numeric entry indicating the number of columns when plots are created for several features
#' @param alpha An optional numeric entry indicating the alpha-level of ICE curves
#' @param sample_prop An optional numeric entry between 0 and 1 indicating the percentage of individiuals that should be sampled for ICE curves
#'
#' @return a plot of type ggplotify
#'
#' @examples
#' \dontrun{
#' N <- 1000
#' x1 <- runif(N, -1, 1)
#' x2 <- runif(N, -1, 1)
#' y <- 5 + 2 * x1 + x2 + rnorm(N,1)
#' dat <- data.frame(x1,x2,y)
#' rfmod <- randomForest::randomForest(y~., dat)
#' pred <- iml::Predictor$new(rfmod)
#' pdp <- pdp_ice_ale(pred, "x1", method = "pdp")
#' ice <- pdp_ice_ale(pred, "x1", method = "ice", alpha = .1)
#' pdp_ice <- pdp_ice_ale(pred, "x1", method = "pdp+ice", alpha = .2)
#' ale <- pdp_ice_ale(pred, c("x1", "x2"), method = "ale", ylabel = "ALE")
#'}
#'
#' @export
#'
pdp_ice_ale <- function(pred,
features,
method,
xlabel = features,
ylabel = method,
title = "",
center = FALSE,
limits = c(NA,NA),
nCol = NA,
alpha = .5,
sample_prop = .5){
emptyPlot <- ggplot()
listPlot <- list(emptyPlot)
for (nfeat in 1:length(features)){
# build plot dat
feature <- features[nfeat]
if(center == FALSE){
tempDat <- iml::FeatureEffect$new(pred, feature = feature, method = method)
}
if(center == TRUE){
tempDat <- iml::FeatureEffect$new(pred, feature = feature, method = method,
center.at = min(pred$data$X[[feature]]))
}
plotDat <- tempDat$results
if(".class" %in% names(plotDat)){
if(method == "ice") stop("ICE plots are not yet implemented for categorical outcomes")
plotDat <- plotDat[plotDat$.class==1,]
}
# optional sampling
if(method == "ice" | method == "pdp+ice"){
samp <- c(sample(1:nrow(plotDat), nrow(plotDat) * sample_prop, replace = FALSE), NA)
plotDat <- plotDat[plotDat$.id %in% samp,]
plotDat$alpha <- ifelse(plotDat$.type == "ice", alpha, 1)
tempPlot <- ggplot(plotDat,
aes(x = .data[[feature]],
y = .data$.value,
group = .data$.id,
size = .data$.type,
color = .data$.type,
alpha = .data$.type)) +
geom_line() +
scale_size_manual(values = c(.5,2)) +
scale_color_manual(values = c("#2C738EFF", "#000000")) +
scale_alpha_manual(values = c(alpha,1))
}
if(method != "ice" & method != "pdp+ice"){
tempPlot <- ggplot(plotDat,
aes(x = .data[[feature]],
y = .data$.value)) +
geom_point() +
geom_line()
}
# plot
listPlot[[nfeat]] <- tempPlot
# add titles, labels
tempPlot <- tempPlot +
theme_bw() +
xlab(xlabel[nfeat]) +
ylab(ylabel) +
theme(legend.position = "none")
if(!all(is.na(xlabel))){
tempPlot <- tempPlot + xlab(xlabel[nfeat])
}
if (!all(is.na(limits))){
tempPlot <- tempPlot + ylim(limits)
}
listPlot[[nfeat]] <- tempPlot
}
namesPlots <- paste0("tempPlot_",1:length(features))
if(is.na(nCol)){
nCol <- ceiling(length(features)/2)
}
plotting <- do.call("grid.arrange", c(listPlot, ncol=nCol))
plotting <- ggplotify::as.ggplot(plotting) + ggtitle(title)
return(plotting)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.