skip_connection("ml-evaluation")
skip_on_livy()
skip_on_arrow_devel()
skip_databricks_connect()
test_that("basic binary classification evaluation works", {
sc <- testthat_spark_connection()
df <- data.frame(label = c(1, 1, 0, 0), features1 = c(1, 1, 0, 0))
df_tbl <- dplyr::copy_to(sc, df, overwrite = TRUE)
model <- df_tbl %>%
ft_vector_assembler("features1", "features") %>%
ml_logistic_regression()
auc <- ml_binary_classification_evaluator(
model %>%
ml_predict(ft_vector_assembler(df_tbl, "features1", "features")),
label_col = "label", raw_prediction_col = "rawPrediction"
)
expect_equal(auc, 1)
})
test_that("basic regression evaluation works", {
sc <- testthat_spark_connection()
df <- data.frame(
label = c(1.2, 4.5, 6.7),
prediction = c(3, 5, 7)
)
df_tbl <- dplyr::copy_to(sc, df, overwrite = TRUE)
mse_r <- df %>%
dplyr::summarize(mse = sum((label - prediction)^2) / 3) %>%
dplyr::pull(mse)
mse_s <- ml_regression_evaluator(
df_tbl,
label_col = "label", prediction_col = "prediction", metric_name = "mse"
)
expect_equal(mse_r, mse_s)
})
test_that("ml evaluator print methods work", {
sc <- testthat_spark_connection()
expect_known_output(
ml_binary_classification_evaluator(sc, uid = "foo"),
output_file("print/binary-classification-evaluator.txt"),
print = TRUE
)
expect_known_output(
ml_multiclass_classification_evaluator(sc, uid = "foo"),
output_file(
ifelse(spark_version(sc) < "3.0.0",
"print/multiclass-classification-evaluator.txt",
"print/multiclass-classification-evaluator-spark-3.0.0.txt"
)
),
print = TRUE
)
expect_known_output(
ml_regression_evaluator(sc, uid = "foo"),
output_file("print/regression-evaluator.txt"),
print = TRUE
)
})
test_clear_cache()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.