knitr::opts_chunk$set(echo = TRUE)
For users who are comfortable with writing their own posterior samplers, dbarts
makes it easy to incorporate a linear BART component in a Gibbs sampler. In the example below, we fit a functional mixture model, where observations may have come from one of two underlying functions and class membership is unobserved.
Performing inference conditional on $x$, we assume that:
$$Y\mid Z = z \sim \mathrm{N}\left(f(x, z), \sigma^2 \right),$$ $$Z \sim \mathrm{Bern}(p),$$
and that $Z$ is unobserved. To make the model fully Bayesian, we impose a $\mathrm{Beta}(1, 1)$ prior on $p$. $f$ and $\sigma^2$ have implicit priors defined by BART.
To implement a Gibbs sampler, we will have to produce updates for $Z$ and $p$. After writing down the joint distribution of $Z$, $Y$, $p$, $f$, and $\sigma^2$, we have:
where $\phi$ is the density of a normal distribution. This yields an overall strategy of:
To illustrate, we create some toy data. The simulated data is visualized later, together with the fitted model.
# Underlying functions f0 <- function(x) 90 + exp(0.06 * x) f1 <- function(x) 72 + 3 * sqrt(x) set.seed(2793) # Generate true values. n <- 120 p.0 <- 0.5 z.0 <- rbinom(n, 1, p.0) n1 <- sum(z.0); n0 <- n - n1 # In order to make the problem more interesting, x is confounded with both # y and z. x <- numeric(n) x[z.0 == 0] <- rnorm(n0, 20, 10) x[z.0 == 1] <- rnorm(n1, 40, 10) y <- numeric(n) y[z.0 == 0] <- f0(x[z.0 == 0]) y[z.0 == 1] <- f1(x[z.0 == 1]) y <- y + rnorm(n) data_train <- data.frame(y = y, x = x, z = z.0) data_test <- data.frame( x = x, z = 1 - z.0)
For this specific model, the sampler needs to be updated with new predictor variables/covariates. This can be done with a dbartsSampler
reference class object by calling the sampler$setPredictor
function. As, given any $Z = z$, the sampler also needs to evaluate $f(x, 1 - z)$, we utilize the test slot of the sampler and keep it up-to-date with the "counterfactual" predictor using the sampler$setTestPredictor
function. Models that instead modify the response variable can use the sampler$setResponse
or sampler$setOffset
functions.
One complication in updating the predictor matrix while the sampler runs is that new values of $z$ must leave the sampler in an internally consistent state. During warmup this constraint is ignored, however, after warmup is complete rejection sampling is used to guarantee that no leaf nodes are empty. This is accomplished by using the forceUpdate
argument to setPredictor
and checking that the logical response is TRUE
.
n_warmup <- 100 n_samples <- 500 n_total <- n_warmup + n_samples # Allocate storage for result. samples_p <- rep.int(NA_real_, n_samples) samples_z <- matrix(NA_real_, n_samples, n) samples_mu0 <- matrix(NA_real_, n_samples, n) samples_mu1 <- matrix(NA_real_, n_samples, n) library(dbarts, quietly = TRUE) # We only need to draw one sample at a time, although for illustrative purposes # a small degree of thinning is done to the BART component. control <- dbartsControl(updateState = FALSE, verbose = FALSE, n.burn = 0L, n.samples = 1L, n.thin = 3L, n.chains = 1L) # We create the sampler with a z vector that contains at least one 1 and one 0, # so that all of the cut points are set correctly. sampler <- dbarts(y ~ x + z, data_train, data_test, control = control) # Sample from prior. p <- rbeta(1, 1, 1) z <- rbinom(n, 1, p) # Ignore result of this sampler call invisible(sampler$setPredictor(x = z, column = 2, forceUpdate = TRUE)) sampler$setTestPredictor(x = 1 - z, column = 2) sampler$sampleTreesFromPrior() for (i in seq_len(n_total)) { # Draw a single sample from the posterior of f and sigma^2. samples <- sampler$run() # Recover f(x, 1) and f(x, 0). mu0 <- ifelse(z == 0, samples$train[,1], samples$test[,1]) mu1 <- ifelse(z == 1, samples$train[,1], samples$test[,1]) p0 <- dnorm(y, mu0, samples$sigma[1]) * (1 - p) p1 <- dnorm(y, mu1, samples$sigma[1]) * p p.z <- p1 / (p0 + p1) z <- rbinom(n, 1, p.z) if (i <= n_warmup) { sampler$setPredictor(x = z, column = 2, forceUpdate = TRUE) } else while (sampler$setPredictor(x = z, column = 2) == FALSE) { z <- rbinom(n, 1, p.z) } sampler$setTestPredictor(x = 1 - z, column = 2) n1 <- sum(z); n0 <- n - n1 p <- rbeta(1, 1 + n0, 1 + n1) # Store samples if no longer warming up. if (i > n_warmup) { offset <- i - n_warmup samples_p[offset] <- p samples_z[offset,] <- z samples_mu0[offset,] <- mu0 samples_mu1[offset,] <- mu1 } }
Save
ing the samplerIf it desired to use save
and load
on the sampler, it is required to instruct sampler stored using the reference class to write its state out as an R
object:
sampler$storeState()
Finally, we visualize the results. In the graph below, the estimated label of observations encircles the true labels, both of which are represented by point color. We see that, beyond the range of the observed data, the estimated functions regress towards each other and points start to be mislabeled.
mean_mu0 <- apply(samples_mu0, 2, mean) mean_mu1 <- apply(samples_mu1, 2, mean) ub_mu0 <- apply(samples_mu0, 2, quantile, 0.975) lb_mu0 <- apply(samples_mu0, 2, quantile, 0.025) ub_mu1 <- apply(samples_mu1, 2, quantile, 0.975) lb_mu1 <- apply(samples_mu1, 2, quantile, 0.025) curve(f0(x), 0, 80, ylim = c(80, 110), ylab = expression(f(x)), main = "Mixture Model") curve(f1(x), add = TRUE) points(x, y, pch = 20, col = ifelse(z.0 == 0, "black", "gray")) lines(sort(x), mean_mu0[order(x)], col = "red") lines(sort(x), mean_mu1[order(x)], col = "red") # Add point-wise confidence intervals. lines(sort(x), ub_mu0[order(x)], col = "gray", lty = 2) lines(sort(x), lb_mu0[order(x)], col = "gray", lty = 2) lines(sort(x), ub_mu1[order(x)], col = "gray", lty = 3) lines(sort(x), lb_mu1[order(x)], col = "gray", lty = 3) # Without constraining z for an observation to 0/1, its interpretation may be # flipped from that which generated the data. mean_z <- ifelse(apply(samples_z, 2, mean) <= 0.5, 0L, 1L) if (mean(mean_z != z.0) > 0.5) mean_z <- 1 - mean_z points(x, y, pch = 1, col = ifelse(mean_z == 0, "black", "gray")) legend("topright", c("true func", "est func", "group 0", "group 1", "est group 0", "est group 1"), lty = c(1, 1, NA, NA, NA, NA), pch = c(NA, NA, 20, 20, 1, 1), col = c("black", "red", "black", "gray", "black", "gray"), cex = 0.8, box.col = "white", bg = "white")
When implementing a Gibbs sampler using dbarts
in R, it is most often the case that multiple threads will need to handled by creating separate copies of the sampler and data. While dbartsSampler
s are natively multithreaded, few of their slots are stored independently across chains. For an example of this approach and a more complete implementation of the principles in this document, consult the implementation of rbart_vi.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.