bnns_train: Internal function for training the BNN

View source: R/bnns.R

bnns_trainR Documentation

Internal function for training the BNN

Description

This function performs the actual fitting of the Bayesian Neural Network. It is called by the exported bnns methods and is not intended for direct use.

Usage

bnns_train(
  train_x,
  train_y,
  L = 1,
  nodes = rep(2, L),
  act_fn = rep(2, L),
  out_act_fn = 1,
  iter = 1000,
  warmup = 200,
  thin = 1,
  chains = 2,
  cores = 2,
  seed = 123,
  prior_weights = NULL,
  prior_bias = NULL,
  prior_sigma = NULL,
  verbose = FALSE,
  refresh = max(iter/10, 1),
  normalize = TRUE,
  ...
)

Arguments

train_x

A numeric matrix representing the input features (predictors) for training. Rows correspond to observations, and columns correspond to features.

train_y

A numeric vector representing the target values for training. Its length must match the number of rows in train_x.

L

An integer specifying the number of hidden layers in the neural network. Default is 1.

nodes

An integer or vector specifying the number of nodes in each hidden layer. If a single value is provided, it is applied to all layers. Default is 16.

act_fn

An integer or vector specifying the activation function(s) for the hidden layers. Options are:

  • 1 for tanh

  • 2 for sigmoid (default)

  • 3 for softplus

  • 4 for ReLU

  • 5 for linear

out_act_fn

An integer specifying the activation function for the output layer. Options are:

  • 1 for linear (default)

  • 2 for sigmoid

  • 3 for softmax

iter

An integer specifying the total number of iterations for the Stan sampler. Default is 1e3.

warmup

An integer specifying the number of warmup iterations for the Stan sampler. Default is 2e2.

thin

An integer specifying the thinning interval for Stan samples. Default is 1.

chains

An integer specifying the number of Markov chains. Default is 2.

cores

An integer specifying the number of CPU cores to use for parallel sampling. Default is 2.

seed

An integer specifying the random seed for reproducibility. Default is 123.

prior_weights

A list specifying the prior distribution for the weights in the neural network. The list must include two components:

  • dist: A character string specifying the distribution type. Supported values are "normal", "uniform", and "cauchy".

  • params: A named list specifying the parameters for the chosen distribution:

    • For "normal": Provide mean (mean of the distribution) and sd (standard deviation).

    • For "uniform": Provide alpha (lower bound) and beta (upper bound).

    • For "cauchy": Provide mu (location parameter) and sigma (scale parameter).

If prior_weights is NULL, the default prior is a normal(0, 1) distribution. For example:

  • list(dist = "normal", params = list(mean = 0, sd = 1))

  • list(dist = "uniform", params = list(alpha = -1, beta = 1))

  • list(dist = "cauchy", params = list(mu = 0, sigma = 2.5))

prior_bias

A list specifying the prior distribution for the biases in the neural network. The list must include two components:

  • dist: A character string specifying the distribution type. Supported values are "normal", "uniform", and "cauchy".

  • params: A named list specifying the parameters for the chosen distribution:

    • For "normal": Provide mean (mean of the distribution) and sd (standard deviation).

    • For "uniform": Provide alpha (lower bound) and beta (upper bound).

    • For "cauchy": Provide mu (location parameter) and sigma (scale parameter).

If prior_bias is NULL, the default prior is a normal(0, 1) distribution. For example:

  • list(dist = "normal", params = list(mean = 0, sd = 1))

  • list(dist = "uniform", params = list(alpha = -1, beta = 1))

  • list(dist = "cauchy", params = list(mu = 0, sigma = 2.5))

prior_sigma

A list specifying the prior distribution for the sigma parameter in regression models (out_act_fn = 1). This allows for setting priors on the standard deviation of the residuals. The list must include two components:

  • dist: A character string specifying the distribution type. Supported values are "half-normal" and "inverse-gamma".

  • params: A named list specifying the parameters for the chosen distribution:

    • For "half-normal": Provide sd (standard deviation of the half-normal distribution).

    • For "inverse-gamma": Provide shape (shape parameter) and scale (scale parameter).

If prior_sigma is NULL, the default prior is a half-normal(0, 1) distribution. For example:

  • list(dist = "half_normal", params = list(mean = 0, sd = 1))

  • list(dist = "inv_gamma", params = list(alpha = 1, beta = 1))

verbose

TRUE or FALSE: flag indicating whether to print intermediate output from Stan on the console, which might be helpful for model debugging.

refresh

refresh (integer) can be used to control how often the progress of the sampling is reported (i.e. show the progress every refresh iterations). By default, refresh = max(iter/10, 1). The progress indicator is turned off if refresh <= 0.

normalize

Logical. If TRUE (default), the input predictors are normalized to have zero mean and unit variance before training. Normalization ensures stable and efficient Bayesian sampling by standardizing the input scale, which is particularly beneficial for neural network training. If FALSE, no normalization is applied, and it is assumed that the input data is already pre-processed appropriately.

...

Currently not in use.

Details

The function uses the generate_stan_code function to dynamically generate Stan code based on the specified number of layers and nodes. Stan is then used to fit the Bayesian Neural Network.

Value

An object of class "bnns" containing the following components:

fit

The fitted Stan model object.

call

The matched call.

data

A list containing the Stan data used in the model.

See Also

stan

Examples


# Example usage:
train_x <- matrix(runif(20), nrow = 10, ncol = 2)
train_y <- rnorm(10)
model <- bnns::bnns_train(train_x, train_y,
  L = 1, nodes = 2, act_fn = 2,
  iter = 1e1, warmup = 5, chains = 1
)

# Access Stan model fit
model$fit


bnns documentation built on April 3, 2025, 6:12 p.m.