tests/testthat/test-tune_integration.R

# tests/testthat/test-tune_integration.R
library(testthat)
library(parsnip)
library(bnns)
library(dplyr)

test_that("bnns regression cross-validation works with tune_grid", {
  skip_if_not_installed("rsample")
  skip_if_not_installed("tune")
  skip_if_not_installed("workflows")
  skip_on_cran()
  
  library(rsample)
  library(tune)
  library(workflows)
  
  # Fast sampling parameters for testing. Tuning the hidden_units parameter.
  reg_spec <- mlp(mode = "regression", hidden_units = tune(), epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  reg_wf <- workflow() %>%
    add_model(reg_spec) %>%
    add_formula(mpg ~ hp + wt)
  
  set.seed(123)
  cv_folds <- vfold_cv(mtcars, v = 2)
  
  reg_grid <- data.frame(hidden_units = c(2, 3))
  
  tune_res <- tune_grid(
    reg_wf,
    resamples = cv_folds,
    grid = reg_grid,
    control = control_grid(save_pred = TRUE)
  )
  
  expect_s3_class(tune_res, "tune_results")
  
  metrics <- collect_metrics(tune_res)
  expect_true(nrow(metrics) > 0)
  
  preds <- collect_predictions(tune_res)
  expect_true(nrow(preds) > 0)
  expect_true(".pred" %in% names(preds))
})

test_that("bnns classification cross-validation works with fit_resamples", {
  skip_if_not_installed("rsample")
  skip_if_not_installed("tune")
  skip_if_not_installed("workflows")
  skip_on_cran()

  library(rsample)
  library(tune)
  library(workflows)

  # Create a binary target
  df_bin <- iris %>% 
    filter(Species != "virginica") %>%
    mutate(Species = droplevels(Species))
  
  class_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  class_wf <- workflow() %>%
    add_model(class_spec) %>%
    add_formula(Species ~ Sepal.Length + Sepal.Width)
  
  set.seed(123)
  cv_folds <- vfold_cv(df_bin, v = 2)
  
  res <- fit_resamples(
    class_wf,
    resamples = cv_folds,
    control = control_resamples(save_pred = TRUE)
  )
  
  expect_s3_class(res, "tune_results")
  
  metrics <- collect_metrics(res)
  expect_true(nrow(metrics) > 0)
  
  preds <- collect_predictions(res)
  expect_true(nrow(preds) > 0)
  expect_true(all(c(".pred_class", "id", ".row") %in% names(preds)))
})

Try the bnns package in your browser

Any scripts or data that you put into this service are public.

bnns documentation built on June 8, 2026, 1:06 a.m.