test_that("fitting", {
skip_if_not_installed("klaR")
skip_if_not_installed("modeldata")
data("ames", package = "modeldata")
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
set.seed(1234)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR")
expect_no_error(
res <- fit(spec, ~., ames_cat)
)
expect_no_error(
res <- fit_xy(spec, ames_cat)
)
expect_identical(res$fit$weighted, FALSE)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR", weighted = TRUE)
res <- fit(spec, ~., ames_cat)
expect_identical(res$fit$weighted, TRUE)
})
test_that("predicting", {
skip_if_not_installed("klaR")
skip_if_not_installed("modeldata")
data("ames", package = "modeldata")
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
set.seed(1234)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR")
res <- fit(spec, ~., ames_cat)
preds <- predict(res, ames_cat[c(1:5), ])
expect_identical(
preds,
tibble::tibble(
.pred_cluster = factor(
paste0("Cluster_", c(1, 1, 1, 1, 2)),
paste0("Cluster_", 1:3)
)
)
)
})
test_that("all levels are preserved with 1 row predictions", {
skip_if_not_installed("klaR")
skip_if_not_installed("modeldata")
data("ames", package = "modeldata")
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
set.seed(1234)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR")
res <- fit(spec, ~., ames_cat)
preds <- predict(res, ames_cat[1, ])
expect_identical(
levels(preds$.pred_cluster),
paste0("Cluster_", 1:3)
)
})
test_that("predicting ties argument works", {
skip_if_not_installed("klaR")
dat <- data.frame(
x = c("A", "A", "B", "B", "C"),
y = c("A", "A", "B", "B", "C")
)
set.seed(1234)
spec <- k_means(num_clusters = 2) %>%
set_engine("klaR")
res <- fit(spec, ~., dat)
expect_identical(
predict(res, data.frame(x = "C", y = "C"), ties = "first"),
tibble::tibble(.pred_cluster = factor("Cluster_1", paste0("Cluster_", 1:2)))
)
expect_identical(
predict(res, data.frame(x = "C", y = "C"), ties = "last"),
tibble::tibble(.pred_cluster = factor("Cluster_2", paste0("Cluster_", 1:2)))
)
})
test_that("extract_centroids() works", {
skip_if_not_installed("klaR")
skip_if_not_installed("modeldata")
data("ames", package = "modeldata")
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
set.seed(1234)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR")
res <- fit(spec, ~., ames_cat)
centroids <- extract_centroids(res)
expected <- vctrs::vec_cbind(
tibble::tibble(.cluster = factor(paste0("Cluster_", 1:3))),
tibble::as_tibble(res$fit$modes)
)
expect_identical(
centroids,
expected
)
})
test_that("extract_cluster_assignment() works", {
skip_if_not_installed("klaR")
skip_if_not_installed("modeldata")
data("ames", package = "modeldata")
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
set.seed(1234)
spec <- k_means(num_clusters = 3) %>%
set_engine("klaR")
res <- fit(spec, ~., ames_cat)
clusters <- extract_cluster_assignment(res)
exp_cluster <- res$fit$cluster
exp_cluster <- order(unique(exp_cluster))[exp_cluster]
expected <- vctrs::vec_cbind(
tibble::tibble(.cluster = factor(paste0("Cluster_", exp_cluster)))
)
expect_identical(
clusters,
expected
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.