skip_connection("dplyr")
test_requires("dplyr")
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
mtcars_tbl <- testthat_tbl("mtcars")
has_predicates <- tidyselect_data_has_predicates(mtcars_tbl)
df1 <- tibble(a = 1:3, b = letters[1:3])
df2 <- tibble(b = letters[1:3], c = letters[24:26])
df1_tbl <- testthat_tbl("df1")
df2_tbl <- testthat_tbl("df2")
sdf_5 <- copy_to(sc, data.frame(id = 1:5))
sdf_10 <- copy_to(sc, data.frame(id = 1:10))
dplyr_across_test_cases_df <- tibble(
x = seq(3),
y = as.character(seq(3)),
t = as.POSIXct(seq(3), origin = "1970-01-01"),
z = seq(3) + 5L
)
dplyr_across_test_cases_tbl <- testthat_tbl("dplyr_across_test_cases_df")
test_remote_name <- function(x, y) {
if (packageVersion("dbplyr") <= "2.3.4") {
y <- ident(y)
}
expect_equal(dbplyr::remote_name(x), y)
}
scalars_df <- dplyr::tibble(
row_num = seq(4),
b_a = c(FALSE, FALSE, TRUE, TRUE),
b_b = c(FALSE, TRUE, FALSE, TRUE),
ba = FALSE,
bb = TRUE,
n_a = c(2, 3, 6, 7),
n_b = c(3, 6, 2, 7),
c_a = c("aa", "ab", "ca", "dd"),
c_b = c("ab", "bc", "ac", "ad")
)
scalars_sdf <- copy_to(sc, scalars_df, overwrite = TRUE)
arrays_df <- dplyr::tibble(
row_num = seq(4),
a_a = list(1:4, 2:5, 3:6, 4:7),
a_b = list(4:7, 3:6, 2:5, 1:4)
)
arrays_sdf <- copy_to(sc, arrays_df, overwrite = TRUE)
test_that("'select' works with where(...) predicate", {
skip_if(!has_predicates)
expect_equal(
iris %>% select(where(is.numeric)) %>% tbl_vars() %>% gsub("\\.", "_", .),
iris_tbl %>% select(where(is.numeric)) %>% collect() %>% tbl_vars()
)
})
test_that("'n_distinct' summarizer works as expected", {
skip_connection("supports-na")
summarize_n_distinct <- function(input) {
input %>%
summarize(
n_distinct_default = n_distinct(x ^ 2),
n_distinct_na_rm_true = n_distinct(x ^ 2, na.rm = TRUE),
n_distinct_na_rm_false = n_distinct(x ^ 2, na.rm = FALSE)
)
}
df <- dplyr::tibble(x = c(-3L:2L, NA, NaN, NA))
sdf <- copy_to(sc, df, name = random_string())
expect_equal(
df %>% summarize_n_distinct(),
sdf %>% summarize_n_distinct() %>% collect(),
ignore_attr = TRUE
)
})
test_that("'summarize' works with where(...) predicate", {
skip_if(!has_predicates)
expect_equivalent(
iris %>% summarize(across(where(is.numeric), mean)),
iris_tbl %>% summarize(across(where(is.numeric), ~mean(.x, na.rm = TRUE))) %>% collect()
)
expect_equivalent(
iris %>% summarize(across(starts_with("Petal"), mean)),
iris_tbl %>% summarize(across(starts_with("Petal"), ~mean(.x, na.rm = TRUE))) %>% collect()
)
expect_equivalent(
iris %>% summarize(across(where(is.factor), n_distinct)),
iris_tbl %>% summarize(across(where(is.character), n_distinct)) %>% collect()
)
})
test_that("'mutate' works as expected", {
expect_equal(
iris %>% mutate(x = Species) %>% tbl_vars() %>% gsub("\\.", "_", .),
iris_tbl %>% mutate(x = Species) %>% collect() %>% tbl_vars()
)
})
test_that("'mutate' and 'transmute' work with NSE", {
col <- "mpg"
expect_equivalent(
mtcars_tbl %>% mutate(!!col := !!rlang::sym(col) * 2) %>% collect(),
mtcars %>% mutate(!!col := !!rlang::sym(col) * 2)
)
expect_equivalent(
mtcars_tbl %>% transmute(!!col := !!rlang::sym(col) * 2) %>% collect(),
mtcars %>% transmute(!!col := !!rlang::sym(col) * 2)
)
})
test_that("the implementation of 'filter' functions as expected", {
expect_equivalent(
iris_tbl %>%
filter(Sepal_Length == 5.1) %>%
filter(Sepal_Width == 3.5) %>%
filter(Petal_Length == 1.4) %>%
filter(Petal_Width == 0.2) %>%
select(Species) %>%
collect(),
iris %>%
transmute(
Sepal_Length = `Sepal.Length`,
Sepal_Width = `Sepal.Width`,
Petal_Length = `Petal.Length`,
Petal_Width = `Petal.Width`,
Species = Species
) %>%
filter(Sepal_Length == 5.1) %>%
filter(Sepal_Width == 3.5) %>%
filter(Petal_Length == 1.4) %>%
filter(Petal_Width == 0.2) %>%
transmute(Species = as.character(Species))
)
})
test_that("if_else works as expected", {
sdf <- copy_to(sc, dplyr::tibble(x = c(0.9, NA_real_, 1.1)))
expect_equal(
sdf %>% dplyr::mutate(x = ifelse(x > 1, "good", "bad")) %>% dplyr::pull(x),
c("bad", NA, "good")
)
expect_equal(
sdf %>% dplyr::mutate(x = ifelse(x > 1, "good", "bad", "unknown")) %>%
dplyr::pull(x),
c("bad", "unknown", "good")
)
})
test_that("if_all and if_any work as expected", {
test_requires_package_version("dbplyr", 2)
expect_equivalent(
scalars_sdf %>%
filter(if_any(starts_with("b_"))) %>%
collect(),
scalars_df %>%
filter(if_any(starts_with("b_")))
)
expect_equivalent(
scalars_sdf %>%
filter(if_all(starts_with("b_"))) %>%
collect(),
scalars_df %>%
filter(if_all(starts_with("b_")))
)
})
test_that("if_all and if_any work as expected with boolean predicates", {
test_requires_package_version("dbplyr", 2)
test_requires_version("2.4.0")
skip_on_arrow()
expect_equivalent(
scalars_sdf %>%
filter(if_all(starts_with("n_"), ~ .x > 5)) %>%
collect(),
scalars_df %>% filter(if_all(starts_with("n_"), ~ .x > 5))
)
expect_equivalent(
scalars_sdf %>%
filter(if_any(starts_with("n_"), ~ .x > 5)) %>%
collect(),
scalars_df %>% filter(if_any(starts_with("n_"), ~ .x > 5))
)
expect_equivalent(
scalars_sdf %>%
filter(if_all(starts_with("n_"), c(~ .x > 5, ~ .x < 3))) %>%
collect(),
scalars_df %>% filter(if_all(starts_with("n_"), c(~ .x > 5, ~ .x < 3)))
)
expect_equivalent(
scalars_sdf %>%
filter(if_any(starts_with("n_"), c(~ .x > 6, ~ .x < 3))) %>%
collect(),
scalars_df %>% filter(if_any(starts_with("n_"), c(~ .x > 6, ~ .x < 3)))
)
# if_all/if_any is totally dependent on dbplyr implementation
# there is a warning that does not seem to be
# generated by sparklyr code
expect_warning(
scalars_sdf %>%
dplyr::filter(if_all(starts_with("c_"), grepl, "caabac"))
)
expect_equivalent(
scalars_sdf %>%
dplyr::filter(if_all(starts_with("c_"), grepl, "caabac")) %>%
pull(row_num),
c(1L, 3L)
)
expect_equivalent(
scalars_sdf %>%
dplyr::filter(if_any(starts_with("c_"), grepl, "aac")) %>%
pull(row_num),
c(1L, 3L)
)
expect_equivalent(
scalars_sdf %>%
dplyr::filter(if_any(starts_with("c_"), grepl, "bcad")) %>%
pull(row_num),
c(2L, 3L, 4L)
)
expect_equivalent(
arrays_sdf %>%
filter(if_all(starts_with("a_"), ~ array_contains(.x, 5L))) %>%
pull(row_num),
c(2L, 3L)
)
expect_equivalent(
arrays_sdf %>%
filter(if_any(starts_with("a_"), ~ array_contains(.x, 7L))) %>%
pull(row_num),
c(1L, 4L)
)
})
test_that("grepl works as expected", {
regexes <- c(
"a|c", ".", "b", "x|z", "", "y", "e", "^", "$", "^$", "[0-9]", "[a-z]", "[b-z]"
)
verify_equivalent <- function(actual, expected) {
# handle an edge case for arrow-enabled Spark connection
for (col in colnames(df2)) {
expect_equivalent(
as.character(actual[[col]]),
as.character(expected[[col]])
)
}
}
for (regex in regexes) {
verify_equivalent(
df2 %>% dplyr::filter(grepl(regex, b)),
df2_tbl %>% dplyr::filter(grepl(regex, b)) %>% collect()
)
verify_equivalent(
df2 %>% dplyr::filter(grepl(regex, c)),
df2_tbl %>% dplyr::filter(grepl(regex, c)) %>% collect()
)
}
})
test_that("'head' uses 'limit' clause", {
test_requires("dbplyr")
expect_true(
grepl(
"LIMIT",
sql_render(head(iris_tbl))
)
)
})
test_that("'sdf_broadcast' forces broadcast hash join", {
skip_connection("sdf-broadcast")
query_plan <- df1_tbl %>%
sdf_broadcast() %>%
left_join(df2_tbl, by = "b") %>%
spark_dataframe() %>%
invoke("queryExecution") %>%
invoke("analyzed") %>%
invoke("toString")
expect_match(query_plan, "B|broadcast")
})
test_that("compute() works as expected", {
sdf <- sdf_10
sdf_even <- sdf %>% dplyr::filter(id %% 2 == 0)
sdf_odd <- sdf %>% dplyr::filter(id %% 2 == 1)
expect_null(dbplyr::remote_name(sdf_even))
expect_null(dbplyr::remote_name(sdf_odd))
# caching Spark dataframes with random names
sdf_even_cached <- sdf_even %>% dplyr::compute()
sdf_odd_cached <- sdf_odd %>% dplyr::compute()
expect_equivalent(
sdf_even_cached %>% collect(),
dplyr::tibble(id = c(2L, 4L, 6L, 8L, 10L))
)
expect_equivalent(
sdf_odd_cached %>% collect(),
dplyr::tibble(id = c(1L, 3L, 5L, 7L, 9L))
)
# caching Spark dataframes with pre-determined names
sdf_congruent_to_1_mod_3 <- sdf %>% dplyr::filter(id %% 3 == 1)
sdf_congruent_to_2_mod_3 <- sdf %>% dplyr::filter(id %% 3 == 2)
expect_null(sdf_congruent_to_1_mod_3 %>% dbplyr::remote_name())
expect_null(sdf_congruent_to_2_mod_3 %>% dbplyr::remote_name())
sdf_congruent_to_1_mod_3_cached <- sdf_congruent_to_1_mod_3 %>%
dplyr::compute(name = "congruent_to_1_mod_3")
sdf_congruent_to_2_mod_3_cached <- sdf_congruent_to_2_mod_3 %>%
dplyr::compute(name = "congruent_to_2_mod_3")
test_remote_name(
sdf_congruent_to_1_mod_3_cached,
"congruent_to_1_mod_3"
)
test_remote_name(
sdf_congruent_to_2_mod_3_cached,
"congruent_to_2_mod_3"
)
temp_view <- sdf_congruent_to_2_mod_3 %>% dplyr::compute("temp_view")
test_remote_name(
temp_view, "temp_view"
)
expect_equivalent(
sdf_congruent_to_1_mod_3_cached %>% collect(),
dplyr::tibble(id = c(1L, 4L, 7L, 10L))
)
expect_equivalent(
sdf_congruent_to_2_mod_3_cached %>% collect(),
dplyr::tibble(id = c(2L, 5L, 8L))
)
})
test_that("mutate creates NA_real_ column correctly", {
sdf <- sdf_5 %>% dplyr::mutate(z = NA_real_, sq = id * id)
expect_equivalent(
sdf %>% collect(),
dplyr::tibble(id = seq(5), z = NA_real_, sq = id * id)
)
})
test_that("transmute creates NA_real_ column correctly", {
sdf <- sdf_5 %>% dplyr::transmute(z = NA_real_, sq = id * id)
expect_equivalent(
sdf %>% collect(),
dplyr::tibble(z = NA_real_, sq = seq(5) * seq(5))
)
})
test_that("overwriting a temp view", {
# Skipping while researching why override works on non-connect methods
skip()
temp_view_name <- random_string()
sdf <- sdf_5 %>%
dplyr::mutate(foo = "foo") %>%
dplyr::compute(name = temp_view_name)
sdf <- sdf_5 %>%
dplyr::compute(name = temp_view_name)
expect_equivalent(sdf %>% collect(), dplyr::tibble(id = seq(5)))
expect_equivalent(
dplyr::tbl(sc, temp_view_name) %>% collect(), dplyr::tibble(id = seq(5))
)
})
test_that("dplyr::distinct() impl is configurable", {
options(sparklyr.dplyr_distinct.impl = "tbl_lazy")
on.exit(options(sparklyr.dplyr_distinct.impl = NULL))
tbl_name <- random_string()
sdf <- copy_to(sc, data.frame(a = c(1, 1)), name = tbl_name)
query <- sdf %>%
dplyr::distinct() %>%
dbplyr::remote_query() %>%
strsplit("\\s+")
query[[1]][[3]] <- gsub(sprintf("`%s`.*", tbl_name), "*", query[[1]][[3]])
expect_equal(
toupper(query[[1]]),
c("SELECT", "DISTINCT", "*", "FROM", sprintf("`%s`", toupper(tbl_name)))
)
expect_equivalent(
sdf %>% dplyr::distinct() %>% collect(),
data.frame(a = 1)
)
})
test_that("process_tbl_name works as expected", {
skip_if(any(grepl("connect_", class(sc))))
expect_equal(sparklyr:::process_tbl_name("a"), "a")
expect_equal(sparklyr:::process_tbl_name("xyz"), "xyz")
expect_equal(sparklyr:::process_tbl_name("x.y"), dbplyr::in_schema("x", "y"))
expect_equal(sparklyr:::process_tbl_name("x.y.z"), dbplyr::in_catalog("x", "y", "z"))
df1 <- dplyr::tibble(a = 1, g = 2) %>%
copy_to(sc, ., "df1", overwrite = TRUE)
df2 <- dplyr::tibble(b = 1, g = 2) %>%
copy_to(sc, ., "df2", overwrite = TRUE)
query <- sql("SELECT df1.a, df2.b, df1.g FROM df1 LEFT JOIN df2 ON df1.g = df2.g")
expect_equivalent(
tbl(sc, query) %>% collect(),
dplyr::tibble(a = 1, b = 1, g = 2)
)
})
test_that("in_schema() works as expected", {
skip_on_arrow()
skip_on_livy()
if(spark_version(sc) < "3.4.0") {
db_name <- random_string("test_db_")
queries <- c(
sprintf("CREATE DATABASE `%s`", db_name),
sprintf(
"CREATE TABLE IF NOT EXISTS `%s`.`hive_tbl` (`x` INT) USING hive",
db_name
)
)
for (query in queries) {
DBI::dbGetQuery(sc, query)
}
expect_equivalent(
dplyr::tbl(sc, dbplyr::in_schema(db_name, "hive_tbl")) %>% collect(),
dplyr::tibble(x = integer())
)
}
})
test_that("sdf_remote_name returns null for computed tables", {
test_remote_name(iris_tbl, "iris")
virginica_sdf <- iris_tbl %>% filter(Species == "virginica")
expect_equal(dbplyr::remote_name(virginica_sdf), NULL)
})
test_that("sdf_remote_name ignores the last group_by() operation(s)", {
sdf <- iris_tbl
for (i in seq(4)) {
sdf <- sdf %>% dplyr::group_by(Species)
test_remote_name(sdf, "iris")
}
})
test_that("sdf_remote_name ignores the last ungroup() operation(s)", {
sdf <- iris_tbl
for (i in seq(4)) {
sdf <- sdf %>% dplyr::ungroup()
test_remote_name(sdf, "iris")
}
})
test_that("sdf_remote_name works with arrange followed by compute", {
tbl <- copy_to(sc, dplyr::tibble(lts = letters[26:24], nums = seq(3)))
ordered_tbl <- tbl %>% arrange(lts) %>% compute(name = "ordered_tbl")
test_remote_name(
ordered_tbl,
"ordered_tbl"
)
expect_equivalent(
tbl(sc, "ordered_tbl") %>% collect(),
dplyr::tibble(lts = letters[24:26], nums = 3:1)
)
})
test_that("result from dplyr::compute() has remote name", {
sdf <- iris_tbl
sdf <- sdf %>% dplyr::mutate(y = 5) %>% dplyr::compute()
expect_false(is.null(sdf %>% dbplyr::remote_name()))
})
test_that("tbl_ptype.tbl_spark works as expected", {
skip_if(!has_predicates)
expect_equal(df1_tbl %>% dplyr::select_if(is.integer) %>% colnames(), "a")
expect_equal(df1_tbl %>% dplyr::select_if(is.numeric) %>% colnames(), "a")
expect_equal(df1_tbl %>% dplyr::select_if(is.character) %>% colnames(), "b")
expect_equal(df1_tbl %>% dplyr::select_if(is.list) %>% colnames(), character())
})
test_that("summarise(.groups=)", {
sdf <- copy_to(sc, data.frame(x = 1, y = 2)) %>%
group_by(x, y)
expect_equal(sdf %>% summarise() %>% group_vars(), "x")
expect_equal(sdf %>% summarise(.groups = "drop_last") %>% group_vars(), "x")
expect_equal(sdf %>% summarise(.groups = "drop") %>% group_vars(), character())
expect_equal(sdf %>% summarise(.groups = "keep") %>% group_vars(), c("x", "y"))
df <- dplyr::tibble(val1 = c(1, 2, 1, 2), val2 = c(10, 20, 30, 40))
sdf <- copy_to(sc, df, name = random_string())
for (groups in c("drop_last", "drop", "keep")) {
expect_equivalent(
sdf %>%
group_by(val1) %>%
summarize(result = sum(val2, na.rm = TRUE), .groups = groups) %>%
arrange(val1) %>%
collect(),
df %>%
group_by(val1) %>%
summarize(result = sum(val2, na.rm = TRUE), .groups = groups) %>%
arrange(val1)
)
}
})
test_that("tbl_spark prints", {
print_output <- capture.output(print(iris_tbl))
expect_equal(
print_output[1],
"# Source: table<`iris`> [?? x 5]"
)
})
test_that("pmin and pmax work", {
pmin_df <- data.frame(x = 11:20, y = 1:10)
tbl_pmin_df <- sdf_copy_to(sc, pmin_df, overwrite = TRUE)
remote_p <- tbl_pmin_df %>%
mutate(
p_min = pmin(x, y),
p_max = pmax(x, y)
) %>%
collect()
local_p <- pmin_df %>%
mutate(
p_min = pmin(x, y),
p_max = pmax(x, y)
)
expect_true(
all(remote_p == local_p)
)
expect_error({
collect(mutate(tbl_pmin_df, x = pmin(x, y, na.rm = FALSE)))
}, regexp = "na.rm = TRUE")
expect_error({
collect(mutate(tbl_pmin_df, x = pmax(x, y, na.rm = FALSE)))
}, regexp = "na.rm = TRUE")
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.