knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
The bnns
package provides an efficient and user-friendly implementation of Bayesian Neural Networks (BNNs) for regression, binary classification, and multiclass classification problems. By integrating Bayesian inference, bnns
allows for uncertainty quantification in predictions and robust parameter estimation.
This vignette covers: 1. Installing and loading the package 2. Preparing data 3. Fitting a BNN model 4. Summarizing the model 5. Making predictions 6. Model evaluation 7. Customizing prior
To install the package, use the following commands:
# Install from CRAN (if available) # install.packages("bnns") # Or install the development version from GitHub # devtools::install_github("swarnendu-stat/bnns")
Load the package in your R session:
library(bnns)
The bnns
package expects data in the form of matrices for predictors and a vector for responses.
Here’s an example of generating synthetic data:
# Generate training data set.seed(123) df <- data.frame(x1 = runif(10), x2 = runif(10), y = rnorm(10))
For binary or multiclass classification:
# Binary classification response df$y_bin <- sample(0:1, 10, replace = TRUE) # Multiclass classification response df$y_cat <- factor(sample(letters[1:3], 10, replace = TRUE)) # 3 classes
Fit a Bayesian Neural Network using the bnns()
function. Specify the network architecture using arguments like the number of layers (L
), nodes per layer (nodes
), and activation functions (act_fn
).
model_reg <- bnns( y ~ -1 + x1 + x2, data = df, L = 1, # Number of hidden layers nodes = 2, # Nodes per layer act_fn = 3, # Activation functions: 3 = ReLU out_act_fn = 1, # Output activation function: 1 = Identity (for regression) iter = 1e1, # Very low number of iteration is shown, increase to at least 1e3 for meaningful inference warmup = 5, # Very low number of warmup is shown, increase to at least 2e2 for meaningful inference chains = 1 )
model_bin <- bnns( y_bin ~ -1 + x1 + x2, data = df, L = 1, nodes = c(16), act_fn = c(2), out_act_fn = 2, # Output activation: 2 = Logistic sigmoid iter = 2e2, warmup = 1e2, chains = 1 )
model_cat <- bnns( y_cat ~ -1 + x1 + x2, data = df, L = 3, nodes = c(32, 16, 8), act_fn = c(3, 2, 2), out_act_fn = 3, # Output activation: 3 = Softmax iter = 2e2, warmup = 1e2, chains = 1 )
Use the summary()
function to view details of the fitted model, including the network architecture, posterior distributions, and predictive performance.
summary(model_reg)
summary(model_bin) summary(model_cat)
The predict()
function generates predictions for new data. The format of predictions depends on the output activation function.
# New data test_x <- matrix(runif(10), nrow = 5, ncol = 2) |> data.frame() |> `colnames<-`(c("x1", "x2")) # Regression predictions pred_reg <- predict(model_reg, test_x)
# Binary classification predictions pred_bin <- predict(model_bin, test_x) # Multiclass classification predictions pred_cat <- predict(model_cat, test_x)
The bnns
package includes utility functions like measure_cont
, measure_bin
, and measure_cat
for evaluating model performance.
# True responses test_y <- rnorm(5) # Evaluate predictions metrics_reg <- measure_cont(obs = test_y, pred = pred_reg) print(metrics_reg)
# True responses test_y_bin <- sample(c(rep(0, 2), rep(1, 3)), 5) # Evaluate predictions metrics_bin <- measure_bin(obs = test_y_bin, pred = pred_bin)
# True responses test_y_cat <- factor(sample(letters[1:3], 5, replace = TRUE)) # Evaluate predictions metrics_cat <- measure_cat(obs = test_y_cat, pred = pred_cat)
Customized priors can be used for weights as well as the sigma
parameter (for regression). Here we show an example use of a Cauchy
prior for weights in multi-classification case.
model_cat_cauchy <- bnns( y_cat ~ -1 + x1 + x2, data = df, L = 3, nodes = c(32, 16, 8), act_fn = c(3, 2, 2), out_act_fn = 3, # Output activation: 3 = Softmax iter = 2e2, warmup = 1e2, chains = 1, prior_weights = list(dist = "cauchy", params = list(mu = 0, sigma = 2.5)) )
# Evaluate predictions metrics_cat_cauchy <- measure_cat(obs = test_y_cat, pred = predict(model_cat_cauchy, test_x))
For more details, consult the source code on GitHub.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.