Nothing
#' Plots response surfaces for tree-based models.
#'
#' @description Plots a response surface plot or interaction plot (2 predictors and a model response) for models of the functions \code{\link[ranger]{ranger}}, \code{\link[randomForest]{randomForest}}, and \code{\link[rpart]{rpart}}. It also plots the observed data on top of the predicted surface.
#'
#' @usage plotInteraction(
#' model = NULL,
#' data = NULL,
#' x = NULL,
#' y = NULL,
#' z = NULL,
#' grid = 100,
#' point.size.range = c(0.1, 1)
#' )
#'
#' @param model a model object produced by the functions \code{\link[ranger]{ranger}}, \code{\link[randomForest]{randomForest}}, or \code{\link[rpart]{rpart}}.
#' @param data dataframe used to fit the model.
#' @param x character string, name of column in \code{data} to be plotted in the x axis.
#' @param y character string, name of column in \code{data} to be plotted in the y axis.
#' @param z character string, name of column in \code{data} to be plotted as a surface. Generally, the name of the response variable in \code{model}.
#' @param grid numeric, resolution of the x and y axes.
#' @param point.size.range numeric vector with two values defining the range size of the points representing the observed data.
#'
#' @author Blas M. Benito <blasbenito@gmail.com>
#'
#' @return A ggplot object.
#'
#'@export
plotInteraction <- function(model = NULL,
data = NULL,
x = NULL,
y = NULL,
z = NULL,
grid = 100,
point.size.range = c(0.1, 1)
){
#from https://stackoverflow.com/a/49167211
#to ensure that scale_fill_viridis returns a transparent color scale
vir_lite = function(cols, ds=0.4, dv=0.7) {
cols = rgb2hsv(col2rgb(cols))
cols["v", ] = cols["v", ] + dv*(1 - cols["v", ])
cols["s", ] = ds*cols["s", ]
apply(cols, 2, function(x) hsv(x[1], x[2], x[3]))
}
#generating grid
newdata <- expand.grid(seq(min(data[[x]]), max(data[[x]]), length.out = grid), seq(min(data[[y]]), max(data[[y]]), length.out = grid))
colnames(newdata) <- c(x, y)
#setting the other variables to their mean
other_vars <- setdiff(names(data), c(x, y))
n <- nrow(data)
for(i in other_vars){
# newdata[, i] <- data[, i][sample(n, n)]
newdata[, i] <- mean(data[, i])
}
#predicting different types of models
if(class(model)=="ranger"){
newdata$prediction <- predict(model, newdata)$predictions
}
if(class(model)=="rpart"){
newdata$prediction <- predict(model, newdata, type="vector")
}
if(!(class(model) %in% c("ranger", "rpart"))){
newdata$prediction <- predict(model, newdata, type="response")
}
#if more than 15 unique predictions,
if(length(unique(newdata$prediction)) < 15){
newdata$prediction <- round(newdata$prediction, 0)
newdata$prediction <- as.factor(newdata$prediction)
p1 <- ggplot(newdata, aes_string(x=x, y=y)) +
geom_raster(aes(fill=prediction), alpha=0.3) +
scale_fill_viridis(discrete=TRUE) +
guides(fill = guide_legend(override.aes = list(alpha = 0.2))) +
geom_point(data=data, aes_string(x=x, y=y, color=z, size=z, alpha=z), shape=16) +
scale_color_gradient(low="white", high="black", guide=FALSE) +
scale_size_continuous(range=point.size.range) +
scale_alpha_continuous(guide=FALSE) +
labs(fill="Predicted", size="Observed")
} else {
p1 <- ggplot(newdata, aes_string(x=x, y=y)) +
geom_raster(aes(fill = prediction), alpha=0.8) +
scale_fill_gradientn(colors=vir_lite(viridis(10))) +
geom_point(data=data, aes_string(x=x, y=y, color=z, size=z, alpha=z), shape=16) +
scale_color_gradient(low="white", high="black", guide=FALSE) +
scale_size_continuous(range=point.size.range) +
scale_alpha_continuous(guide=FALSE) +
labs(fill="Predicted", size="Observed")
}
return(p1)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.