Nothing
## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE)
## ----generate_data------------------------------------------------------------
# Underlying functions
f0 <- function(x) 90 + exp(0.06 * x)
f1 <- function(x) 72 + 3 * sqrt(x)
set.seed(2793)
# Generate true values.
n <- 120
p.0 <- 0.5
z.0 <- rbinom(n, 1, p.0)
n1 <- sum(z.0); n0 <- n - n1
# In order to make the problem more interesting, x is confounded with both
# y and z.
x <- numeric(n)
x[z.0 == 0] <- rnorm(n0, 20, 10)
x[z.0 == 1] <- rnorm(n1, 40, 10)
y <- numeric(n)
y[z.0 == 0] <- f0(x[z.0 == 0])
y[z.0 == 1] <- f1(x[z.0 == 1])
y <- y + rnorm(n)
data_train <- data.frame(y = y, x = x, z = z.0)
data_test <- data.frame( x = x, z = 1 - z.0)
## ----gibbs_sampler------------------------------------------------------------
n_warmup <- 100
n_samples <- 500
n_total <- n_warmup + n_samples
# Allocate storage for result.
samples_p <- rep.int(NA_real_, n_samples)
samples_z <- matrix(NA_real_, n_samples, n)
samples_mu0 <- matrix(NA_real_, n_samples, n)
samples_mu1 <- matrix(NA_real_, n_samples, n)
library(dbarts, quietly = TRUE)
# We only need to draw one sample at a time, although for illustrative purposes
# a small degree of thinning is done to the BART component.
control <- dbartsControl(updateState = FALSE, verbose = FALSE,
n.burn = 0L, n.samples = 1L, n.thin = 3L,
n.chains = 1L)
# We create the sampler with a z vector that contains at least one 1 and one 0,
# so that all of the cut points are set correctly.
sampler <- dbarts(y ~ x + z, data_train, data_test, control = control)
# Sample from prior.
p <- rbeta(1, 1, 1)
z <- rbinom(n, 1, p)
# Ignore result of this sampler call
invisible(sampler$setPredictor(x = z, column = 2, forceUpdate = TRUE))
sampler$setTestPredictor(x = 1 - z, column = 2)
sampler$sampleTreesFromPrior()
for (i in seq_len(n_total)) {
# Draw a single sample from the posterior of f and sigma^2.
samples <- sampler$run()
# Recover f(x, 1) and f(x, 0).
mu0 <- ifelse(z == 0, samples$train[,1], samples$test[,1])
mu1 <- ifelse(z == 1, samples$train[,1], samples$test[,1])
p0 <- dnorm(y, mu0, samples$sigma[1]) * (1 - p)
p1 <- dnorm(y, mu1, samples$sigma[1]) * p
p.z <- p1 / (p0 + p1)
z <- rbinom(n, 1, p.z)
if (i <= n_warmup) {
sampler$setPredictor(x = z, column = 2, forceUpdate = TRUE)
} else while (sampler$setPredictor(x = z, column = 2) == FALSE) {
z <- rbinom(n, 1, p.z)
}
sampler$setTestPredictor(x = 1 - z, column = 2)
n1 <- sum(z); n0 <- n - n1
p <- rbeta(1, 1 + n0, 1 + n1)
# Store samples if no longer warming up.
if (i > n_warmup) {
offset <- i - n_warmup
samples_p[offset] <- p
samples_z[offset,] <- z
samples_mu0[offset,] <- mu0
samples_mu1[offset,] <- mu1
}
}
## ----sampler_state------------------------------------------------------------
sampler$storeState()
## ----result_visualization-----------------------------------------------------
mean_mu0 <- apply(samples_mu0, 2, mean)
mean_mu1 <- apply(samples_mu1, 2, mean)
ub_mu0 <- apply(samples_mu0, 2, quantile, 0.975)
lb_mu0 <- apply(samples_mu0, 2, quantile, 0.025)
ub_mu1 <- apply(samples_mu1, 2, quantile, 0.975)
lb_mu1 <- apply(samples_mu1, 2, quantile, 0.025)
curve(f0(x), 0, 80, ylim = c(80, 110), ylab = expression(f(x)),
main = "Mixture Model")
curve(f1(x), add = TRUE)
points(x, y, pch = 20, col = ifelse(z.0 == 0, "black", "gray"))
lines(sort(x), mean_mu0[order(x)], col = "red")
lines(sort(x), mean_mu1[order(x)], col = "red")
# Add point-wise confidence intervals.
lines(sort(x), ub_mu0[order(x)], col = "gray", lty = 2)
lines(sort(x), lb_mu0[order(x)], col = "gray", lty = 2)
lines(sort(x), ub_mu1[order(x)], col = "gray", lty = 3)
lines(sort(x), lb_mu1[order(x)], col = "gray", lty = 3)
# Without constraining z for an observation to 0/1, its interpretation may be
# flipped from that which generated the data.
mean_z <- ifelse(apply(samples_z, 2, mean) <= 0.5, 0L, 1L)
if (mean(mean_z != z.0) > 0.5) mean_z <- 1 - mean_z
points(x, y, pch = 1, col = ifelse(mean_z == 0, "black", "gray"))
legend("topright", c("true func", "est func", "group 0", "group 1",
"est group 0", "est group 1"),
lty = c(1, 1, NA, NA, NA, NA), pch = c(NA, NA, 20, 20, 1, 1),
col = c("black", "red", "black", "gray", "black", "gray"),
cex = 0.8, box.col = "white", bg = "white")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.