Nothing
# tests/testthat/test-parsnip_integration.R
library(testthat)
library(parsnip)
library(bnns)
library(dplyr)
test_that("bnns regression parsnip integration works", {
skip_on_cran()
# Fast sampling parameters for testing
reg_spec <- mlp(mode = "regression", hidden_units = 2, epochs = 20) %>%
set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
fit_reg <- fit(reg_spec, mpg ~ hp + wt, data = mtcars)
# Check if model fitted properly
expect_s3_class(fit_reg$fit, "bnns")
# Test predictions
preds <- predict(fit_reg, new_data = mtcars[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_equal(names(preds), ".pred")
expect_equal(nrow(preds), 5)
expect_true(is.numeric(preds$.pred))
})
test_that("bnns binary classification parsnip integration works", {
skip_on_cran()
# Create a binary target
df_bin <- iris %>%
filter(Species != "virginica") %>%
mutate(Species = droplevels(Species))
bin_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
fit_bin <- fit(bin_spec, Species ~ Sepal.Length + Sepal.Width, data = df_bin)
# 1. Test Hard Class Predictions
class_preds <- predict(fit_bin, new_data = df_bin[1:5, ], type = "class")
expect_s3_class(class_preds, "tbl_df")
expect_equal(names(class_preds), ".pred_class")
expect_equal(nrow(class_preds), 5)
expect_true(is.factor(class_preds$.pred_class))
expect_equal(levels(class_preds$.pred_class), levels(df_bin$Species))
# 2. Test Probabilistic Predictions
prob_preds <- predict(fit_bin, new_data = df_bin[1:5, ], type = "prob")
expect_s3_class(prob_preds, "tbl_df")
expect_equal(names(prob_preds), c(".pred_setosa", ".pred_versicolor"))
expect_equal(nrow(prob_preds), 5)
# Verify probabilities sum to 1
row_sums <- rowSums(prob_preds)
expect_true(all(abs(row_sums - 1) < 1e-6))
})
test_that("bnns multiclass classification parsnip integration works", {
skip_on_cran()
multi_spec <- mlp(mode = "classification", hidden_units = 2, epochs = 20) %>%
set_engine("bnns", warmup = 10, refresh = 0, chains = 1)
fit_multi <- fit(multi_spec, Species ~ Sepal.Length + Sepal.Width, data = iris)
# 1. Test Hard Class Predictions
class_preds <- predict(fit_multi, new_data = iris[1:5, ], type = "class")
expect_s3_class(class_preds, "tbl_df")
expect_equal(names(class_preds), ".pred_class")
expect_equal(nrow(class_preds), 5)
expect_true(is.factor(class_preds$.pred_class))
expect_equal(levels(class_preds$.pred_class), levels(iris$Species))
# 2. Test Probabilistic Predictions
prob_preds <- predict(fit_multi, new_data = iris[1:5, ], type = "prob")
expect_s3_class(prob_preds, "tbl_df")
expect_equal(names(prob_preds), c(".pred_setosa", ".pred_versicolor", ".pred_virginica"))
expect_equal(nrow(prob_preds), 5)
# Verify probabilities sum to 1
row_sums <- rowSums(prob_preds)
expect_true(all(abs(row_sums - 1) < 1e-6))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.