tools/pdp-logo-img.R

# Load required packages
library(earth)
library(ggplot2)
library(grid)
library(pdp)
library(png)
library(randomForest)

# Fit a random forest to the Boston housing data
data (boston)  # load the boston housing data
set.seed(101)  # for reproducibility
boston.rf <- randomForest(cmedv ~ ., data = boston)
boston.earth <- earth(cmedv ~ ., data = boston, degree = 2,
                      pmethod = "exhaustive", nfold = 5, ncross = 5)

# Print MARS model coefficients
coef(boston.earth)

# Function to rescale vector to be between a and b
rescale <- function(x, a, b) {
  ((x - min(x)) / (max(x) - min(x))) * (b - a) + a
}

# Partial dependence of cmedv on lstat and rm
pd <- partial(boston.rf, pred.var = c("lstat", "rm"), chull = FALSE,
              progress = "text", grid.resolution = 100)

# Boundaries of the hexagon
hex <- data.frame(x = 1.35 * 1 * c(-sqrt(3) / 2, 0, rep(sqrt(3) / 2, 2), 0,
                                   rep(-sqrt(3) / 2, 2)),
                  y = 1.35 * 1 * c(0.5, 1, 0.5, -0.5, -1, -0.5, 0.5))

# Restrict PDP to the boundaries of the hexagon
pd_hex <- pd
pd_hex$lstat <- rescale(pd_hex$lstat, a = min(hex$x), b = max(hex$x))
pd_hex$rm <- rescale(pd_hex$rm, a = min(hex$y), b = max(hex$y))
pd_hex <- pd_hex[mgcv::in.out(as.matrix(hex), as.matrix(pd_hex[, 1L:2L])), ]

# Hexagon logo
make_pdp_sticker <- function(option) {
  ggplot(pd_hex, aes(lstat, rm)) +
    geom_polygon(data = hex, aes(x, y), color = "black", fill = grey(0.25),
                 size = 3) +
    geom_tile(data = pd_hex, aes(x = lstat, y = rm, z = yhat, fill = yhat)) +
    # scale_fill_distiller(name = "yhat", palette = "Spectral") +
    viridis::scale_fill_viridis(option = option) +
    geom_polygon(data = hex, aes(x, y), color = "black", fill = "transparent",
                 size = 3) +
    geom_contour(aes(z = yhat), color = "black") +
    annotate(geom = "text", x = 0, y = -0.15, color = "white", size = 18,
             label = "pdp") +
    coord_equal(xlim = range(hex$x), ylim = range(hex$y)) +
    scale_x_continuous(expand = c(0.04, 0)) +
    scale_y_reverse(expand = c(0.04, 0)) +
    theme(axis.line = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks = element_blank(),
          axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          legend.position = "none",
          plot.background = element_blank(),
          panel.background = element_blank(),
          panel.border = element_blank(),
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank())
}

# Plot range of different logos
logos <- lapply(LETTERS[1L:5L], make_pdp_sticker)
png("tools/pdp-logos.png", width = 900, height = 500, bg = "transparent",
    type = "cairo-png")
grid.arrange(grobs = logos, ncol = 3)
dev.off()

# Print hexagon logo
pdp_logo <- make_pdp_sticker(option = "C")
print(pdp_logo)

png("tools/pdp-logo.png", width = 181, height = 209, bg = "transparent", type = "cairo-png")
print(pdp_logo)
dev.off()

svg("tools/pdp-logo.svg", width = 181 / 72, height = 209 / 72, bg = "transparent")
print(pdp_logo)
dev.off()
bgreenwell/pdp documentation built on Aug. 15, 2018, 10:17 p.m.