skip_connection("ml-pipeline-utils")
skip_on_livy()
skip_on_arrow_devel()
skip_databricks_connect()
sc <- testthat_spark_connection()
test_that("ml_transform() fails on estimators", {
iris_tbl <- testthat_tbl("iris")
string_indexer <- ft_string_indexer(sc, "Species", "species_idx")
expect_error(
string_indexer %>%
ml_transform(iris_tbl),
"Transformers must be 'ml_transformer' objects"
)
})
test_that("ml_fit() and ml_fit_and_transform() fail on transformers", {
iris_tbl <- testthat_tbl("iris")
binarizer <- ft_binarizer(sc, "Petal_Width", "petal_width_binarized")
expect_error(
binarizer %>%
ml_fit(iris_tbl),
"is only applicable to"
)
expect_error(
binarizer %>%
ml_fit_and_transform(iris_tbl),
"is only applicable to"
)
})
test_that("ml_stage() and ml_stages() work properly", {
pipeline <- ml_pipeline(sc) %>%
ft_tokenizer("a", "b", uid = "tok1") %>%
ft_tokenizer("c", "d", uid = "tok2") %>%
ft_binarizer("e", "f", uid = "bin1")
expect_error(
ml_stage(pipeline, "blah"),
"stage not found"
)
expect_error(
ml_stage(pipeline, "tok"),
"multiple stages found"
)
expect_equal(
pipeline %>%
ml_stage("bin") %>%
ml_uid(),
"bin1"
)
expect_error(
ml_stages(pipeline, c("blah")),
"no stages found for identifier blah"
)
expect_error(
ml_stages(pipeline, c("tok", "bin")),
"multiple stages found for identifier tok"
)
expect_equal(
ml_stages(pipeline, c("tok1", "bin")) %>%
sapply(ml_uid),
c("tok1", "bin1")
)
expect_equal(
ml_stages(pipeline) %>%
sapply(ml_uid),
c("tok1", "tok2", "bin1")
)
expect_equal(
ml_stage(pipeline, 1) %>%
ml_uid(),
"tok1"
)
expect_equal(
ml_stages(pipeline, 1) %>%
sapply(ml_uid),
"tok1"
)
expect_equal(
ml_stages(pipeline, 1:2) %>%
sapply(ml_uid),
c("tok1", "tok2")
)
})
test_that("ml_is_set works", {
lr <- ml_logistic_regression(sc, reg_param = 0L)
expect_true(ml_is_set(lr, "reg_param"))
expect_false(ml_is_set(lr, "thresholds"))
expect_true(ml_is_set(spark_jobj(lr), "reg_param"))
expect_false(ml_is_set(spark_jobj(lr), "thresholds"))
})
test_that("ml_transform take list of transformers (#1444)", {
test_requires_version("2.0.0")
iris_tbl <- testthat_tbl("iris")
string_indexer <- ft_string_indexer(sc, "Species", "label") %>%
ml_fit(iris_tbl)
pipeline <- ml_pipeline(string_indexer) %>%
ft_vector_assembler(c("Petal_Width", "Petal_Length"), "features") %>%
ml_logistic_regression() %>%
ft_index_to_string("prediction", "predicted_label",
labels = ml_labels(string_indexer)
)
pipeline_model <- ml_fit(pipeline, iris_tbl)
stages <- pipeline_model %>%
ml_stages(c("vector_assembler", "logistic", "index_to_string"))
transformed1 <- ml_transform(stages, iris_tbl) %>%
dplyr::pull(prediction)
transformed2 <- Reduce(function(transformer, data) ml_transform(data, transformer), stages, init = iris_tbl) %>%
dplyr::pull(prediction)
expect_equal(transformed1, transformed2)
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.