tests/testthat/test-dirichlet.R

context("test-dirichlet")

testthat::skip_if_not_installed('rsamplestudy')

# Test that Dirichlet functions work
n <- 100000
alpha_target_1 <- c(1, 2, 3, 4, 5)
alpha_target_2 <- c(0.1, 0.2, 3, 4, 0.5)
alpha_target_3 <- c(0.1, 0.1, 0.1, 0.1, 0.1)
df_1 <- rsamplestudy::fun_rdirichlet(n, alpha_target_1)
df_2 <- rsamplestudy::fun_rdirichlet(n, alpha_target_2)
df_3 <- rsamplestudy::fun_rdirichlet(n, alpha_target_3)
p <- length(alpha_target_1)

# Bind the dataframes
df_1_s <- df_1
df_2_s <- df_2
df_3_s <- df_3
df_1_s$source <- 1
df_2_s$source <- 2
df_3_s$source <- 3
df <- dplyr::bind_rows(df_1_s, df_2_s, df_3_s)


# Single source -----------------------------------------------------------

expect_equal_tol <- function(...) expect_equal(..., tolerance = 0.05, scale = NULL)

test_that("Single source: ML converges", {
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_1, use = 'ML', eps = 1e-14)), as.numeric(alpha_target_1))
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_2, use = 'ML', eps = 1e-14)), as.numeric(alpha_target_2))
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_3, use = 'ML', eps = 1e-14)), as.numeric(alpha_target_3))
})

test_that("Single source: naive converges", {
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_1, use = 'naive')), as.numeric(alpha_target_1))
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_2, use = 'naive')), as.numeric(alpha_target_2))
   expect_equal_tol(as.numeric(fun_estimate_Dirichlet_from_single_source(df_3, use = 'naive')), as.numeric(alpha_target_3))
})

test_that('Single source: wrong parameters', {
   expect_error(fun_estimate_Dirichlet_from_single_source(df_1, use = 'ZZZZ'))
   expect_error(fun_estimate_Dirichlet_from_single_source(df_1, use = '3141'))
})

# Multiple sources --------------------------------------------------------

df_item <- df %>% dplyr::rename(item = source)

test_that('Multiple sources: standard', {
          expect_silent(fun_estimate_Dirichlet_from_samples(df, use = 'ML'))
          expect_equal(nrow(fun_estimate_Dirichlet_from_samples(df, use = 'ML')), 3)
          expect_silent(fun_estimate_Dirichlet_from_samples(df_item, use = 'ML', col_source = 'item'))
})

test_that('Multiple sources: source column checks', {
          expect_silent(fun_estimate_Dirichlet_from_samples(df_item, use = 'ML', col_source = 'item'))
          expect_error(fun_estimate_Dirichlet_from_samples(df_item, use = 'ML', col_source = 'AAAA'))
          expect_error(fun_estimate_Dirichlet_from_samples(df_item, use = 'ML'))

          expect_silent(fun_estimate_Dirichlet_from_samples(df, use = 'ML', col_source = 'source'))
          expect_error(fun_estimate_Dirichlet_from_samples(df, use = 'ML', col_source = item))
          expect_error(fun_estimate_Dirichlet_from_samples(df, use = 'ML', col_source = 'item'))

          expect_identical(
            fun_estimate_Dirichlet_from_samples(df, use = 'ML', col_source = 'source') %>% dplyr::select(-source),
            fun_estimate_Dirichlet_from_samples(df_item, use = 'ML', col_source = 'item') %>% dplyr::select(-item)
       )
})


# Multiple sources: source estimates --------------------------------------

test_that('Multiple sources: source estimates are correct', {
   df_1_single <- fun_estimate_Dirichlet_from_single_source(df_1, use = 'ML')
   df_1_multiple <- fun_estimate_Dirichlet_from_samples(df, use = 'ML') %>% dplyr::filter(source == 1) %>% dplyr::select(-source)

   expect_identical(df_1_single, df_1_multiple)
})



# Hyperparameter for the DirDir model ----------------------------------------------------------

test_that('Multiple sources: DirDir hyperparameter ML estimation does not fail', {
   res <- expect_silent(fun_estimate_DirDir_hyperparameter(df, method = 'ML', col_source = 'source'))
   expect_length(res, p)
   expect_is(res, 'numeric')
})



# Hyperparameter for the DirDirGamma model ----------------------------------------------------------

test_that('Multiple sources: DirDirGamma hyperparameter ML estimation does not fail', {
   res <- expect_silent(fun_estimate_DirDirGamma_hyperparameter(df, col_source = 'source'))
   expect_is(res, 'list')
   expect_named(res, c('alpha_0', 'beta_0', 'nu_0'))
   expect_is(res$alpha_0, 'numeric')
   expect_length(res$alpha_0, 1)
   expect_is(res$beta_0, 'numeric')
   expect_length(res$beta_0, 1)
   expect_is(res$nu_0, 'numeric')
   expect_length(res$nu_0, p)
})
lgaborini/rstanBF documentation built on March 10, 2021, 1:12 p.m.