test_that("primary arguments", {
basic <- k_means(mode = "partition")
basic_stats <- translate_tidyclust(basic %>% set_engine("stats"))
expect_equal(
basic_stats$method$fit$args,
list(
x = rlang::expr(missing_arg()),
centers = rlang::expr(missing_arg())
)
)
k <- k_means(num_clusters = 15, mode = "partition")
k_stats <- translate_tidyclust(k %>% set_engine("stats"))
expect_equal(
k_stats$method$fit$args,
list(
x = rlang::expr(missing_arg()),
centers = rlang::expr(missing_arg()),
centers = new_empty_quosure(15)
)
)
})
test_that("engine arguments", {
stats_print <- k_means(mode = "partition")
expect_equal(
translate_tidyclust(
stats_print %>%
set_engine("stats", nstart = 1L)
)$method$fit$args,
list(
x = rlang::expr(missing_arg()),
centers = rlang::expr(missing_arg()),
nstart = new_empty_quosure(1L)
)
)
})
test_that("bad input", {
expect_snapshot(error = TRUE, k_means(mode = "bogus"))
expect_snapshot(error = TRUE, {
bt <- k_means(num_clusters = -1) %>% set_engine("stats")
fit(bt, mpg ~ ., mtcars)
})
expect_snapshot(error = TRUE, translate_tidyclust(k_means(), engine = NULL))
expect_snapshot(error = TRUE, translate_tidyclust(k_means(formula = ~x)))
})
test_that("predictions", {
set.seed(1234)
kmeans_fit <- k_means(num_clusters = 4) %>%
set_engine("stats") %>%
fit(~., mtcars)
set.seed(1234)
ref_res <- kmeans(mtcars, 4)
ref_predictions <- ref_res$centers %>%
flexclust::dist2(mtcars) %>%
apply(2, which.min) %>%
unname()
relevel_preds <- function(x) {
factor(unname(x), unique(unname(x))) %>% as.numeric()
}
expect_equal(
relevel_preds(ref_predictions),
predict(kmeans_fit, mtcars)$.pred_cluster %>% as.numeric()
)
expect_equal(
relevel_preds(ref_predictions),
extract_cluster_assignment(kmeans_fit)$.cluster %>% as.numeric()
)
})
test_that("extract_centroids work", {
set.seed(1234)
kmeans_fit <- k_means(num_clusters = 4) %>%
set_engine("stats") %>%
fit(~., mtcars)
set.seed(1234)
ref_res <- kmeans(mtcars, 4)
ref_centroids <- dplyr::bind_cols(
.cluster = c(1L, 4L, 2L, 3L),
ref_res$centers
) %>%
dplyr::arrange(.cluster) %>%
dplyr::select(-.cluster)
expect_identical(
extract_centroids(kmeans_fit) %>% dplyr::select(-.cluster),
ref_centroids
)
})
test_that("Right classes", {
expect_equal(
class(k_means()),
c("k_means", "cluster_spec", "unsupervised_spec")
)
})
test_that("printing", {
expect_snapshot(
k_means()
)
expect_snapshot(
k_means(num_clusters = 10)
)
})
test_that("updating", {
expect_snapshot(
k_means(num_clusters = 5) %>%
update(num_clusters = tune())
)
})
test_that("Engine-specific arguments are passed to ClusterR models", {
spec <- k_means(num_clusters = 2) %>%
set_engine("ClusterR", fuzzy = FALSE)
fit <- fit(spec, ~., data = mtcars)
expect_true(is.null(fit$fit$fuzzy_clusters))
spec <- k_means(num_clusters = 2) %>%
set_engine("ClusterR", fuzzy = TRUE)
fit <- fit(spec, ~., data = mtcars)
expect_false(is.null(fit$fit$fuzzy_clusters))
})
test_that("reordering is done correctly for stats k_means", {
set.seed(42)
kmeans_fit <- k_means(num_clusters = 6) %>%
set_engine("stats") %>%
fit(~., data = mtcars)
summ <- extract_fit_summary(kmeans_fit)
expect_identical(
summ$n_members,
unname(as.integer(table(summ$cluster_assignments)))
)
})
test_that("reordering is done correctly for ClusterR k_means", {
set.seed(42)
kmeans_fit <- k_means(num_clusters = 6) %>%
set_engine("ClusterR") %>%
fit(~., data = mtcars)
summ <- extract_fit_summary(kmeans_fit)
expect_identical(
summ$n_members,
unname(as.integer(table(summ$cluster_assignments)))
)
})
test_that("errors if `num_clust` isn't specified", {
expect_snapshot(
error = TRUE,
k_means() %>%
set_engine("stats") %>%
fit(~., data = mtcars)
)
expect_snapshot(
error = TRUE,
k_means() %>%
set_engine("ClusterR") %>%
fit(~., data = mtcars)
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.