skip_connection("ml-supervised-random-forest")
skip_on_livy()
skip_on_arrow_devel()
skip_databricks_connect()
test_that("rf runs successfully when all args specified", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
expect_error(
iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width + Sepal_Length + Petal_Width,
type = "classification",
feature_subset_strategy = "onethird", impurity = "entropy", max_bins = 16,
max_depth = 3, min_info_gain = 1e-5, min_instances_per_node = 2L,
num_trees = 25L, thresholds = c(1 / 2, 1 / 3, 1 / 4), seed = 42L
),
NA
)
})
test_that("thresholds parameter behaves as expected", {
skip_slow("takes too long to measure coverage")
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
most_predicted_label <- function(x) {
x %>%
count(prediction) %>%
arrange(desc(n)) %>%
collect() %>%
pull(prediction) %>%
first()
}
rf_predictions <- iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
thresholds = c(0, 1, 1)
) %>%
ml_predict(iris_tbl)
expect_equal(most_predicted_label(rf_predictions), 0)
rf_predictions <- iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
thresholds = c(1, 0, 1)
) %>%
ml_predict(iris_tbl)
expect_equal(most_predicted_label(rf_predictions), 1)
rf_predictions <- iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
thresholds = c(1, 1, 0)
) %>%
ml_predict(iris_tbl)
expect_equal(most_predicted_label(rf_predictions), 2)
})
test_that("error for thresholds with wrong length", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
if (spark_version(sc) < "2.1.0") skip("threshold length checking implemented in 2.1.0")
expect_error(
iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
thresholds = c(0, 1)
)
)
})
test_that("error for bad impurity specification", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
expect_error(
iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
impurity = "variance"
),
"`impurity` must be \"gini\" or \"entropy\" for classification\\."
)
expect_error(
iris_tbl %>%
ml_random_forest(Sepal_Length ~ Sepal_Width,
type = "regression",
impurity = "gini"
),
"`impurity` must be \"variance\" for regression\\."
)
})
test_that("random seed setting works", {
skip_slow("takes too long to measure coverage")
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
model_string <- function(x) {
spark_jobj(x$model) %>%
invoke("toDebugString") %>%
strsplit("\n") %>%
unlist() %>%
tail(-1)
}
m1 <- iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
seed = 42L
)
m2 <- iris_tbl %>%
ml_random_forest(Species ~ Sepal_Width,
type = "classification",
seed = 42L
)
expect_equal(model_string(m1), model_string(m2))
})
test_that("one-tree forest agrees with ml_decision_tree()", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
rf <- iris_tbl %>%
ml_random_forest(Petal_Length ~ Sepal_Width + Sepal_Length + Petal_Width,
type = "regression",
subsampling_rate = 1, feature_subset_strategy = "all",
num_trees = 1
)
dt <- iris_tbl %>%
ml_decision_tree(Petal_Length ~ Sepal_Width + Sepal_Length + Petal_Width,
type = "regression"
)
expect_equal(rf %>%
ml_predict(iris_tbl) %>%
collect(),
dt %>%
ml_predict(iris_tbl) %>%
collect(),
tolerance = 0.5,
scale = 1
)
})
test_that("checkpointing works for rf", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
spark_set_checkpoint_dir(sc, tempdir())
expect_error(
iris_tbl %>%
ml_random_forest(Petal_Length ~ Sepal_Width + Sepal_Length + Petal_Width,
type = "regression",
cache_node_ids = TRUE,
checkpoint_interval = 5
),
NA
)
})
test_that("ml_random_forest() provides informative error for bad response_col", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
expect_error(
ml_random_forest(iris_tbl, Sepal.Length ~ Sepal.Width),
"`Sepal.Length` is not a column in the input dataset\\."
)
})
test_that("residuals() call on ml_model_random_forest_regression errors", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
expect_error(
ml_random_forest(iris_tbl, Sepal_Length ~ Sepal_Width) %>% residuals(),
"'residuals\\(\\)' not supported for ml_model_random_forest_regression"
)
})
test_that("ml_random_forest() supports response-features syntax", {
sc <- testthat_spark_connection()
iris_tbl <- testthat_tbl("iris")
expect_error(
ml_random_forest(iris_tbl,
response = "Sepal_Length",
features = c("Sepal_Width", "Petal_Length")
),
NA
)
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.