tests/testthat/test-mixstanvar.R

context("mixstanvar: Mixture Distribution brms Adapter")

skip_if_not_installed("brms")

## various tests around mixture distributions
set.seed(234534)

if(getOption("brms.backend", "not_set") == "not_set") {
    .brms_backend <- Sys.getenv("BRMS_BACKEND", "not_set")
    if(.brms_backend != "not_set") {
        options(brms.backend=.brms_backend)
    }
}
if(getOption("cmdstanr_write_stan_file_dir", "not_set") == "not_set") {
    .brms_cache_dir <- Sys.getenv("BRMS_CACHE_DIR", "not_set")
    if(.brms_cache_dir != "not_set") {
        options(cmdstanr_write_stan_file_dir=.brms_cache_dir)
    }
}

## sample based match requirements
eps <- 1E-1

## define the different test cases (univariate)
beta <- mixbeta(c(1, 11, 4))
betaMix <- mixbeta(c(0.8, 11, 4), c(0.2, 1, 1))

gamma <- mixgamma(c(1, 5, 10), param="mn")
gammaMix <- mixgamma(rob=c(0.25, 8, 0.5), inf=c(0.75, 8, 10), param="mn")

norm <- mixnorm(c(1, 0, sqrt(2)), sigma=1)

normMix <- mixnorm(c(0.2, 0, 2), c(0.8, 1, 2), sigma=1)
normMixWeak <- mixnorm(c(0.2, 0, 2), c(0.8, 1, 2), c(0, 0, 1), sigma=1)

## tests the quantile and distribution function against simulated
## samples when using brms prior sampling as reference
mixstanvar_simul_test <- function(mix,
                                  brms_args,
                                  eps, qtest, ptest = seq(0.2, 0.8, by=0.2)) {
    skip_on_cran()
    skip_on_ci()
    capture.output(brms_prior <- do.call(brms::brm, c(brms_args, list(seed=1423545, refresh=0, sample_prior="only", stanvars=mixstanvar(prior=mix)))))
    samp <- as.numeric(brms::as_draws_matrix(brms_prior, variable="b_Intercept")[,1])
    qtest_samp <- quantile(samp, ptest)
    qref_qmix <- qmix(mix, ptest)
    res_quants <- abs(qref_qmix - qtest_samp)
    expect_true(all(res_quants < eps))
    ptest_samp <- vapply(qtest, function(q) mean(samp < q), c(0.1))
    pref_pmix <- pmix(mix, qtest)
    res_probs <- abs(pref_pmix - ptest_samp)
    expect_true(all(res_probs < eps))
}


mixstanvar_test <- function(mix,
                            brms_args) {
    brms_prior_empty <- do.call(brms::brm, c(brms_args, list(seed=1423545, refresh=0, sample_prior="only", stanvars=mixstanvar(prior=mix), empty=TRUE)))
    stan_dist <- paste0("mix", gsub("Mix$", "", class(mix)[1]), "_lpdf")
    ## look for the declared density in Stan
    expect_true(grep(stan_dist, brms::stancode(brms_prior_empty)) == 1, info="Looking for declared Stan mixture density in generated brms Stan code.")
    ## now check for the mixture being passed to Stan as data
    stan_data <- brms::standata(brms_prior_empty)
    for(i in 1:3) {    
        param <- paste0("prior_", rownames(mix)[i])
        expect_true(all(unname(mix[i,]) == unname(stan_data[[param]])))
    }
}

brms_beta_args <- list(formula=brms::bf(r | trials(n) ~ 1, family=brms::brmsfamily("binomial", link="identity"), center=FALSE),
                       data=data.frame(r=0, n=0),
                       prior=brms::prior(mixbeta(prior_w, prior_a, prior_b), coef=Intercept))

test_that("Beta quantiles are correct for brms sampled prior", mixstanvar_simul_test(beta, brms_beta_args, eps, c(0.1, 0.9)))
test_that("Beta prior is declared correctly in brms generated model and data", mixstanvar_test(beta, brms_beta_args))
test_that("Beta mixture quantiles are correct for brms sampled prior", mixstanvar_simul_test(betaMix, brms_beta_args, eps, c(0.1, 0.9)))
test_that("Beta mixture prior is declared correctly in brms generated model and data", mixstanvar_test(betaMix, brms_beta_args))

