bnns_train | R Documentation |
This function performs the actual fitting of the Bayesian Neural Network. It is called by the exported bnns methods and is not intended for direct use.
bnns_train(
train_x,
train_y,
L = 1,
nodes = rep(2, L),
act_fn = rep(2, L),
out_act_fn = 1,
iter = 1000,
warmup = 200,
thin = 1,
chains = 2,
cores = 2,
seed = 123,
prior_weights = NULL,
prior_bias = NULL,
prior_sigma = NULL,
verbose = FALSE,
refresh = max(iter/10, 1),
normalize = TRUE,
...
)
train_x |
A numeric matrix representing the input features (predictors) for training. Rows correspond to observations, and columns correspond to features. |
train_y |
A numeric vector representing the target values for training. Its length must match the number of rows in |
L |
An integer specifying the number of hidden layers in the neural network. Default is 1. |
nodes |
An integer or vector specifying the number of nodes in each hidden layer. If a single value is provided, it is applied to all layers. Default is 16. |
act_fn |
An integer or vector specifying the activation function(s) for the hidden layers. Options are:
|
out_act_fn |
An integer specifying the activation function for the output layer. Options are:
|
iter |
An integer specifying the total number of iterations for the Stan sampler. Default is |
warmup |
An integer specifying the number of warmup iterations for the Stan sampler. Default is |
thin |
An integer specifying the thinning interval for Stan samples. Default is 1. |
chains |
An integer specifying the number of Markov chains. Default is 2. |
cores |
An integer specifying the number of CPU cores to use for parallel sampling. Default is 2. |
seed |
An integer specifying the random seed for reproducibility. Default is 123. |
prior_weights |
A list specifying the prior distribution for the weights in the neural network. The list must include two components:
If
|
prior_bias |
A list specifying the prior distribution for the biases in the neural network. The list must include two components:
If
|
prior_sigma |
A list specifying the prior distribution for the
If
|
verbose |
TRUE or FALSE: flag indicating whether to print intermediate output from Stan on the console, which might be helpful for model debugging. |
refresh |
refresh (integer) can be used to control how often the progress of the sampling is reported (i.e. show the progress every refresh iterations). By default, refresh = max(iter/10, 1). The progress indicator is turned off if refresh <= 0. |
normalize |
Logical. If |
... |
Currently not in use. |
The function uses the generate_stan_code
function to dynamically generate Stan code based on the specified number of layers and nodes. Stan is then used to fit the Bayesian Neural Network.
An object of class "bnns"
containing the following components:
fit
The fitted Stan model object.
call
The matched call.
data
A list containing the Stan data used in the model.
stan
# Example usage:
train_x <- matrix(runif(20), nrow = 10, ncol = 2)
train_y <- rnorm(10)
model <- bnns::bnns_train(train_x, train_y,
L = 1, nodes = 2, act_fn = 2,
iter = 1e1, warmup = 5, chains = 1
)
# Access Stan model fit
model$fit
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.