tests/testthat/test-parsnip_integration.R

# tests/testthat/test-parsnip_integration.R
library(testthat)
library(parsnip)
library(bnns)
library(dplyr)

test_that("bnns regression parsnip integration works", {
  skip_on_cran()
  # Fast sampling parameters for testing
  reg_spec <- mlp(mode = "regression", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  fit_reg <- fit(reg_spec, mpg ~ hp + wt, data = mtcars)
  
  # Check if model fitted properly
  expect_s3_class(fit_reg$fit, "bnns")
  
  # Test predictions
  preds <- predict(fit_reg, new_data = mtcars[1:5, ])
  
  expect_s3_class(preds, "tbl_df")
  expect_equal(names(preds), ".pred")
  expect_equal(nrow(preds), 5)
  expect_true(is.numeric(preds$.pred))
})

test_that("bnns binary classification parsnip integration works", {
  skip_on_cran()
  # Create a binary target
  df_bin <- iris %>% 
    filter(Species != "virginica") %>%
    mutate(Species = droplevels(Species))
  
  bin_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  fit_bin <- fit(bin_spec, Species ~ Sepal.Length + Sepal.Width, data = df_bin)
  
  # 1. Test Hard Class Predictions
  class_preds <- predict(fit_bin, new_data = df_bin[1:5, ], type = "class")
  
  expect_s3_class(class_preds, "tbl_df")
  expect_equal(names(class_preds), ".pred_class")
  expect_equal(nrow(class_preds), 5)
  expect_true(is.factor(class_preds$.pred_class))
  expect_equal(levels(class_preds$.pred_class), levels(df_bin$Species))
  
  # 2. Test Probabilistic Predictions
  prob_preds <- predict(fit_bin, new_data = df_bin[1:5, ], type = "prob")
  
  expect_s3_class(prob_preds, "tbl_df")
  expect_equal(names(prob_preds), c(".pred_setosa", ".pred_versicolor"))
  expect_equal(nrow(prob_preds), 5)
  
  # Verify probabilities sum to 1
  row_sums <- rowSums(prob_preds)
  expect_true(all(abs(row_sums - 1) < 1e-6))
})

test_that("bnns multiclass classification parsnip integration works", {
  skip_on_cran()
  multi_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
    set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
  
  fit_multi <- fit(multi_spec, Species ~ Sepal.Length + Sepal.Width, data = iris)
  
  # 1. Test Hard Class Predictions
  class_preds <- predict(fit_multi, new_data = iris[1:5, ], type = "class")
  
  expect_s3_class(class_preds, "tbl_df")
  expect_equal(names(class_preds), ".pred_class")
  expect_equal(nrow(class_preds), 5)
  expect_true(is.factor(class_preds$.pred_class))
  expect_equal(levels(class_preds$.pred_class), levels(iris$Species))
  
  # 2. Test Probabilistic Predictions
  prob_preds <- predict(fit_multi, new_data = iris[1:5, ], type = "prob")
  
  expect_s3_class(prob_preds, "tbl_df")
  expect_equal(names(prob_preds), c(".pred_setosa", ".pred_versicolor", ".pred_virginica"))
  expect_equal(nrow(prob_preds), 5)
  
  # Verify probabilities sum to 1
  row_sums <- rowSums(prob_preds)
  expect_true(all(abs(row_sums - 1) < 1e-6))
})

Try the bnns package in your browser

Any scripts or data that you put into this service are public.

bnns documentation built on June 8, 2026, 1:06 a.m.