brms_normal_args <- list(formula=brms::bf(y ~ 1, family=brms::brmsfamily("gaussian", link="identity"), center=FALSE),
                         data=data.frame(y=0),
                         prior=brms::prior(mixnorm(prior_w, prior_m, prior_s), coef=Intercept) + brms::prior(constant(1), class=sigma))
test_that("Normal quantiles are correct for brms sampled prior", mixstanvar_simul_test(norm, brms_normal_args, eps, c(-1, 0)))
test_that("Normal prior is declared correctly in brms generated model and data", mixstanvar_test(norm, brms_normal_args))
test_that("Normal mixture quantiles are correct for brms sampled prior", mixstanvar_simul_test(normMix, brms_normal_args, eps, c(2, 1), ptest=c(0.3, 0.5, 0.7)))
test_that("Normal mixture prior is declared correctly in brms generated model and data", mixstanvar_test(normMix, brms_normal_args))

brms_gamma_args <- list(formula=brms::bf(y ~ 1, family=brms::brmsfamily("gaussian", link="identity"), center=FALSE),
                        data=data.frame(y=1),
                        prior=brms::prior(mixgamma(prior_w, prior_a, prior_b), coef=Intercept) + brms::prior(constant(1), class=sigma))

test_that("Gamma quantiles are correct for brms sampled prior", mixstanvar_simul_test(gamma, brms_gamma_args, eps, c(2, 7)))
test_that("Gamma prior is declared correctly in brms generated model and data", mixstanvar_test(gamma, brms_gamma_args))
test_that("Gamma mixture quantile function is correct for brms sampled prior", mixstanvar_simul_test(gammaMix, brms_gamma_args, eps, c(2, 7), ptest = seq(0.2, 0.8, by=0.2)))
test_that("Gamma mixture prior is declared correctly in brms generated model and data", mixstanvar_test(gammaMix, brms_gamma_args))

# Here we approximate the samples using a multi-variante normal via a
# moment based approxmation and compare this to the respective
# approximation of the multi variate normal mixture. While this is not
# correct per se, it is sufficient for testing as these
# approximationas are unique and they do test for marginal means and
# correlations. See for details
# https://statproofbook.github.io/P/mvn-kl.html
KLdiv_mvnorm <- function(m_1, sigma_1, m_2, sigma_2) {
    m_delta <- (m_2 - m_1)
    inv_sigma_2 <- solve(sigma_2)
    p <- length(m_1)
    0.5 * ( t(m_delta) %*% inv_sigma_2 %*% m_delta + sum(diag(inv_sigma_2 %*% sigma_1)) - log(det(sigma_1)) + log(det(sigma_2)) - p )
}

mixstanvar_simul_mv_test <- function(mvmix, brms_args, eps) {
    skip_on_cran()
    skip_on_ci()
    capture.output(brms_prior <- do.call(brms::brm, c(brms_args, list(seed=1423545, refresh=0, sample_prior="only", stanvars=mixstanvar(prior=mvmix)))))
    samp <- brms::as_draws_matrix(brms_prior, variable="^b_", regex=TRUE)
    samp_m <- colMeans(samp)
    samp_sigma <- cov(samp)    
    mix_m <- summary(mvmix)$mean
    mix_sigma <- summary(mvmix)$cov
    kl <- KLdiv_mvnorm(samp_m, samp_sigma, mix_m, mix_sigma)
    expect_true(abs(kl) < eps)
}

