skip_connection("ml-classification-one-vs-rest")
skip_on_livy()
skip_on_arrow_devel()
skip_databricks_connect()
test_that("ml_one_vs_rest() default params", {
test_requires_version("3.0.0")
sc <- testthat_spark_connection()
test_default_args(sc, ml_one_vs_rest)
})
test_that("ml_one_vs_rest() param setting", {
test_requires_version("3.0.0")
sc <- testthat_spark_connection()
test_args <- list(
features_col = "wefaef",
label_col = "weijfw",
prediction_col = "weifjwifj"
)
test_param_setting(sc, ml_one_vs_rest, test_args)
})
test_that("ml_one_vs_rest with two classes agrees with logistic regression", {
sc <- testthat_spark_connection()
iris_tbl2 <- testthat_tbl("iris") %>%
mutate(is_versicolor = ifelse(
Species == "versicolor", "versicolor", "other"
)) %>%
select(-Species)
lr_model <- ml_logistic_regression(iris_tbl2, formula = is_versicolor ~ .)
lr <- ml_logistic_regression(sc)
ovr_model <- ml_one_vs_rest(iris_tbl2, is_versicolor ~ ., classifier = lr)
expect_equal(
ml_predict(ovr_model, iris_tbl2) %>% pull(predicted_label),
ml_predict(lr_model, iris_tbl2) %>% pull(predicted_label)
)
})
test_that("ml_one_vs_rest fits the right number of estimators", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
lr <- ml_logistic_regression(sc)
ovr_model <- ml_one_vs_rest(iris_tbl, Species ~ ., classifier = lr)
expect_equal(length(ovr_model$model$models), 3)
})
test_that("ml_one_vs_rest() errors when not given classifier", {
sc <- testthat_spark_connection()
expect_error(
ml_one_vs_rest(sc, classifier = ml_random_forest_regressor(sc)),
"`classifier` must be an `ml_classifier`\\."
)
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.