tests/testthat/test-ml-evaluation-clustering.R

skip_connection("ml-evaluation-clustering")
skip_on_livy()
skip_on_arrow_devel()

skip_databricks_connect()
test_that("ml_clustering_evaluator() works", {
  test_requires_version("2.3.0", "ml clustering evaluator requires spark 2.3+")
  sc <- testthat_spark_connection()
  sample_data_path <- get_test_data_path("sample_kmeans_data.txt")
  sample_data <- spark_read_libsvm(sc, "sample_data",
    sample_data_path,
    overwrite = TRUE
  )
  kmeans <- ml_kmeans(sample_data, k = 2, seed = 1)
  predictions <- ml_transform(kmeans, sample_data)
  expect_equal(
    ml_clustering_evaluator(predictions),
    0.999753,
    tolerance = 1e-5,
    scale = 1
  )
})

test_that("ml_clustering_evalator() can be used in a CV pipeline", {
  test_requires_version("2.3.0", "ml clustering evaluator requires spark 2.3+")
  sc <- testthat_spark_connection()
  iris_tbl <- testthat_tbl("iris")
  pipeline_cv <- ml_pipeline(sc) %>%
    ft_vector_assembler(
      input_cols = c("Sepal_Width", "Sepal_Length", "Petal_Width", "Petal_Length"),
      output_col = "features"
    ) %>%
    ml_kmeans(uid = "kmeans")

  # Specify hyperparameter grid
  grid <- list(
    kmeans = list(
      k = c(2, 3, 4)
    )
  )

  # Create the cross validator object
  cv <- ml_cross_validator(
    sc,
    estimator = pipeline_cv,
    estimator_param_maps = grid,
    evaluator = ml_clustering_evaluator(sc),
    num_folds = 3,
    parallelism = 1
  )

  # Train the models
  expect_error(
    ml_fit(cv, iris_tbl),
    NA
  )
})

test_clear_cache()
rstudio/sparklyr documentation built on April 30, 2024, 4:01 p.m.