Nothing
test_that("BART Serialization", {
skip_on_cran()
# Generate simulated data
n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
f_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
# Sample a BART model
general_param_list <- list(num_chains = 1, keep_every = 1)
bart_model <- bart(X_train = X_train, y_train = y_train,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list)
y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat)
# Save to JSON
bart_json_string <- saveBARTModelToJsonString(bart_model)
# Reload as a BART model
bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string)
# Predict from the roundtrip BART model
y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat)
# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})
test_that("BCF Serialization", {
skip_on_cran()
n <- 500
x1 <- runif(n)
x2 <- runif(n)
x3 <- runif(n)
x4 <- runif(n)
x5 <- runif(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
pi_x <- 0.25 + 0.5*X[,1]
mu_x <- pi_x * 5
tau_x <- X[,2] * 2
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
# Sample a BCF model
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
propensity_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100)
bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test)
mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])
# Save to JSON
bcf_json_string <- saveBCFModelToJsonString(bcf_model)
# Reload as a BCF model
bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)
# Predict from the roundtrip BCF model
bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test)
mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])
# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})
test_that("BCF Serialization (no propensity)", {
skip_on_cran()
n <- 500
x1 <- runif(n)
x2 <- runif(n)
x3 <- runif(n)
x4 <- runif(n)
x5 <- runif(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
pi_x <- 0.25 + 0.5*X[,1]
mu_x <- pi_x * 5
tau_x <- X[,2] * 2
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
# Sample a BCF model
bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
num_gfr = 100, num_burnin = 0, num_mcmc = 100)
bcf_preds_orig <- predict(bcf_model, X_test, Z_test)
mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])
# Save to JSON
bcf_json_string <- saveBCFModelToJsonString(bcf_model)
# Reload as a BCF model
bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)
# Predict from the roundtrip BCF model
bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test)
mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])
# Assertion
expect_equal(y_hat_orig, y_hat_reloaded)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.