Nothing
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
test_that("list_compute_functions() works", {
expect_type(list_compute_functions(), "character")
expect_true(all(!grepl("^hash_", list_compute_functions())))
})
test_that("arrow_scalar_function() works", {
# check in/out type as schema/data type
fun <- arrow_scalar_function(
function(context, x) x$cast(int64()),
schema(x = int32()), int64()
)
expect_equal(fun$in_type[[1]], schema(x = int32()))
expect_equal(fun$out_type[[1]](), int64())
# check in/out type as data type/data type
fun <- arrow_scalar_function(
function(context, x) x$cast(int64()),
int32(), int64()
)
expect_equal(fun$in_type[[1]][[1]], field("", int32()))
expect_equal(fun$out_type[[1]](), int64())
# check in/out type as field/data type
fun <- arrow_scalar_function(
function(context, a_name) x$cast(int64()),
field("a_name", int32()),
int64()
)
expect_equal(fun$in_type[[1]], schema(a_name = int32()))
expect_equal(fun$out_type[[1]](), int64())
# check in/out type as lists
fun <- arrow_scalar_function(
function(context, x) x,
list(int32(), int64()),
list(int64(), int32()),
auto_convert = TRUE
)
expect_equal(fun$in_type[[1]][[1]], field("", int32()))
expect_equal(fun$in_type[[2]][[1]], field("", int64()))
expect_equal(fun$out_type[[1]](), int64())
expect_equal(fun$out_type[[2]](), int32())
expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32()))
})
test_that("arrow_scalar_function() works with auto_convert = TRUE", {
times_32_wrapper <- arrow_scalar_function(
function(context, x) x * 32,
float64(),
float64(),
auto_convert = TRUE
)
dummy_kernel_context <- list()
expect_equal(
times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))),
Array$create(2 * 32)
)
})
test_that("register_scalar_function() adds a compute function to the registry", {
skip_if_not(CanRunWithCapturedR())
# TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has
# occasional valgrind errors
skip_on_linux_devel()
register_scalar_function(
"times_32",
function(context, x) x * 32.0,
int32(), float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions))
expect_true("times_32" %in% list_compute_functions())
expect_equal(
call_function("times_32", Array$create(1L, int32())),
Array$create(32L, float64())
)
expect_equal(
call_function("times_32", Scalar$create(1L, int32())),
Scalar$create(32L, float64())
)
skip_if_not_available("acero")
expect_identical(
record_batch(a = 1L) %>%
dplyr::mutate(b = times_32(a)) %>%
dplyr::collect(),
tibble::tibble(a = 1L, b = 32.0)
)
})
test_that("arrow_scalar_function() with bad return type errors", {
skip_if_not(CanRunWithCapturedR())
register_scalar_function(
"times_32_bad_return_type_array",
function(context, x) Array$create(x, int32()),
int32(),
float64()
)
on.exit(unregister_binding("times_32_bad_return_type_array", update_cache = TRUE))
expect_error(
call_function("times_32_bad_return_type_array", Array$create(1L)),
"Expected return Array or Scalar with type 'double'"
)
register_scalar_function(
"times_32_bad_return_type_scalar",
function(context, x) Scalar$create(x, int32()),
int32(),
float64()
)
on.exit(unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE))
expect_error(
call_function("times_32_bad_return_type_scalar", Array$create(1L)),
"Expected return Array or Scalar with type 'double'"
)
})
test_that("register_scalar_function() can register multiple kernels", {
skip_if_not(CanRunWithCapturedR())
register_scalar_function(
"times_32",
function(context, x) x * 32L,
in_type = list(int32(), int64(), float64()),
out_type = function(in_types) in_types[[1]],
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
expect_equal(
call_function("times_32", Scalar$create(1L, int32())),
Scalar$create(32L, int32())
)
expect_equal(
call_function("times_32", Scalar$create(1L, int64())),
Scalar$create(32L, int64())
)
expect_equal(
call_function("times_32", Scalar$create(1L, float64())),
Scalar$create(32L, float64())
)
})
test_that("register_scalar_function() errors for unsupported specifications", {
expect_error(
register_scalar_function(
"no_kernels",
function(...) NULL,
list(),
list()
),
"Can't register user-defined scalar function with 0 kernels"
)
expect_error(
register_scalar_function(
"wrong_n_args",
function(x) NULL,
int32(),
int32()
),
"Expected `fun` to accept 2 argument\\(s\\)"
)
expect_error(
register_scalar_function(
"var_kernels",
function(...) NULL,
list(float64(), schema(x = float64(), y = float64())),
float64()
),
"Kernels for user-defined function must accept the same number of arguments"
)
})
test_that("user-defined functions work during multi-threaded execution", {
skip_if_not(CanRunWithCapturedR())
skip_if_not_available("dataset")
# Skip on linux devel because:
# TODO(ARROW-17283): Snappy has a UBSan issue that is fixed in the dev version
# TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has
# occasional valgrind errors
skip_on_linux_devel()
n_rows <- 10000
n_partitions <- 10
example_df <- expand.grid(
part = letters[seq_len(n_partitions)],
value = seq_len(n_rows),
stringsAsFactors = FALSE
)
# make sure values are different for each partition and
example_df$row_num <- seq_len(nrow(example_df))
example_df$value <- example_df$value + match(example_df$part, letters)
tf_dataset <- tempfile()
tf_dest <- tempfile()
on.exit(unlink(c(tf_dataset, tf_dest)))
write_dataset(example_df, tf_dataset, partitioning = "part")
register_scalar_function(
"times_32",
function(context, x) x * 32.0,
int32(),
float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
# check a regular collect()
result <- open_dataset(tf_dataset) %>%
dplyr::mutate(fun_result = times_32(value)) %>%
dplyr::collect() %>%
dplyr::arrange(row_num)
expect_identical(result$fun_result, example_df$value * 32)
# check a write_dataset()
open_dataset(tf_dataset) %>%
dplyr::mutate(fun_result = times_32(value)) %>%
write_dataset(tf_dest)
result2 <- dplyr::collect(open_dataset(tf_dest)) %>%
dplyr::arrange(row_num) %>%
dplyr::collect()
expect_identical(result2$fun_result, example_df$value * 32)
})
test_that("nested exec plans can contain user-defined functions", {
skip_if_not_available("dataset")
skip_if_not(CanRunWithCapturedR())
register_scalar_function(
"times_32",
function(context, x) x * 32.0,
int32(),
float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
stream_plan_with_udf <- function() {
record_batch(a = 1:1000) %>%
dplyr::mutate(b = times_32(a)) %>%
as_record_batch_reader() %>%
as_arrow_table()
}
collect_plan_with_head <- function() {
record_batch(a = 1:1000) %>%
dplyr::mutate(fun_result = times_32(a)) %>%
head(11) %>%
dplyr::collect()
}
expect_equal(
stream_plan_with_udf(),
record_batch(a = 1:1000) %>%
dplyr::mutate(b = times_32(a)) %>%
dplyr::collect(as_data_frame = FALSE)
)
result <- collect_plan_with_head()
expect_equal(nrow(result), 11)
})
test_that("head() on exec plan containing user-defined functions", {
skip("ARROW-18101")
skip_if_not_available("dataset")
skip_if_not(CanRunWithCapturedR())
register_scalar_function(
"times_32",
function(context, x) x * 32.0,
int32(),
float64(),
auto_convert = TRUE
)
on.exit(unregister_binding("times_32", update_cache = TRUE))
result <- record_batch(a = 1:1000) %>%
dplyr::mutate(b = times_32(a)) %>%
as_record_batch_reader() %>%
head(11) %>%
dplyr::collect()
expect_equal(nrow(result), 11)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.