#' @title Multispecies partial plots (in development)
#'
#' @description
#'
#' Partial dependence plots show the response curves of an individual variable in the sum-of-trees models. The main line is the average of partial dependence plots for each posterior draw of sum-of-trees models; each of those curves is generated by evaluating the BART model prediction at each specified x value for *each other combination of other x values in the data*. This is obviously computationally very expensive, and gets slower to run depending on: how much smooth you add, how many variables you ask for, and more posterior draws (ndpost; defaults to 1000) in the bart() function.
#'
#' This is a somewhat altered and simplified version of partial() that allows you to combine different species' responses to the same variables on the same axes. Run the models separately, then use this function to combine them.
#'
#' @param model A dbarts model object
#' @param x.vars A list of the variables for which you want to run the partials. Defaults to doing all of them.
#' @param smooth A multiplier for how much smoother you want the sampling of the levels to be. High values, like 10 or over, are obviously much slower and don't add much.
#' @param trace Traceplots for each individual draw from the posterior
#' @param transform This converts from the logit output of dbarts:::predict to actual 0 to 1 probabilities. I wouldn't turn this off unless you're really interested in a deep dive on the model.
#' @param panels For multiple variables, use this to create a multipanel figure.
#'
#'
#' @return Returns a ggplot object or cowplot object.
#'
#' @examples
#' f <- function(x) { return(0.5 * x[,1] + 3 * x[,2] * x[,3]) - 5*x[,4] }
#' g <- function(x) { return(0.6 * (x[,1]-0.3) + 1 * x[,2] * x[,3]) - 5*x[,4] }
#' sigma <- 0.2
#' n <- 100
#' x <- matrix(2 * runif(n * 3) -1, ncol = 3)
#' x <- data.frame(x)
#' x[,4] <- rbinom(100, 1, 0.3)
#' colnames(x) <- c('rob', 'hugh', 'ed', 'phil')
#' Ey <- f(x)
#' y <- rnorm(n, Ey, sigma)
#' Ez <- g(x)
#' z <- rnorm(n, Ez, sigma)
#' df <- data.frame(z, y, x)
#' set.seed(99)
#'
#' bartFit1 <- bart(y ~ rob + hugh + ed + phil, df,
#' keepevery = 10, ntree = 100, keeptrees = TRUE)
#'
#' bartFit2 <- bart(z ~ rob + hugh + ed + phil, df,
#' keepevery = 10, ntree = 100, keeptrees = TRUE)
#'
#' model.list <- list(bartFit1, bartFit2)
#'
#' multipartial(modl = model.list, spnames = c('Y','Z'),
#' x.vars = c('rob','hugh'),
#' smooth = 7, trace = TRUE, panels = TRUE)
#'
#' @export
#'
#'
multipartial <- function(modl, spnames, x.vars = NULL,
smooth = 7, trace = TRUE,
transform = TRUE, panels = FALSE) {
# A couple errors in case I'm Idiot
if(smooth>10) {
warning("You have chosen way, way too much smoothing... poorly")
}
if(!is.null(x.vars) && length(x.vars)==1 && panels==TRUE) {
stop("Hey bud, you can't do several panels on only one variable!")
}
# Actually build each partial
pd.get <- function(model) {
if (is.null(x.vars)) { raw <- model$fit$data@x} else { raw <- model$fit$data@x[,x.vars]}
if(!is.null(x.vars) && length(x.vars)==1) {
minmax <- data.frame(mins = min(raw),
maxs = max(raw)) } else {
minmax <- data.frame(mins = apply(raw, 2, min),
maxs = apply(raw, 2, max))
}
lev <- lapply(c(1:nrow(minmax)), function(i) {seq(minmax$mins[i], minmax$maxs[i], (minmax$maxs[i]-minmax$mins[i])/(10*smooth))})
for(i in 1:length(lev)){
if(length(lev)==1) {
if(length(unique(raw))==2) { lev[[i]] <- unique(raw) }
} else {
if(length(unique(raw[,i]))==2) { lev[[i]] <- unique(raw[,i])}
}
}
pd <- pdbart(model, xind = x.vars, levs = lev, pl=FALSE)
}
pdl <- lapply(modl, pd.get)
# Generate the partial visualizations in ggplot
plots <- list()
for (i in 1:length(pdl[[1]]$fd)) {
for (j in 1:length(pdl)) {
q50f <- function(x) {apply(x$fd[[i]], 2, median)}
q50 <- data.frame(lapply(pdl, q50f))
colnames(q50) <- spnames
if(transform==TRUE) {q50 <- apply(q50, 2, pnorm)}
q50 %>% as_tibble() %>% mutate(x = pdl[[1]]$levs[[i]]) -> df
}
df %>% pivot_longer(-x) %>% rename(Species = name) -> df
if(trace==TRUE) {
ff <- function(x) {data.frame(t(x$fd[[i]])) }
f <- lapply(pdl, ff)
traces <- lapply(c(1:length(pdl)), function(k) {
colnames(f[[k]]) <- gsub('X', 'iter', colnames(f[[k]]))
df %>% filter(Species == spnames[k]) %>% pull(x) -> x.i
f[[k]]$x <- x.i
f[[k]]$Species <- spnames[k]
f[[k]]
})
traces <- do.call("rbind", traces)
df <- left_join(df, traces)
}
#if(transform==TRUE) {df %>% mutate(value = pnorm(value)) -> df}
g <- ggplot(df,aes(x=x, y = value, colour = Species)) +
labs(title=pdl[[1]]$xlbs[[i]], y='', x='') + theme_light(base_size = 16) +
theme(plot.title = element_text(hjust = 0.5),
axis.title.y = element_text(vjust=1.7))
if(trace==TRUE) {
if(transform==TRUE) {
for(j in 3:ncol(df)) {
g <- g + geom_line(aes_string(y=pnorm(pull(df[,j]))), alpha = 0.025)
}
} else {
for(j in 3:ncol(df)) {
g <- g + geom_line(aes_string(y=pull(df[,j])), alpha = 0.025)
}
}
}
g <- g + geom_line(size=1.25)
if(panels==FALSE) {g <- g + theme(plot.margin=unit(c(0.5,0.5,0.5,0.5),"cm"))} else {
g <- g + theme(plot.margin=unit(c(0.15,0.15,0.15,0.15),"cm"))
}
plots[[i]] <- g
}
# Return them
if(panels==TRUE) {#print(cowplot::plot_grid(plotlist=plots))
return(wrap_plots(plotlist=plots, guides = 'collect'))
} else {
return(plots)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.