tests/testthat/test-serialization.R

test_that("BART Serialization", {
    skip_on_cran()
    
    # Generate simulated data
    n <- 100
    p <- 5
    X <- matrix(runif(n*p), ncol = p)
    f_XW <- (
        ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + 
        ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + 
        ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + 
        ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
    )
    noise_sd <- 1
    y <- f_XW + rnorm(n, 0, noise_sd)
    test_set_pct <- 0.2
    n_test <- round(test_set_pct*n)
    n_train <- n - n_test
    test_inds <- sort(sample(1:n, n_test, replace = FALSE))
    train_inds <- (1:n)[!((1:n) %in% test_inds)]
    X_test <- X[test_inds,]
    X_train <- X[train_inds,]
    y_test <- y[test_inds]
    y_train <- y[train_inds]
    
    # Sample a BART model
    general_param_list <- list(num_chains = 1, keep_every = 1)
    bart_model <- bart(X_train = X_train, y_train = y_train, 
                       num_gfr = 0, num_burnin = 10, num_mcmc = 10, 
                       general_params = general_param_list)
    y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat)
    
    # Save to JSON
    bart_json_string <- saveBARTModelToJsonString(bart_model)
    
    # Reload as a BART model
    bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string)
    
    # Predict from the roundtrip BART model
    y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat)
    
    # Assertion
    expect_equal(y_hat_orig, y_hat_reloaded)
})

test_that("BCF Serialization", {
    skip_on_cran()
    
    n <- 500
    x1 <- runif(n)
    x2 <- runif(n)
    x3 <- runif(n)
    x4 <- runif(n)
    x5 <- runif(n)
    X <- cbind(x1,x2,x3,x4,x5)
    p <- ncol(X)
    pi_x <- 0.25 + 0.5*X[,1]
    mu_x <- pi_x * 5
    tau_x <- X[,2] * 2
    Z <- rbinom(n,1,pi_x)
    E_XZ <- mu_x + Z*tau_x
    y <- E_XZ + rnorm(n, 0, 1)
    test_set_pct <- 0.2
    n_test <- round(test_set_pct*n)
    n_train <- n - n_test
    test_inds <- sort(sample(1:n, n_test, replace = FALSE))
    train_inds <- (1:n)[!((1:n) %in% test_inds)]
    X_test <- X[test_inds,]
    X_train <- X[train_inds,]
    pi_test <- pi_x[test_inds]
    pi_train <- pi_x[train_inds]
    Z_test <- Z[test_inds]
    Z_train <- Z[train_inds]
    y_test <- y[test_inds]
    y_train <- y[train_inds]
    mu_test <- mu_x[test_inds]
    mu_train <- mu_x[train_inds]
    tau_test <- tau_x[test_inds]
    tau_train <- tau_x[train_inds]
    
    # Sample a BCF model
    bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, 
                     propensity_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100)
    bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test)
    mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
    tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
    y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])
    
    # Save to JSON
    bcf_json_string <- saveBCFModelToJsonString(bcf_model)

    # Reload as a BCF model
    bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)

    # Predict from the roundtrip BCF model
    bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test)
    mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
    tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
    y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])

    # Assertion
    expect_equal(y_hat_orig, y_hat_reloaded)
})

test_that("BCF Serialization (no propensity)", {
    skip_on_cran()
    
    n <- 500
    x1 <- runif(n)
    x2 <- runif(n)
    x3 <- runif(n)
    x4 <- runif(n)
    x5 <- runif(n)
    X <- cbind(x1,x2,x3,x4,x5)
    p <- ncol(X)
    pi_x <- 0.25 + 0.5*X[,1]
    mu_x <- pi_x * 5
    tau_x <- X[,2] * 2
    Z <- rbinom(n,1,pi_x)
    E_XZ <- mu_x + Z*tau_x
    y <- E_XZ + rnorm(n, 0, 1)
    test_set_pct <- 0.2
    n_test <- round(test_set_pct*n)
    n_train <- n - n_test
    test_inds <- sort(sample(1:n, n_test, replace = FALSE))
    train_inds <- (1:n)[!((1:n) %in% test_inds)]
    X_test <- X[test_inds,]
    X_train <- X[train_inds,]
    pi_test <- pi_x[test_inds]
    pi_train <- pi_x[train_inds]
    Z_test <- Z[test_inds]
    Z_train <- Z[train_inds]
    y_test <- y[test_inds]
    y_train <- y[train_inds]
    mu_test <- mu_x[test_inds]
    mu_train <- mu_x[train_inds]
    tau_test <- tau_x[test_inds]
    tau_train <- tau_x[train_inds]
    
    # Sample a BCF model
    bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, 
                     num_gfr = 100, num_burnin = 0, num_mcmc = 100)
    bcf_preds_orig <- predict(bcf_model, X_test, Z_test)
    mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]])
    tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]])
    y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]])
    
    # Save to JSON
    bcf_json_string <- saveBCFModelToJsonString(bcf_model)
    
    # Reload as a BCF model
    bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string)
    
    # Predict from the roundtrip BCF model
    bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test)
    mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]])
    tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]])
    y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]])
    
    # Assertion
    expect_equal(y_hat_orig, y_hat_reloaded)
})

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 4, 2025, 2:11 a.m.