Nothing
test_that("primary arguments", {
basic <- mean_shift(mode = "partition")
basic_LPCM <- translate_tidyclust(basic |> set_engine("LPCM"))
expect_equal(
basic_LPCM$method$fit$args,
list(
x = rlang::expr(missing_arg()),
bandwidth = rlang::expr(missing_arg())
)
)
ms <- mean_shift(bandwidth = 0.5, mode = "partition")
ms_LPCM <- translate_tidyclust(ms |> set_engine("LPCM"))
expect_equal(
ms_LPCM$method$fit$args,
list(
x = rlang::expr(missing_arg()),
bandwidth = rlang::expr(missing_arg()),
bandwidth = new_empty_quosure(0.5)
)
)
basic_meanShiftR <- translate_tidyclust(basic |> set_engine("meanShiftR"))
expect_equal(
basic_meanShiftR$method$fit$args,
list(
x = rlang::expr(missing_arg()),
bandwidth = rlang::expr(missing_arg())
)
)
ms_meanShiftR <- translate_tidyclust(ms |> set_engine("meanShiftR"))
expect_equal(
ms_meanShiftR$method$fit$args,
list(
x = rlang::expr(missing_arg()),
bandwidth = rlang::expr(missing_arg()),
bandwidth = new_empty_quosure(0.5)
)
)
})
test_that("bad input", {
expect_snapshot(error = TRUE, mean_shift(mode = "bogus"))
skip_if_not_installed("LPCM")
expect_snapshot(error = TRUE, {
bt <- mean_shift(bandwidth = -1) |> set_engine("LPCM")
fit(bt, mpg ~ ., mtcars)
})
expect_snapshot(
error = TRUE,
translate_tidyclust(mean_shift(), engine = NULL)
)
expect_snapshot(error = TRUE, translate_tidyclust(mean_shift(formula = ~x)))
})
test_that("predictions", {
skip_if_not_installed("LPCM")
set.seed(1234)
ms_fit <- mean_shift(bandwidth = 0.5) |>
set_engine("LPCM") |>
fit(~., mtcars)
preds <- predict(ms_fit, mtcars)
expect_s3_class(preds$.pred_cluster, "factor")
expect_identical(nrow(preds), nrow(mtcars))
assignments <- extract_cluster_assignment(ms_fit)
expect_identical(
levels(preds$.pred_cluster),
levels(assignments$.cluster)
)
})
test_that("extract_centroids work", {
skip_if_not_installed("LPCM")
set.seed(1234)
ms_fit <- mean_shift(bandwidth = 0.5) |>
set_engine("LPCM") |>
fit(~., mtcars)
set.seed(1234)
ref_res <- LPCM::ms(as.matrix(mtcars), h = 0.5, plot = FALSE)
ref_centroids <- sweep(ref_res$cluster.center, 2, ref_res$scaled.by, "*")
ref_centroids <- tibble::as_tibble(ref_centroids, .name_repair = "minimal")
colnames(ref_centroids) <- colnames(mtcars)
expect_equal(
extract_centroids(ms_fit) |> dplyr::select(-.cluster),
ref_centroids
)
})
test_that("extract_cluster_assignment works", {
skip_if_not_installed("LPCM")
set.seed(1234)
ms_fit <- mean_shift(bandwidth = 0.5) |>
set_engine("LPCM") |>
fit(~., mtcars)
set.seed(1234)
ref_res <- LPCM::ms(as.matrix(mtcars), h = 0.5, plot = FALSE)
expect_equal(
extract_cluster_assignment(ms_fit)$.cluster |> as.numeric(),
ref_res$cluster.label
)
})
test_that("Right classes", {
expect_equal(
class(mean_shift()),
c("mean_shift", "cluster_spec", "unsupervised_spec")
)
})
test_that("printing", {
expect_snapshot(
mean_shift()
)
expect_snapshot(
mean_shift(bandwidth = 0.5)
)
})
test_that("updating", {
expect_snapshot(
mean_shift(bandwidth = 0.5) |>
update(bandwidth = tune())
)
})
test_that("errors if `bandwidth` isn't specified", {
skip_if_not_installed("LPCM")
expect_snapshot(
error = TRUE,
mean_shift() |>
set_engine("LPCM") |>
fit(~., data = mtcars)
)
})
# ------------------------------------------------------------------------------
# meanShiftR engine
test_that("predictions (meanShiftR)", {
skip_if_not_installed("meanShiftR")
set.seed(1234)
x_scaled <- as.data.frame(scale(mtcars))
ms_fit <- mean_shift(bandwidth = 3) |>
set_engine("meanShiftR") |>
fit(~., x_scaled)
preds <- predict(ms_fit, x_scaled)
expect_s3_class(preds$.pred_cluster, "factor")
expect_identical(nrow(preds), nrow(mtcars))
assignments <- extract_cluster_assignment(ms_fit)
expect_identical(
levels(preds$.pred_cluster),
levels(assignments$.cluster)
)
expect_identical(preds$.pred_cluster, assignments$.cluster)
})
test_that("predictions on new data (meanShiftR)", {
skip_if_not_installed("meanShiftR")
set.seed(1234)
x_scaled <- as.data.frame(scale(mtcars))
train <- x_scaled[1:24, ]
new <- x_scaled[25:32, ]
ms_fit <- mean_shift(bandwidth = 3) |>
set_engine("meanShiftR") |>
fit(~., train)
preds <- predict(ms_fit, new)
expect_s3_class(preds$.pred_cluster, "factor")
expect_identical(nrow(preds), nrow(new))
expect_identical(
levels(preds$.pred_cluster),
levels(extract_cluster_assignment(ms_fit)$.cluster)
)
})
test_that("extract_centroids works (meanShiftR)", {
skip_if_not_installed("meanShiftR")
set.seed(1234)
x_scaled <- as.data.frame(scale(mtcars))
ms_fit <- mean_shift(bandwidth = 3) |>
set_engine("meanShiftR") |>
fit(~., x_scaled)
centroids <- extract_centroids(ms_fit)
expect_named(
centroids,
c(".cluster", colnames(mtcars))
)
expect_identical(
nrow(centroids),
length(unique(ms_fit$fit$assignment))
)
# Centroid values must align with cluster IDs: row i is the converged
# mode of cluster i, regardless of the order assignments first appear in.
fit <- ms_fit$fit
ids <- sort(unique(fit$assignment))
expected <- fit$value[match(ids, fit$assignment), , drop = FALSE]
expect_equal(
unname(as.matrix(centroids[, -1])),
unname(expected)
)
})
test_that("extract_cluster_assignment works (meanShiftR)", {
skip_if_not_installed("meanShiftR")
set.seed(1234)
x_scaled <- as.data.frame(scale(mtcars))
ms_fit <- mean_shift(bandwidth = 3) |>
set_engine("meanShiftR") |>
fit(~., x_scaled)
expect_equal(
extract_cluster_assignment(ms_fit)$.cluster |> as.numeric(),
ms_fit$fit$assignment
)
})
test_that("scalar bandwidth is recycled (meanShiftR)", {
skip_if_not_installed("meanShiftR")
set.seed(1234)
x <- as.data.frame(scale(mtcars))
ms_fit <- mean_shift(bandwidth = 3) |>
set_engine("meanShiftR") |>
fit(~., x)
expect_identical(ms_fit$fit$bandwidth, rep(3, ncol(mtcars)))
})
test_that("errors if `bandwidth` isn't specified (meanShiftR)", {
skip_if_not_installed("meanShiftR")
expect_snapshot(
error = TRUE,
mean_shift() |>
set_engine("meanShiftR") |>
fit(~., data = mtcars)
)
})
test_that("errors on bandwidth length mismatch (meanShiftR)", {
skip_if_not_installed("meanShiftR")
expect_snapshot(
error = TRUE,
.mean_shift_fit_meanShiftR(as.matrix(mtcars), bandwidth = c(1, 2))
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.