tests/testthat/test-ml-clustering-bisecting-kmeans.R

skip_connection("ml-clustering-bisecting-kmeans")
skip_on_livy()
skip_on_arrow_devel()

test_that("ml_bisecting_kmeans() default params", {
  test_requires_version("3.0.0")
  sc <- testthat_spark_connection()
  test_default_args(sc, ml_bisecting_kmeans)
})

test_that("ml_bisecting_kmeans() param setting", {
  test_requires_version("3.0.0")
  sc <- testthat_spark_connection()
  test_args <- list(
    k = 5,
    max_iter = 10,
    seed = 32932,
    min_divisible_cluster_size = 3,
    features_col = "fwefw",
    prediction_col = "ewfwef"
  )
  test_param_setting(sc, ml_bisecting_kmeans, test_args)
})

test_that("ml_bisecting_kmeans() works properly", {
  sc <- testthat_spark_connection()
  test_requires_version("2.0.0", "bisecting kmeans support")
  sample_data_path <- get_test_data_path("sample_libsvm_data.txt")

  sample_data <- spark_read_libsvm(sc, "sample_data",
    sample_data_path,
    overwrite = TRUE
  )
  bkm <- ml_bisecting_kmeans(sample_data, k = 2, seed = 1)
  expect_equal(bkm$compute_cost(sample_data), 214807298)

  cluster_centers <- list(c(
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.08474576271186,
    3.23728813559322, 4, 5.44067796610169, 5.6271186440678, 1.30508474576271,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.44067796610169,
    4.28813559322034, 2.23728813559322, 0.152542372881356, 6.3728813559322,
    27.0508474576271, 38.1864406779661, 44.2203389830508, 45.4406779661017,
    49.9830508474576, 67.9491525423729, 53.8474576271186, 34.9830508474576,
    10.0338983050847, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.38983050847458,
    4.08474576271186, 4.25423728813559, 4.25423728813559, 3.32203389830508,
    10.7796610169492, 45.6271186440678, 64.3050847457627, 69.8305084745763,
    81.0508474576271, 89.864406779661, 103.593220338983, 80.0677966101695,
    64.3728813559322, 22.8983050847458, 3.83050847457627, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 2.96610169491525, 4.25423728813559,
    4.25423728813559, 4.25423728813559, 5.38983050847458, 18.864406779661,
    46.8135593220339, 66.728813559322, 85.7627118644068, 95.8305084745763,
    117, 113.915254237288, 94.6440677966102, 66.8813559322034, 21.7118644067797,
    6.38983050847458, 0.135593220338983, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0.220338983050847, 3.45762711864407, 4.25423728813559,
    5.49152542372881, 8.25423728813559, 21.1864406779661, 48.864406779661,
    68.2033898305085, 99.0169491525424, 109.135593220339, 132.728813559322,
    113.64406779661, 96.9830508474576, 61.5084745762712, 20.1694915254237,
    7.1864406779661, 1.15254237288136, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0.661016949152542, 4.25423728813559, 7.33898305084746,
    10.271186440678, 21.9661016949153, 47.5593220338983, 74.2203389830509,
    117.033898305085, 140.932203389831, 142.64406779661, 117.305084745763,
    93.1864406779661, 45.4237288135593, 9.94915254237288, 3.49152542372881,
    1.69491525423729, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.254237288135593,
    2.89830508474576, 9.71186440677966, 19, 25.3220338983051, 47.0677966101695,
    79.4237288135593, 139.745762711864, 159.423728813559, 150.779661016949,
    109.694915254237, 76.4745762711864, 29.8135593220339, 7.16949152542373,
    2.77966101694915, 1.15254237288136, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0.355932203389831, 5.49152542372881, 11.3220338983051,
    14.8474576271186, 18.4576271186441, 42, 90.271186440678, 167.728813559322,
    184.661016949153, 148.64406779661, 95.0847457627119, 45.9661016949153,
    15.7627118644068, 1.72881355932203, 3.89830508474576, 4.30508474576271,
    3.15254237288136, 0.23728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    1.49152542372881, 7.27118644067797, 10.0677966101695, 8.91525423728813,
    14.8474576271186, 16.0169491525424, 34.2881355932203, 107.779661016949,
    192.728813559322, 199.152542372881, 137.423728813559, 82, 28.8135593220339,
    6.64406779661017, 0, 0.559322033898305, 3.03389830508475, 4.28813559322034,
    2.3728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0.796610169491525, 7.1864406779661,
    11.3220338983051, 8.06779661016949, 5.91525423728814, 10.271186440678,
    10.6949152542373, 35.4745762711864, 126.576271186441, 215.372881355932,
    204.949152542373, 131.237288135593, 61.7966101694915, 12.9830508474576,
    4.28813559322034, 0, 0, 0.406779661016949, 3.54237288135593,
    3.77966101694915, 0, 0, 0, 0, 0, 0, 0, 0, 3.86440677966102, 8.54237288135593,
    4.23728813559322, 0.983050847457627, 1.13559322033898, 4.91525423728814,
    5, 40.5762711864407, 173.440677966102, 231.186440677966, 206.271186440678,
    115.338983050847, 34.5084745762712, 6.22033898305085, 3.91525423728814,
    0, 0, 0, 2.30508474576271, 4.28813559322034, 0, 0, 0, 0, 0, 0,
    0, 0, 6.13559322033898, 4.32203389830508, 0.254237288135593,
    0, 1.66101694915254, 4.30508474576271, 4.30508474576271, 60.3050847457627,
    201.186440677966, 234.28813559322, 197.35593220339, 79.1525423728814,
    22.2372881355932, 2.8135593220339, 2.6271186440678, 0, 0, 0,
    2.30508474576271, 3.64406779661017, 0, 0, 0, 0, 0, 0, 0, 0, 4.32203389830508,
    1.40677966101695, 0.932203389830508, 0, 1.66101694915254, 4.69491525423729,
    12.8983050847458, 100.898305084746, 217.796610169492, 234.033898305085,
    165.915254237288, 48.5762711864407, 10.7796610169492, 4.03389830508475,
    2.6271186440678, 0, 0.0847457627118644, 2.15254237288136, 4.16949152542373,
    2.25423728813559, 0, 0, 0, 0, 0, 0, 0, 0, 4.55932203389831, 2.47457627118644,
    0.389830508474576, 0, 1.66101694915254, 9.79661016949153, 41.6271186440678,
    124.406779661017, 214.203389830508, 223.237288135593, 134.322033898305,
    36.5593220338983, 5.96610169491525, 7.71186440677966, 5.47457627118644,
    0.288135593220339, 2.54237288135593, 4.1864406779661, 1.54237288135593,
    0.152542372881356, 0, 0, 0, 0, 0, 0, 0, 0, 5.27118644067797,
    4.20338983050847, 0, 0, 2.45762711864407, 25, 75.593220338983,
    136.271186440678, 211.983050847458, 196.949152542373, 105.457627118644,
    31.9322033898305, 7.22033898305085, 11.5593220338983, 10.6779661016949,
    5.76271186440678, 3.66101694915254, 1.42372881355932, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 1.88135593220339, 4.25423728813559, 1.47457627118644,
    0, 6.49152542372881, 49.5593220338983, 102.881355932203, 149.152542372881,
    202.203389830508, 166.71186440678, 89.4237288135593, 33.6101694915254,
    16.8813559322034, 17.0338983050847, 13.5762711864407, 8.03389830508475,
    1.74576271186441, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.23728813559322,
    3.13559322033898, 4.08474576271186, 6.30508474576271, 25.2203389830508,
    79.7118644067797, 116.627118644068, 153.237288135593, 173.322033898305,
    148.830508474576, 85.6271186440678, 46.6779661016949, 25.5762711864407,
    17.4745762711864, 14.1016949152542, 9.33898305084746, 2.33898305084746,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.389830508474576, 4.44067796610169,
    18.8813559322034, 50.1016949152542, 92.7796610169491, 132.847457627119,
    149, 144.694915254237, 126.338983050847, 76.2372881355932, 47.3050847457627,
    30.0169491525424, 18.9322033898305, 8.30508474576271, 6.69491525423729,
    3.23728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4.23728813559322,
    11.2372881355932, 25.9661016949153, 68, 99.7118644067797, 127.864406779661,
    127.728813559322, 123.864406779661, 107.372881355932, 77.728813559322,
    47.5932203389831, 26.9830508474576, 14.8305084745763, 7.76271186440678,
    4.93220338983051, 2.93220338983051, 0, 0, 0, 0, 2.3728813559322,
    0, 0, 0, 0, 0, 0, 0, 3.45762711864407, 13.3220338983051, 37.7627118644068,
    81.5593220338983, 105.813559322034, 119.322033898305, 105.033898305085,
    109.237288135593, 98.9152542372881, 68.4237288135593, 36.8813559322034,
    24.5423728813559, 17.6949152542373, 14.6271186440678, 5.94915254237288,
    2.93220338983051, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.203389830508475,
    5.89830508474576, 26.4237288135593, 75.2542372881356, 91.1186440677966,
    84.4576271186441, 60.9491525423729, 63.3728813559322, 63.5593220338983,
    40.3898305084746, 20.4406779661017, 18.3728813559322, 14.5932203389831,
    7.20338983050847, 4.59322033898305, 2.93220338983051, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.949152542372881, 20.4915254237288,
    26.4067796610169, 19.5084745762712, 9.49152542372881, 4.98305084745763,
    14.6949152542373, 10.5762711864407, 0.525423728813559, 0, 0,
    0
  ), c(
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.36585365853659, 6.02439024390244,
    2.95121951219512, 1.70731707317073, 6.21951219512195, 4.02439024390244,
    2.78048780487805, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 1.53658536585366, 5.92682926829268, 13.0487804878049,
    43.6829268292683, 65.1951219512195, 76.5609756097561, 91.1463414634146,
    97.2926829268293, 92.7560975609756, 65.4390243902439, 25.4146341463415,
    6.80487804878049, 1.92682926829268, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 1.73170731707317, 8.41463414634146, 25.7317073170732,
    55.8780487804878, 88, 123.219512195122, 166.439024390244, 195.512195121951,
    192.317073170732, 173.390243902439, 142.365853658537, 100.536585365854,
    48.0731707317073, 18.0975609756098, 3.46341463414634, 2.07317073170732,
    0.341463414634146, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4.21951219512195,
    13.6829268292683, 34.0731707317073, 69.5609756097561, 87.0731707317073,
    129.951219512195, 171.707317073171, 205.19512195122, 224.024390243902,
    219.268292682927, 199.926829268293, 188.707317073171, 152.853658536585,
    96.1219512195122, 43.219512195122, 13.7073170731707, 6.09756097560976,
    3.5609756097561, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3.19512195121951,
    16.0731707317073, 36.2926829268293, 62.1951219512195, 87.390243902439,
    125.878048780488, 167.09756097561, 211.80487804878, 211.268292682927,
    202.146341463415, 211.90243902439, 223.439024390244, 208.585365853659,
    205.585365853659, 159.317073170732, 78.9024390243902, 21.3170731707317,
    7.02439024390244, 6.09756097560976, 0, 0, 0, 0, 0, 0, 0, 0, 0.560975609756098,
    11.7317073170732, 28.9268292682927, 60.9512195121951, 83.1951219512195,
    129.09756097561, 167, 187.121951219512, 198.682926829268, 171.609756097561,
    147, 149.19512195122, 183.853658536585, 187.658536585366, 209.487804878049,
    189.756097560976, 122.365853658537, 48.2439024390244, 15.6585365853659,
    6.41463414634146, 0, 0, 0, 0, 0, 0, 0, 0, 4.70731707317073, 23.3414634146341,
    42.6585365853659, 77.8048780487805, 111.878048780488, 150.634146341463,
    190.073170731707, 173.536585365854, 156.073170731707, 125.414634146341,
    93.3170731707317, 91.0975609756098, 110.512195121951, 134.292682926829,
    179.658536585366, 196.390243902439, 148.853658536585, 89.5609756097561,
    26.8780487804878, 9.95121951219512, 0, 0, 0, 0, 0, 0, 0, 0, 7.8780487804878,
    34.8048780487805, 63.6585365853659, 96.2682926829268, 142.634146341463,
    167.634146341463, 165.024390243902, 150.121951219512, 120, 89.2682926829268,
    55.4634146341463, 50.8292682926829, 66.8780487804878, 91.9268292682927,
    142.365853658537, 191.048780487805, 161.80487804878, 112.926829268293,
    48, 14.9024390243902, 0, 0, 0, 0, 0, 0, 0, 0, 20.8536585365854,
    55.780487804878, 84.780487804878, 114.90243902439, 158.317073170732,
    155.487804878049, 147.121951219512, 122.439024390244, 80.8780487804878,
    47.6341463414634, 26.6341463414634, 25.2439024390244, 37.1463414634146,
    64.6585365853659, 104.414634146341, 162.219512195122, 160.146341463415,
    133.121951219512, 72.6585365853659, 23.7317073170732, 0, 0, 0,
    0, 0, 0, 0, 0, 33.390243902439, 72.4390243902439, 101.19512195122,
    144.756097560976, 177.951219512195, 150.146341463415, 120.024390243902,
    84.5121951219512, 50.4878048780488, 22.7073170731707, 11.609756097561,
    11.8780487804878, 26.9756097560976, 53.9512195121951, 81.5121951219512,
    149.878048780488, 158.121951219512, 143.048780487805, 91.4878048780488,
    32.6585365853659, 0, 0, 0, 0, 0, 0, 0, 0, 36.8780487804878, 89.2439024390244,
    126.585365853659, 161.829268292683, 173.146341463415, 130, 91.0487804878049,
    61.8780487804878, 22.8780487804878, 8.78048780487805, 6.68292682926829,
    3.34146341463415, 16.4390243902439, 56.5365853658537, 79.1219512195122,
    143.512195121951, 162.243902439024, 152.09756097561, 93.7560975609756,
    37.3658536585366, 0, 0, 0, 0, 0, 0, 0, 0.463414634146341, 46.3170731707317,
    107.829268292683, 147.073170731707, 174.024390243902, 152.048780487805,
    98.1707317073171, 63.5121951219512, 25.5853658536585, 6.73170731707317,
    2.24390243902439, 1.73170731707317, 0.24390243902439, 16.0975609756098,
    46.1951219512195, 89.1463414634146, 139.317073170732, 157.878048780488,
    135.536585365854, 93.5365853658537, 41, 0, 0, 0, 0, 0, 0, 0,
    0.829268292682927, 66.3658536585366, 129.487804878049, 162.19512195122,
    184.878048780488, 145.048780487805, 82.9756097560976, 30.1951219512195,
    7, 0, 0, 0, 0, 14.609756097561, 45.9512195121951, 97.5121951219512,
    147.585365853659, 147.853658536585, 123.487804878049, 87.219512195122,
    36.5121951219512, 0, 0, 0, 0, 0, 0, 0, 1.29268292682927, 80.4634146341463,
    139.951219512195, 179.365853658537, 183.024390243902, 129.317073170732,
    66.6341463414634, 21.219512195122, 6.34146341463415, 0, 0, 0,
    2.04878048780488, 13.8048780487805, 62.390243902439, 114.682926829268,
    150.292682926829, 149.292682926829, 100.512195121951, 76.3658536585366,
    33.7317073170732, 0, 0, 0, 0, 0, 0, 0, 6.65853658536585, 84.1219512195122,
    144.292682926829, 186.975609756098, 178.90243902439, 130.317073170732,
    55.780487804878, 19.0487804878049, 9.09756097560976, 0.268292682926829,
    0.75609756097561, 2.46341463414634, 13.7073170731707, 33.9756097560976,
    90.7073170731707, 137.756097560976, 161.414634146341, 130.951219512195,
    87.219512195122, 67.0243902439024, 26.8780487804878, 0, 0, 0,
    0, 0, 0, 0, 6.73170731707317, 74.9756097560976, 135.243902439024,
    193.878048780488, 188.780487804878, 132.170731707317, 54.6829268292683,
    17.3658536585366, 9.63414634146342, 5.14634146341463, 8.17073170731707,
    15.8292682926829, 36.7073170731707, 88.0243902439024, 132.414634146341,
    156.560975609756, 145.780487804878, 99.3414634146341, 73.9268292682927,
    48.2439024390244, 16.1463414634146, 0, 0, 0, 0, 0, 0, 0, 4.46341463414634,
    56.2682926829268, 123.780487804878, 186.390243902439, 207.365853658537,
    163.292682926829, 95.609756097561, 50.5609756097561, 26.2439024390244,
    23.0243902439024, 34.0243902439024, 51.5365853658537, 89.390243902439,
    139.048780487805, 161.073170731707, 160.780487804878, 117.512195121951,
    90.5609756097561, 57.8048780487805, 33.5853658536585, 6.17073170731707,
    0, 0, 0, 0, 0, 0, 0, 2.48780487804878, 31.7560975609756, 98.7073170731707,
    160.585365853659, 210.121951219512, 197.317073170732, 156.073170731707,
    120.463414634146, 105.731707317073, 113.414634146341, 115.780487804878,
    135.926829268293, 155.756097560976, 175.707317073171, 162.365853658537,
    125.048780487805, 100.09756097561, 66.780487804878, 43.3658536585366,
    14.8780487804878, 0, 0, 0, 0, 0, 0, 0, 0, 1.78048780487805, 13.4634146341463,
    57.1463414634146, 132.463414634146, 191.414634146341, 224.756097560976,
    217.487804878049, 193.512195121951, 187.634146341463, 181.292682926829,
    185.975609756098, 196.439024390244, 197.439024390244, 177.170731707317,
    141.878048780488, 95.1707317073171, 70.4146341463415, 46, 21.1463414634146,
    1.65853658536585, 0, 0, 0, 0, 0, 0, 0, 0, 1.51219512195122, 7.75609756097561,
    22.4878048780488, 68.8292682926829, 135.829268292683, 191.829268292683,
    226.536585365854, 232, 222.975609756098, 223.829268292683, 216.707317073171,
    201.90243902439, 175.170731707317, 130.731707317073, 89.0975609756098,
    65.6585365853659, 42.2682926829268, 18.2926829268293, 4.07317073170732,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2.02439024390244, 8.68292682926829,
    22, 52.4146341463415, 106.780487804878, 154.756097560976, 183.829268292683,
    190.268292682927, 175.829268292683, 163.707317073171, 117.682926829268,
    83.2926829268293, 64.7560975609756, 43.2682926829268, 27.4634146341463,
    11.0243902439024, 2.34146341463415, 0.24390243902439, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 1.8780487804878, 11.9512195121951, 12.6341463414634,
    18.3658536585366, 26.3414634146341, 43.4878048780488, 49.9512195121951,
    49.9268292682927, 47.9268292682927, 33.8048780487805, 20.1463414634146,
    16, 8.90243902439024, 2.65853658536585
  ))

  expect_equal(bkm$cluster_centers(), cluster_centers)
})

test_that("ml_bisecting_kmeans() works properly with iris", {
  # bisecting kmeans not deterministic over partitions
  skip_on_arrow()

  sc <- testthat_spark_connection()
  test_requires_version("2.0.0", "ml_bisecting_kmeans() requires Spark 2.0.0+")
  iris_tbl <- testthat_tbl("iris")
  expect_output_file(
    print(ml_bisecting_kmeans(iris_tbl, ~ . - Species, k = 5, seed = 11)),
    output_file("print/bisecting-kmeans.txt")
  )
})

test_clear_cache()
rstudio/sparklyr documentation built on April 30, 2024, 4:01 p.m.