skip_connection("stratified-sampling")
skip_on_livy()
skip_on_arrow_devel()
test_requires("dplyr")
sample_space_sz <- 100L
num_zeroes <- 50L
num_groups <- 5L
sample_sz <- 20L
sample_frac <- 0.10
num_sampling_iters <- 10L
alpha <- 0.05
## ------------------------ Data prep ------------------------------------------
sampling_test_data <- data.frame(
id = rep(seq(sample_space_sz + num_zeroes), num_groups),
group = seq(num_groups) %>%
lapply(function(x) rep(x, sample_space_sz + num_zeroes)) %>%
unlist(),
weight = rep(
c(
rep(1, 50),
rep(2, 25),
rep(4, 10),
rep(8, 10),
rep(16, 5),
rep(0, num_zeroes)
),
num_groups
)
)
sdf <- testthat_tbl(
name = "sampling_test_data",
repartition = 5L
)
## ---------------------- Sampling testing function ----------------------------
verify_sample_results <- function(weighted, replacement, sampling) {
expected_dist <- rep(0L, sample_space_sz + num_zeroes)
actual_dist <- rep(0L, sample_space_sz + num_zeroes)
if(weighted) {
weight <- expr(weight)
} else {
weight <- NULL
}
for (x in seq(num_sampling_iters)) {
set.seed(142857L + x)
if(sampling == "fraction") {
r_sample <- sampling_test_data %>%
group_by(group) %>%
sample_frac(
weight = !! weight,
replace = !! replacement,
size = sample_frac
)
spark_sample <- sdf %>%
group_by(group) %>%
sample_frac(
weight = !! weight,
replace = !! replacement,
size = sample_frac
) %>%
collect()
}
if(sampling == "number") {
r_sample <- sampling_test_data %>%
group_by(group) %>%
sample_n(
weight = !! weight,
replace = !! replacement,
size = sample_sz
)
spark_sample <- sdf %>%
group_by(group) %>%
sample_n(
weight = !! weight,
replace = !! replacement,
size = sample_sz
) %>%
collect()
}
spark_count <- spark_sample %>%
count(group) %>%
pull(n)
if(sampling == "fraction") {
expect_equal(
spark_count,
rep(ceiling((sample_space_sz + num_zeroes) * sample_frac), num_groups)
)}
if(sampling == "number") {
expect_equal(
spark_count,
rep(sample_sz, num_groups)
)}
for (id in r_sample$id) {
expected_dist[[id]] <- expected_dist[[id]] + 1L
}
for (id in spark_sample$id) {
actual_dist[[id]] <- actual_dist[[id]] + 1L
}
}
expect_warning(
res <- ks.test(x = actual_dist, y = expected_dist)
)
expect_gte(res$p.value, alpha)
}
test_that("stratified sampling without replacement works as expected", {
test_requires_version("3.0.0")
verify_sample_results(weighted = FALSE, replacement = FALSE, sampling = "number")
verify_sample_results(weighted = FALSE, replacement = FALSE, sampling = "fraction")
})
test_that("stratified sampling with replacement works as expected", {
test_requires_version("3.0.0")
verify_sample_results(weighted = FALSE, replacement = TRUE, sampling = "number")
verify_sample_results(weighted = FALSE, replacement = TRUE, sampling = "fraction")
})
test_that("stratified weighted sampling without replacement works as expected", {
test_requires_version("3.0.0")
verify_sample_results(weighted = TRUE, replacement = FALSE, sampling = "number")
verify_sample_results(weighted = TRUE, replacement = FALSE, sampling = "fraction")
})
test_that("stratified weighted sampling with replacement works as expected", {
test_requires_version("3.0.0")
verify_sample_results(weighted = TRUE, replacement = TRUE, sampling = "number")
verify_sample_results(weighted = TRUE, replacement = TRUE, sampling = "fraction")
})
test_that("stratified sampling returns repeatable results from a fixed PRNG seed", {
test_requires_version("3.0.0")
for (replacement in c(TRUE, FALSE)) {
for (weighted in c(TRUE, FALSE)) {
args <- list(
size = sample_sz,
weight = if (weighted) as.symbol("weight") else NULL,
replace = replacement
)
samples <- lapply(
seq(2),
function(x) {
set.seed(142857L)
sdf %>%
dplyr::group_by(group) %>>%
dplyr::sample_n %@% args %>%
collect()
}
)
expect_equivalent(
samples[[1]] %>% dplyr::arrange(id),
samples[[2]] %>% dplyr::arrange(id)
)
args <- list(
size = sample_frac,
weight = if (weighted) as.symbol("weight") else NULL,
replace = replacement
)
samples <- lapply(
seq(2),
function(x) {
set.seed(142857L)
sdf %>%
dplyr::group_by(group) %>>%
dplyr::sample_frac %@% args %>%
collect()
}
)
expect_equivalent(
samples[[1]] %>% dplyr::arrange(id),
samples[[2]] %>% dplyr::arrange(id)
)
}
}
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.