mixstanvar_test_mvnormMix <- function(mix,
                                      brms_args) {
    brms_prior_empty <- do.call(brms::brm, c(brms_args, list(seed=1423545, refresh=0, sample_prior="only", stanvars=mixstanvar(prior=mix), empty=TRUE)))
    stan_dist <- paste0("mix", gsub("Mix$", "", class(mix)[1]), "_lpdf")
    ## look for the declared density in Stan
    expect_true(grep(stan_dist, brms::stancode(brms_prior_empty)) == 1, info="Looking for declared Stan mixture density in generated brms Stan code.")
    ## now check for the mixture being passed to Stan as data
    stan_data <- brms::standata(brms_prior_empty)
    ## number of mixture components
    Nc <- ncol(mix)
    expect_equal(stan_data$prior_Nc, Nc)
    ## dimensionality
    p <- length(summary(mix)$mean)
    expect_equal(stan_data$prior_p, p)
    ## weights per component
    expect_equal(stan_data$prior_w, array(unname(mix[1,])))
    ## means per component
    expect_equal(unname(stan_data$prior_m), unname(t(mix[2:(p+1),])))
    ## covariance information
    for(i in seq_len(Nc)) {
        S_c <- matrix(stan_data$prior_sigma[i,,,drop=FALSE], p, p)
        expect_equal(sqrt(diag(S_c)), unname(mix[(p+2):(1+2*p),i]))
        Rho_c <- cov2cor(S_c)
        expect_equal(Rho_c[lower.tri(Rho_c)], unname(mix[(1+2*p+1):nrow(mix),i]))
    }
}

p <- 4
Rho <- diag(p)
Rho[lower.tri(Rho)] <- c(0.3, 0.8, -0.2, 0.1, 0.5, -0.4)
Rho[upper.tri(Rho)] <- t(Rho)[upper.tri(Rho)]
s <- c(1, 2, 3, 4)
S <- diag(s, p) %*% Rho %*% diag(s, p)
zero <- rep(0, p)
m1 <- 0:3
m2 <- 1:4

mvnorm_single_4 <- mixmvnorm(c(1, m1, 5),
                           param="mn", sigma=S)

mvnorm_heavy_4 <- mixmvnorm(c(0.5, m1, 0.25),
                          c(0.5, m2, 5),
                          param="mn", sigma=S)



brms_mvn_4_args <- list(formula=brms::bf(y ~ 1 + l1 + l2 + l3, family=brms::brmsfamily("gaussian", link="identity"), center=FALSE),
                        data=data.frame(y=1, l1=0, l2=0, l3=0),
                        prior=brms::prior(mixmvnorm(prior_w, prior_m, prior_sigma_L), class=b) + brms::prior(constant(1), class=sigma))

test_that("Multivariate normal (4D) is correct for brms sampled prior", mixstanvar_simul_mv_test(mvnorm_single_4, brms_mvn_4_args, eps))
test_that("Multivariate normal (4D) prior is declared correctly in brms generated model and data", mixstanvar_test_mvnormMix(mvnorm_single_4, brms_mvn_4_args))
test_that("Multivariate normal with heavy (4D) tails is correct for brms sampled prior", mixstanvar_simul_mv_test(mvnorm_heavy_4, brms_mvn_4_args, eps))
test_that("Multivariate normal with heavy (4D) tails is declared correctly in brms generated model and data", mixstanvar_test_mvnormMix(mvnorm_heavy_4, brms_mvn_4_args))

mvnorm_single_2 <- mixmvnorm(c(1, m1[1:2], 5),
                             param="mn", sigma=S[1:2,1:2])


mvnorm_heavy_2 <- mixmvnorm(c(0.5, m1[1:2], 0.25),
                            c(0.5, m2[1:2], 5),
                            param="mn", sigma=S[1:2,1:2])

brms_mvn_2_args <- list(formula=brms::bf(y ~ 1 + l1, family=brms::brmsfamily("gaussian", link="identity"), center=FALSE),
                        data=data.frame(y=1, l1=0),
                        prior=brms::prior(mixmvnorm(prior_w, prior_m, prior_sigma_L), class=b) + brms::prior(constant(1), class=sigma))

test_that("Multivariate normal (2D) is correct for brms sampled prior", mixstanvar_simul_mv_test(mvnorm_single_2, brms_mvn_2_args, eps))
test_that("Multivariate normal (2D) is declared correctly in brms generated model and data", mixstanvar_test_mvnormMix(mvnorm_single_2, brms_mvn_2_args))
test_that("Multivariate normal with heavy (2D) tails is correct for brms sampled prior", mixstanvar_simul_mv_test(mvnorm_heavy_2, brms_mvn_2_args, eps))
test_that("Multivariate normal with heavy (2D) is declared correctly in brms generated model and data", mixstanvar_test_mvnormMix(mvnorm_heavy_2, brms_mvn_2_args))

Try the RBesT package in your browser

Any scripts or data that you put into this service are public.

RBesT documentation built on May 29, 2024, 10:40 a.m.