tests/testthat/test-11-callback.R

context("stan4bart callback argument")

source(system.file("common", "friedmanData.R", package = "stan4bart"), local = TRUE)

testData <- generateFriedmanData(100, TRUE, TRUE, FALSE)
rm(generateFriedmanData)

df <- with(testData, data.frame(x, g.1 = as.factor(g.1), g.2 = as.factor(g.2), y, z))
df.train <- df[1:80,]
df.test  <- df[81:100,]

callback <- function(yhat.train, yhat.test, stan_pars) {
  rho <- parent.frame()
  if (is.null(rho$fixef.ind)) {
    rho$fixef.ind <- which(grepl("^beta|gamma", names(stan_pars)))
    rho$ranef.ind <- which(startsWith(names(stan_pars), "b."))
  }
  y.test <- rho$frame$y
  fixef <- stan_pars[rho$fixef.ind]
  ranef <- stan_pars[rho$ranef.ind]

  x_means <- rho$X_means
  keep_cols <- names(x_means) != "(Intercept)"
    
  intercept_delta <- sum(fixef[keep_cols] * x_means[keep_cols])
  
  # training X is rho$X, training Zt is rho$reTrms$Zt
  fit.fixef <- as.vector(rho$test$X %*% fixef) - intercept_delta
  fit.ranef <- as.vector(Matrix::crossprod(rho$test$reTrms$Zt, ranef))
  yhat.test.full <- yhat.test + fit.fixef + fit.ranef
  
  result <- c(yhat.test, fit.fixef, fit.ranef)
  names(result) <- c(
    paste0("yhat.test_", seq_along(yhat.test)),
    paste0("fit.fixef_", seq_along(fit.fixef)),
    paste0("fit.ranef_", seq_along(fit.ranef))
  )
  result
}
fn_env <- new.env(parent = baseenv())
environment(callback) <- fn_env

test_that("callback is passed sample correctly", {
  fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
                   data = df.train,
                   test = df.test,
                   cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
                   bart_args = list(n.trees = 11),
                   seed = 0,
                   callback = callback)
  expect_is(fit, "stan4bartFit")

  indiv.bart <- fit$callback[seq_len(20L),,]
  indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
  indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]

  expect_equal(unname(indiv.bart),
               unname(extract(fit, type = "indiv.bart",  sample = "test", combine_chains = FALSE)))
  expect_equal(unname(indiv.fixef),
               unname(extract(fit, type = "indiv.fixef", sample = "test", combine_chains = FALSE)))
  expect_equal(unname(indiv.ranef),
               unname(extract(fit, type = "indiv.ranef", sample = "test", combine_chains = FALSE)))
  
  ev <- indiv.fixef + indiv.ranef + indiv.bart
  expect_equal(unname(ev),
               unname(extract(fit, "ev", sample = "test", combine_chains = FALSE)))

  expect_equal(
    dimnames(fit$callback)[[1L]],
    c(paste0("yhat.test_", seq_len(20L)),
      paste0("fit.fixef_", seq_len(20L)),
      paste0("fit.ranef_", seq_len(20L)))
  )


  fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
                   data = df.train,
                   test = df.test,
                   cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
                   bart_args = list(n.trees = 11),
                   seed = 0, keep_fits = FALSE,
                   callback = callback)
  expect_is(fit, "stan4bartFit")
  
  indiv.bart <- fit$callback[seq_len(20L),,]
  indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
  indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]
  
  expect_equal(unname(ev), unname(indiv.fixef + indiv.ranef + indiv.bart))

  expect_true(is.null(fit$bart_train))
  expect_true(is.null(fit$bart_test))
  expect_true(is.null(fit$bart_varcount))
  expect_true(is.null(fit$stan))
  expect_true(is.null(fit$par_names))
  expect_true(is.null(fit$warmup$bart_train))
  expect_true(is.null(fit$warmup$bart_test))
  expect_true(is.null(fit$warmup$bart_varcount))
  expect_true(is.null(fit$warmup$stan))
})


test_that("callback works with multiple threads", {
  fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
                   data = df.train,
                   test = df.test,
                   cores = 2, verbose = -1L, chains = 2, warmup = 7, iter = 13,
                   bart_args = list(n.trees = 11),
                   seed = 0,
                   callback = callback)
  expect_is(fit, "stan4bartFit")

  
  indiv.bart <- fit$callback[seq_len(20L),,]
  indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
  indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]

  expect_equal(unname(indiv.bart),
               unname(extract(fit, type = "indiv.bart",  sample = "test", combine_chains = FALSE)))
  expect_equal(unname(indiv.fixef),
               unname(extract(fit, type = "indiv.fixef", sample = "test", combine_chains = FALSE)))
  expect_equal(unname(indiv.ranef),
               unname(extract(fit, type = "indiv.ranef", sample = "test", combine_chains = FALSE)))
  
  ev <- indiv.fixef + indiv.ranef + indiv.bart
  expect_equal(unname(ev),
               unname(extract(fit, "ev", sample = "test", combine_chains = FALSE)))
})

callback <- function(yhat.train, yhat.test, stan_pars) {
  rho <- parent.frame()
  if (is.null(rho$fixef.ind)) {
    rho$fixef.ind <- which(grepl("^beta|gamma", names(stan_pars)))
    rho$ranef.ind <- which(startsWith(names(stan_pars), "b."))
  }
  y.test <- rho$frame$y
  fixef <- stan_pars[rho$fixef.ind]
  ranef <- stan_pars[rho$ranef.ind]

  x_means <- rho$X_means
  keep_cols <- names(x_means) != "(Intercept)"
    
  intercept_delta <- sum(fixef[keep_cols] * x_means[keep_cols])
  
  # training X is rho$X, training Zt is rho$reTrms$Zt
  fit.fixef <- as.vector(rho$test$X %*% fixef) - intercept_delta
  fit.ranef <- as.vector(Matrix::crossprod(rho$test$reTrms$Zt, ranef))
  yhat.test.full <- yhat.test + fit.fixef + fit.ranef
  
  result <- cbind(yhat.test, fit.fixef, fit.ranef)
  dimnames(result) <- list(indiv = NULL, value = colnames(result))
  result
}
environment(callback) <- fn_env



test_that("callback works with multiple dimmed results", {
  fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
                   data = df.train,
                   test = df.test,
                   cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
                   bart_args = list(n.trees = 11),
                   seed = 0,
                   callback = callback)
  expect_is(fit, "stan4bartFit")
  
  expect_equal(dim(fit$callback)[1L], nrow(df.test))
  expect_equal(dim(fit$callback)[2L], 3L)
  expect_equal(dim(fit$callback)[3L], 13L - 7L)
  expect_equal(dim(fit$callback)[4L], 2L)

  expect_null(dimnames(fit$callback)[[1L]])
  expect_equal(dimnames(fit$callback)[[2]], c("yhat.test", "fit.fixef", "fit.ranef"))
  expect_null(dimnames(fit$callback)[[3]])
  expect_equal(dimnames(fit$callback)[[4]], paste0("chain:", seq_len(2L)))

  expect_equal(names(dimnames(fit$callback)), c("indiv", "value", "iterations", "chain"))
})

Try the stan4bart package in your browser

Any scripts or data that you put into this service are public.

stan4bart documentation built on Sept. 12, 2024, 7:39 a.m.