Nothing
context("stan4bart callback argument")
source(system.file("common", "friedmanData.R", package = "stan4bart"), local = TRUE)
testData <- generateFriedmanData(100, TRUE, TRUE, FALSE)
rm(generateFriedmanData)
df <- with(testData, data.frame(x, g.1 = as.factor(g.1), g.2 = as.factor(g.2), y, z))
df.train <- df[1:80,]
df.test <- df[81:100,]
callback <- function(yhat.train, yhat.test, stan_pars) {
rho <- parent.frame()
if (is.null(rho$fixef.ind)) {
rho$fixef.ind <- which(grepl("^beta|gamma", names(stan_pars)))
rho$ranef.ind <- which(startsWith(names(stan_pars), "b."))
}
y.test <- rho$frame$y
fixef <- stan_pars[rho$fixef.ind]
ranef <- stan_pars[rho$ranef.ind]
x_means <- rho$X_means
keep_cols <- names(x_means) != "(Intercept)"
intercept_delta <- sum(fixef[keep_cols] * x_means[keep_cols])
# training X is rho$X, training Zt is rho$reTrms$Zt
fit.fixef <- as.vector(rho$test$X %*% fixef) - intercept_delta
fit.ranef <- as.vector(Matrix::crossprod(rho$test$reTrms$Zt, ranef))
yhat.test.full <- yhat.test + fit.fixef + fit.ranef
result <- c(yhat.test, fit.fixef, fit.ranef)
names(result) <- c(
paste0("yhat.test_", seq_along(yhat.test)),
paste0("fit.fixef_", seq_along(fit.fixef)),
paste0("fit.ranef_", seq_along(fit.ranef))
)
result
}
fn_env <- new.env(parent = baseenv())
environment(callback) <- fn_env
test_that("callback is passed sample correctly", {
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
data = df.train,
test = df.test,
cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
bart_args = list(n.trees = 11),
seed = 0,
callback = callback)
expect_is(fit, "stan4bartFit")
indiv.bart <- fit$callback[seq_len(20L),,]
indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]
expect_equal(unname(indiv.bart),
unname(extract(fit, type = "indiv.bart", sample = "test", combine_chains = FALSE)))
expect_equal(unname(indiv.fixef),
unname(extract(fit, type = "indiv.fixef", sample = "test", combine_chains = FALSE)))
expect_equal(unname(indiv.ranef),
unname(extract(fit, type = "indiv.ranef", sample = "test", combine_chains = FALSE)))
ev <- indiv.fixef + indiv.ranef + indiv.bart
expect_equal(unname(ev),
unname(extract(fit, "ev", sample = "test", combine_chains = FALSE)))
expect_equal(
dimnames(fit$callback)[[1L]],
c(paste0("yhat.test_", seq_len(20L)),
paste0("fit.fixef_", seq_len(20L)),
paste0("fit.ranef_", seq_len(20L)))
)
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
data = df.train,
test = df.test,
cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
bart_args = list(n.trees = 11),
seed = 0, keep_fits = FALSE,
callback = callback)
expect_is(fit, "stan4bartFit")
indiv.bart <- fit$callback[seq_len(20L),,]
indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]
expect_equal(unname(ev), unname(indiv.fixef + indiv.ranef + indiv.bart))
expect_true(is.null(fit$bart_train))
expect_true(is.null(fit$bart_test))
expect_true(is.null(fit$bart_varcount))
expect_true(is.null(fit$stan))
expect_true(is.null(fit$par_names))
expect_true(is.null(fit$warmup$bart_train))
expect_true(is.null(fit$warmup$bart_test))
expect_true(is.null(fit$warmup$bart_varcount))
expect_true(is.null(fit$warmup$stan))
})
test_that("callback works with multiple threads", {
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
data = df.train,
test = df.test,
cores = 2, verbose = -1L, chains = 2, warmup = 7, iter = 13,
bart_args = list(n.trees = 11),
seed = 0,
callback = callback)
expect_is(fit, "stan4bartFit")
indiv.bart <- fit$callback[seq_len(20L),,]
indiv.fixef <- fit$callback[seq_len(20L) + 20L,,]
indiv.ranef <- fit$callback[seq_len(20L) + 40L,,]
expect_equal(unname(indiv.bart),
unname(extract(fit, type = "indiv.bart", sample = "test", combine_chains = FALSE)))
expect_equal(unname(indiv.fixef),
unname(extract(fit, type = "indiv.fixef", sample = "test", combine_chains = FALSE)))
expect_equal(unname(indiv.ranef),
unname(extract(fit, type = "indiv.ranef", sample = "test", combine_chains = FALSE)))
ev <- indiv.fixef + indiv.ranef + indiv.bart
expect_equal(unname(ev),
unname(extract(fit, "ev", sample = "test", combine_chains = FALSE)))
})
callback <- function(yhat.train, yhat.test, stan_pars) {
rho <- parent.frame()
if (is.null(rho$fixef.ind)) {
rho$fixef.ind <- which(grepl("^beta|gamma", names(stan_pars)))
rho$ranef.ind <- which(startsWith(names(stan_pars), "b."))
}
y.test <- rho$frame$y
fixef <- stan_pars[rho$fixef.ind]
ranef <- stan_pars[rho$ranef.ind]
x_means <- rho$X_means
keep_cols <- names(x_means) != "(Intercept)"
intercept_delta <- sum(fixef[keep_cols] * x_means[keep_cols])
# training X is rho$X, training Zt is rho$reTrms$Zt
fit.fixef <- as.vector(rho$test$X %*% fixef) - intercept_delta
fit.ranef <- as.vector(Matrix::crossprod(rho$test$reTrms$Zt, ranef))
yhat.test.full <- yhat.test + fit.fixef + fit.ranef
result <- cbind(yhat.test, fit.fixef, fit.ranef)
dimnames(result) <- list(indiv = NULL, value = colnames(result))
result
}
environment(callback) <- fn_env
test_that("callback works with multiple dimmed results", {
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
data = df.train,
test = df.test,
cores = 1, verbose = -1L, chains = 2, warmup = 7, iter = 13,
bart_args = list(n.trees = 11),
seed = 0,
callback = callback)
expect_is(fit, "stan4bartFit")
expect_equal(dim(fit$callback)[1L], nrow(df.test))
expect_equal(dim(fit$callback)[2L], 3L)
expect_equal(dim(fit$callback)[3L], 13L - 7L)
expect_equal(dim(fit$callback)[4L], 2L)
expect_null(dimnames(fit$callback)[[1L]])
expect_equal(dimnames(fit$callback)[[2]], c("yhat.test", "fit.fixef", "fit.ranef"))
expect_null(dimnames(fit$callback)[[3]])
expect_equal(dimnames(fit$callback)[[4]], paste0("chain:", seq_len(2L)))
expect_equal(names(dimnames(fit$callback)), c("indiv", "value", "iterations", "chain"))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.