tests/testthat/test-recipe_integration.R

# tests/testthat/test-recipe_integration.R
library(testthat)
library(parsnip)
library(bnns)
library(dplyr)

test_that("bnns regression works with recipes and step_normalize", {
  skip_if_not_installed("recipes")
  skip_if_not_installed("workflows")
  skip_on_cran()
  
  library(recipes)
  library(workflows)
  
  # 1. Define a recipe with normalization
  rec <- recipe(mpg ~ hp + wt + disp, data = mtcars) %>%
    step_normalize(all_numeric_predictors())
  
  # 2. Define the model
  reg_spec <- mlp(mode = "regression", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  # 3. Create and fit the workflow
  wf <- workflow() %>%
    add_recipe(rec) %>%
    add_model(reg_spec)
  
  fit_wf <- fit(wf, data = mtcars)
  
  # Check if model engine fitted properly
  engine_fit <- extract_fit_engine(fit_wf)
  expect_s3_class(engine_fit, "bnns")
  
  # 4. Predict on new data
  preds <- predict(fit_wf, new_data = mtcars[1:5, ])
  
  # Assertions
  expect_s3_class(preds, "tbl_df")
  expect_equal(names(preds), ".pred")
  expect_equal(nrow(preds), 5)
  expect_true(is.numeric(preds$.pred))
})

test_that("bnns classification works with recipes and step_normalize", {
  skip_if_not_installed("recipes")
  skip_if_not_installed("workflows")
  skip_on_cran()

  library(recipes)
  library(workflows)

  df_bin <- iris %>% 
    filter(Species != "virginica") %>%
    mutate(Species = droplevels(Species))
  
  rec <- recipe(Species ~ Sepal.Length + Sepal.Width, data = df_bin) %>%
    step_normalize(all_numeric_predictors())
  
  class_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  wf <- workflow() %>%
    add_recipe(rec) %>%
    add_model(class_spec)
  
  fit_wf <- fit(wf, data = df_bin)
  
  # Predict probabilities to ensure factor levels passed through the recipe unchanged
  preds <- predict(fit_wf, new_data = df_bin[1:5, ], type = "prob")
  
  expect_s3_class(preds, "tbl_df")
  expect_equal(names(preds), c(".pred_setosa", ".pred_versicolor"))
  expect_equal(nrow(preds), 5)
})

test_that("bnns regression handles missing data with step_impute_mean", {
  skip_if_not_installed("recipes")
  skip_if_not_installed("workflows")
  skip_on_cran()

  library(recipes)
  library(workflows)

  # Create data with NAs
  mtcars_na <- mtcars
  mtcars_na$hp[c(1, 5, 10)] <- NA
  mtcars_na$wt[c(2, 6, 12)] <- NA

  # 1. Define a recipe with mean imputation
  rec <- recipe(mpg ~ hp + wt + disp, data = mtcars_na) %>%
    step_impute_mean(all_numeric_predictors())
  
  # 2. Define the model
  reg_spec <- mlp(mode = "regression", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  # 3. Create and fit the workflow
  wf <- workflow() %>%
    add_recipe(rec) %>%
    add_model(reg_spec)
  
  fit_wf <- fit(wf, data = mtcars_na)
  
  # Check if model engine fitted properly
  engine_fit <- extract_fit_engine(fit_wf)
  expect_s3_class(engine_fit, "bnns")
  
  # 4. Predict on new data containing NAs
  preds <- predict(fit_wf, new_data = mtcars_na[1:5, ])
  
  # Assertions
  expect_s3_class(preds, "tbl_df")
  expect_equal(names(preds), ".pred")
  expect_equal(nrow(preds), 5)
  expect_true(is.numeric(preds$.pred))
  expect_false(any(is.na(preds$.pred)))
})

test_that("bnns classification handles missing factor levels with step_unknown", {
  skip_if_not_installed("recipes")
  skip_if_not_installed("workflows")
  skip_on_cran()

  library(recipes)
  library(workflows)

  # Create binary classification data with a categorical predictor containing NAs
  df_bin <- iris %>% 
    filter(Species != "virginica") %>%
    mutate(
      Species = droplevels(Species),
      Category = factor(rep(c("A", "B", NA, "A"), 25))
    )
  
  # 1. Define a recipe with unknown imputation for categorical predictors
  rec <- recipe(Species ~ Sepal.Length + Sepal.Width + Category, data = df_bin) %>%
    step_unknown(all_nominal_predictors())
  
  # 2. Define the model
  class_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  # 3. Create and fit the workflow
  wf <- workflow() %>%
    add_recipe(rec) %>%
    add_model(class_spec)
  
  fit_wf <- fit(wf, data = df_bin)
  
  # Check if model engine fitted properly
  engine_fit <- extract_fit_engine(fit_wf)
  expect_s3_class(engine_fit, "bnns")
  
  # 4. Predict hard classes on new data containing NAs
  preds <- predict(fit_wf, new_data = df_bin[1:5, ], type = "class")
  
  # Assertions to ensure predictions executed successfully and without NAs
  expect_s3_class(preds, "tbl_df")
  expect_equal(names(preds), ".pred_class")
  expect_equal(nrow(preds), 5)
  expect_false(any(is.na(preds$.pred_class)))
})

Try the bnns package in your browser

Any scripts or data that you put into this service are public.

bnns documentation built on June 8, 2026, 1:06 a.m.