Nothing
context("layer_text_vectorization")
test_call_succeeds("layer_text_vectorization", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
input <- matrix(c("hello world", "hello world"), ncol = 1)
layer <- layer_text_vectorization()
layer %>% adapt(input)
output <- layer(input)
expect_s3_class(output, "tensorflow.tensor")
})
test_call_succeeds("layer_text_vectorization", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
x <- matrix(c("hello world", "hello world"), ncol = 1)
layer <- layer_text_vectorization(output_mode = "binary",
pad_to_max_tokens = FALSE)
layer %>% adapt(x)
output <- layer(x)
expect_s3_class(output, "tensorflow.tensor")
})
test_call_succeeds("can use layer_text_vectorization in a functional model", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
# x <- matrix(c("hello world", "hello world"), ncol = 1)
x <- c("hello world", "hello world")
layer <- layer_text_vectorization()
layer %>% adapt(x)
input <- layer_input(shape = 1, dtype = "string")
output <- layer(input)
model <- keras_model(input, output)
pred <- predict(model, x)
})
test_call_succeeds("can set and get the vocabulary of layer_text_vectorization", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
# x <- matrix(c("hello world", "hello world"), ncol = 1)
x <- c("hello world", "hello world")
layer <- layer_text_vectorization()
# workaround upstream regression, getting an empty vocab throws an exception in 2.5
if(tf_version() < "2.5")
layer$get_vocabulary()
set_vocabulary(layer, vocabulary = c("hello", "world"))
output <- layer(x)
vocab <- get_vocabulary(layer)
expect_s3_class(output, "tensorflow.tensor")
if (tensorflow::tf_version() < "2.3")
expect_length(vocab, 2)
else
expect_length(vocab, 4) # 0 is used for padding and 1 for unknown.
})
test_call_succeeds("can use layer_text_vectorization", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
x <- matrix(c("hello world", "hello world"), ncol = 1)
x_ds <- tfdatasets::tensor_slices_dataset(x)
layer <- layer_text_vectorization()
layer %>% adapt(x_ds)
if (tensorflow::tf_version() < "2.3")
expect_length(get_vocabulary(layer), 2)
else
expect_length(get_vocabulary(layer), 4) # 0 is used for padding and 1 for unknown.
})
test_call_succeeds("can create a tf-idf layer", {
if (tensorflow::tf_version() < "2.1")
skip("TextVectorization requires TF version >= 2.1")
num_words <- 10000
max_length <- 50
text_vectorization <- layer_text_vectorization(
max_tokens = num_words,
output_mode = if(tf_version() >= "2.6") "tf_idf" else "tf-idf"
)
with(tf$device("/cpu:0"), {
text_vectorization %>% adapt(c("hello world", "hello"))
})
x <- text_vectorization(matrix(c("hello"), ncol = 1))
expect_s3_class(x, "tensorflow.tensor")
})
test_call_succeeds("get_vocabulary() returns R character vector", {
text_vectorization <- layer_text_vectorization()
with(tf$device("/cpu:0"), {
text_vectorization %>% adapt(c("hello world", "hello"))
})
vocab <- get_vocabulary(text_vectorization)
expect_type(vocab, "character")
expect_contains(vocab, c("hello", "world"))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.