skip_connection("serialization")
skip_on_livy()
sc <- testthat_spark_connection()
test_requires("nycflights13")
flights_small <- flights %>% dplyr::sample_n(10000)
flights_tbl <- testthat_tbl("flights_small")
logical_nas <- tibble(bools = c(T, NA, F))
logical_nas_tbl <- testthat_tbl("logical_nas")
ensure_round_trip <- function(sc, data) {
# round-trip data through Spark
copied <- copy_to(sc, data, overwrite = TRUE)
collected <- as.data.frame(collect(copied))
# compare without row.names (as we don't preserve those)
lhs <- data
row.names(lhs) <- NULL
rhs <- collected
row.names(rhs) <- NULL
expect_equal(unname(lhs), unname(rhs))
}
test_that("objects survive Spark roundtrips", {
datasets <- list(mtcars = mtcars)
for (dataset in datasets) {
ensure_round_trip(sc, dataset)
}
})
test_that("primitive values survive Spark roundtrips", {
n <- 10
df <- data.frame(
int = as.integer(1:n),
dbl = as.double(1:n),
lgl = rep_len(c(TRUE, FALSE), n),
chr = letters[1:10],
stringsAsFactors = FALSE
)
ensure_round_trip(sc, df)
})
test_that("NA values survive Spark roundtrips", {
n <- 10
df <- data.frame(
int = as.integer(1:n),
dbl = as.double(1:n),
# lgl = rep_len(c(TRUE, FALSE), n), # TODO
# chr = letters[1:10], # TODO
stringsAsFactors = FALSE
)
df[n / 2, ] <- NA
ensure_round_trip(sc, df)
})
test_that("data.frames with '|' can be copied", {
pipes <- data.frame(
x = c("|||", "|||", "|||"),
y = c(1, 2, 3),
stringsAsFactors = FALSE
)
ensure_round_trip(sc, pipes)
})
test_that("data.frames with many columns survive roundtrip", {
skip_slow("takes too long to measure coverage")
skip_databricks_connect()
n <- 1E3
data <- as.data.frame(replicate(n, 1L, simplify = FALSE))
names(data) <- paste("X", 1:n, sep = "")
ensure_round_trip(sc, data)
})
test_that("data.frames with many columns don't cause Java StackOverflows", {
skip_databricks_connect()
version <- Sys.getenv("SPARK_VERSION", unset = "2.2.0")
n <- if (version >= "2.0.0") 500 else 5000
df <- matrix(0, ncol = n, nrow = 2) %>% as_tibble(.name_repair = "unique")
sdf <- copy_to(sc, df, overwrite = TRUE)
# the above failed with a Java StackOverflow with older versions of sparklyr
expect_true(TRUE, info = "no Java StackOverflow on copy of large dataset")
})
test_that("'ml_predict()', 'predict()' return same results", {
skip_databricks_connect()
test_requires("dplyr")
model <- flights_tbl %>%
na.omit() %>%
ml_decision_tree(sched_dep_time ~ dep_time)
predictions <- ml_predict(model)
n1 <- spark_dataframe(predictions) %>% invoke("count")
n2 <- length(predict(model))
expect_equal(n1, n2)
lhs <- predictions %>%
sdf_read_column("prediction")
rhs <- predict(model)
expect_identical(lhs, rhs)
})
test_that("copy_to() succeeds when last column contains missing / empty values", {
df <- data.frame(
x = c(1, 2),
z = c(NA, ""),
stringsAsFactors = FALSE
)
df_tbl <- copy_to(sc, df, overwrite = TRUE)
expect_equal(sdf_nrow(df_tbl), 2)
expect_equal(sdf_ncol(df_tbl), 2)
})
arrow_compat <- using_arrow()
if (arrow_compat && packageVersion("arrow") > "0.17.1") {
# Arrow will pull int64 into R as integer if all values fit in 32-bit int
# (which they do in this test suite)
arrowbigint <- "integer"
} else {
arrowbigint <- "integer64"
}
test_that("collect() can retrieve all data types correctly", {
skip_databricks_connect()
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types#LanguageManualTypes
library(dplyr)
epoch_utime <- 0
epoch_sdate <- "from_unixtime(unix_timestamp('01-01-1970' , 'dd-MM-yyyy'))"
epoch_rdate <- as.Date("01-01-1970", "%d-%m-%Y") %>% as.character()
epoch_stime <- paste0("to_utc_timestamp(from_unixtime(", epoch_utime, "), 'UTC')")
epoch_rtime <- "1970-01-01"
epoch_atime <- as.character(as.POSIXct(epoch_utime, origin = "1970-01-01"))
utime <- as.numeric(as.POSIXct("2010-01-01 01:01:10", origin = "1970-01-01", tz = "UTC"))
sdate <- "from_unixtime(unix_timestamp('01-01-2010' , 'dd-MM-yyyy'))"
rdate <- as.Date("01-01-2010", "%d-%m-%Y") %>% as.character()
stime <- paste0("to_utc_timestamp(from_unixtime(", utime, "), 'UTC')")
rtime <- "2010-01-01 01:01:10"
atime <- as.character(as.POSIXct(utime, origin = "1970-01-01"))
hive_type <- dplyr::tribble(
~stype, ~svalue, ~rtype, ~rvalue, ~atype, ~avalue,
"tinyint", "1", "integer", "1", "integer", "1",
"smallint", "1", "integer", "1", "integer", "1",
"integer", "1", "integer", "1", "integer", "1",
"bigint", "1", "numeric", "1", arrowbigint, "1",
"float", "1", "numeric", "1", "numeric", "1",
"double", "1", "numeric", "1", "numeric", "1",
"decimal", "1", "numeric", "1", "numeric", "1",
"timestamp", epoch_stime, "POSIXct", epoch_rtime, "POSIXct", epoch_atime,
"date", epoch_sdate, "Date", epoch_rdate, "Date", epoch_rdate,
"timestamp", stime, "POSIXct", rtime, "POSIXct", atime,
"date", sdate, "Date", rdate, "Date", rdate,
"string", "1", "character", "1", "character", "1",
"varchar(10)", "1", "character", "1", "character", "1",
"char(10)", "1", "character", "1", "character", "1",
"boolean", "true", "logical", "TRUE", "logical", "TRUE",
)
if (spark_version(sc) < "2.2.0") {
hive_type <- hive_type %>% filter(stype != "integer")
}
if (.Platform$OS.type == "windows") {
# Deserialization of Date type from Spark SQL has been problematic on Windows
# for some strange, platform-specific reasons.
hive_type <- hive_type %>% filter(stype != "date")
}
spark_query <- hive_type %>%
mutate(
query = paste0("cast(", svalue, " as ", stype, ") as ", gsub("\\(|\\)", "", stype), "_col", row_number())
) %>%
pull(query) %>%
paste(collapse = ", ") %>%
paste("SELECT", .)
spark_types <- DBI::dbGetQuery(sc, spark_query) %>%
lapply(function(e) class(e)[[1]]) %>%
as.character()
expect_equal(
spark_types,
hive_type %>% pull(!!if (arrow_compat) "atype" else "rtype")
)
spark_results <- DBI::dbGetQuery(sc, spark_query)
names(spark_results) <- NULL
spark_results <- sapply(spark_results, as.character)
skip_on_os("mac")
expect_equal(
spark_results,
hive_type %>% pull(!!if (arrow_compat) "avalue" else "rvalue")
)
})
test_that("collect() can retrieve NULL data types as NAs", {
library(dplyr)
hive_type <- dplyr::tribble(
~stype, ~rtype, ~atype,
"tinyint", "integer", "integer",
"smallint", "integer", "integer",
"integer", "integer", "integer",
"bigint", "numeric", arrowbigint,
"float", "numeric", "numeric",
"double", "numeric", "numeric",
"decimal", "numeric", "numeric",
"timestamp", "POSIXct", "POSIXct",
"date", "Date", "Date",
"string", "character", "character",
"varchar(10)", "character", "character",
"char(10)", "character", "character",
)
if (spark_version(sc) < "2.2.0") {
hive_type <- hive_type %>% filter(stype != "integer")
}
spark_query <- hive_type %>%
mutate(
query = paste0("cast(NULL as ", stype, ") as ", gsub("\\(|\\)", "", stype), "_col")
) %>%
pull(query) %>%
paste(collapse = ", ") %>%
paste("SELECT", .)
spark_types <- DBI::dbGetQuery(sc, spark_query) %>%
lapply(function(e) class(e)[[1]]) %>%
as.character()
expect_equal(
spark_types,
hive_type %>% pull(!!if (arrow_compat) "atype" else "rtype")
)
spark_results <- DBI::dbGetQuery(sc, spark_query)
lapply(names(spark_results), function(e) {
expect_true(is.na(spark_results[[e]]), paste(e, "expected to be NA"))
})
})
test_that("collect() can retrieve date types successfully", {
skip_on_windows()
df <- dplyr::tibble(
date = as.Date(
c(
"1000-01-01",
"1888-06-01",
"1969-12-31",
"1970-01-01",
"1970-01-02",
"1981-01-20",
"2001-01-20",
"3111-01-20"
)
)
)
expect_equivalent(
df,
df %>% sdf_copy_to(sc, ., overwrite = TRUE) %>% sdf_collect()
)
})
test_that("invoke() can roundtrip POSIXlt fields", {
invoke_static(
sc,
"sparklyr.Test",
"roundtrip",
list(
as.POSIXlt(Sys.time(), "GMT"),
as.POSIXlt(Sys.time(), "GMT")
)
)
succeed()
})
test_that("invoke() can roundtrip collect fields", {
invoke_static(
sc,
"sparklyr.Test",
"roundtrip",
list(
as.POSIXlt(Sys.time(), "GMT"),
as.POSIXlt(Sys.time(), "GMT")
)
)
succeed()
})
test_that("collect() can retrieve specific dates without timezones", {
data_tbl <- sdf_copy_to(
sc,
tibble(t = c(1419126103))
)
expect_equal(
as.double(
data_tbl %>%
mutate(date_alt = from_utc_timestamp(timestamp(t), "UTC")) %>%
pull(date_alt)
),
as.double(as.POSIXct("2014-12-21 01:41:43 UTC", tz = "UTC")),
tolerance = 0.1,
scale = 1
)
expect_equal(
as.double(
data_tbl %>%
mutate(date_alt = to_date(from_utc_timestamp(timestamp(t), "UTC"))) %>%
pull(date_alt)
),
as.double(
as.Date(
data_tbl %>%
mutate(date_alt = as.character(to_date(from_utc_timestamp(timestamp(t), "UTC")))) %>%
pull(date_alt)
)
),
tolerance = 0.1,
scale = 1
)
})
test_that("collect() can retrieve logical columns with NAs", {
expect_equal(
logical_nas,
logical_nas_tbl %>% dplyr::collect()
)
})
test_that("environments are sent to Scala Maps (#1058)", {
expect_identical(
invoke_static(sc, "sparklyr.Test", "readMap", as.environment(list(foo = 5))),
list(foo = 5)
)
expect_identical(
invoke_static(sc, "sparklyr.Test", "readMap", as.environment(list(foo = 2L))),
list(foo = 2L)
)
expect_identical(
invoke_static(sc, "sparklyr.Test", "readMap", as.environment(list(foo = "bar"))),
list(foo = "bar")
)
})
test_that("collect() can retrieve nested list efficiently", {
skip_databricks_connect()
skip_on_windows()
if (spark_version(sc) < "2.0.0") skip("performance improvement not available")
temp_json <- tempfile(fileext = ".json")
list(list(g = 1, d = 1:1000), list(g = 2, d = 1:1000)) %>%
jsonlite::write_json(temp_json, auto_unbox = T)
nested <- spark_read_json(sc, temp_json, memory = FALSE)
expect_equal(nrow(collect(nested)), 2)
})
test_that("array of temporal values are preserved with Spark 3.0+", {
test_requires_version("3.0.0")
skip_on_arrow()
df <- dplyr::tibble(
char = c("one", "two"),
int = c(3L, 4L),
int_arr = list(seq(5), seq(10)),
date_arr = list(
as.Date(seq(5) * 700, origin = "1970-01-01"),
as.Date(seq(7) * 100, origin = "1970-01-01")
),
timestamp_arr = list(
as.POSIXct(seq(5) * 700000, origin = "1970-01-01"),
as.POSIXct(seq(7) * 1000000, origin = "1970-01-01")
),
char_arr = list(letters[1:13], letters[14:26])
)
sdf <- copy_to(sc, df, overwrite = TRUE)
expect_equivalent(sdf %>% collect(), df)
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.