Nothing
# tests/testthat/test-tune_integration.R
library(testthat)
library(parsnip)
library(bnns)
library(dplyr)
test_that("bnns regression cross-validation works with tune_grid", {
skip_if_not_installed("rsample")
skip_if_not_installed("tune")
skip_if_not_installed("workflows")
skip_on_cran()
library(rsample)
library(tune)
library(workflows)
# Fast sampling parameters for testing. Tuning the hidden_units parameter.
reg_spec <- mlp(mode = "regression", hidden_units = tune(), epochs = 20) %>%
set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
reg_wf <- workflow() %>%
add_model(reg_spec) %>%
add_formula(mpg ~ hp + wt)
set.seed(123)
cv_folds <- vfold_cv(mtcars, v = 2)
reg_grid <- data.frame(hidden_units = c(2, 3))
tune_res <- tune_grid(
reg_wf,
resamples = cv_folds,
grid = reg_grid,
control = control_grid(save_pred = TRUE)
)
expect_s3_class(tune_res, "tune_results")
metrics <- collect_metrics(tune_res)
expect_true(nrow(metrics) > 0)
preds <- collect_predictions(tune_res)
expect_true(nrow(preds) > 0)
expect_true(".pred" %in% names(preds))
})
test_that("bnns classification cross-validation works with fit_resamples", {
skip_if_not_installed("rsample")
skip_if_not_installed("tune")
skip_if_not_installed("workflows")
skip_on_cran()
library(rsample)
library(tune)
library(workflows)
# Create a binary target
df_bin <- iris %>%
filter(Species != "virginica") %>%
mutate(Species = droplevels(Species))
class_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
class_wf <- workflow() %>%
add_model(class_spec) %>%
add_formula(Species ~ Sepal.Length + Sepal.Width)
set.seed(123)
cv_folds <- vfold_cv(df_bin, v = 2)
res <- fit_resamples(
class_wf,
resamples = cv_folds,
control = control_resamples(save_pred = TRUE)
)
expect_s3_class(res, "tune_results")
metrics <- collect_metrics(res)
expect_true(nrow(metrics) > 0)
preds <- collect_predictions(res)
expect_true(nrow(preds) > 0)
expect_true(all(c(".pred_class", "id", ".row") %in% names(preds)))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.