#' Build and fit a neural network with random effects
#' @param x A numeric matrix of predictor variables
#' @param y A numeric or integer matrix of response variables
#' @param n_z, The number of latent random variables to include as predictors
#' alongside x
#' @param layers A list of \code{\link{layer}} objects. Note that \code{n_nodes}
#' in the final layer must match \code{ncol(y)}.
#' @param error_distribution An \code{\link{distribution}} object determining
#' the error distribution for the response variables y.
#' @param z_prior A \code{\link{distribution}} object (standard Gaussian by
#' default) determining the prior on the latent variables.
#' @param fit Logical. Should the model be fitted or should an untrained model
#' be returned? Defaults to \code{TRUE}.
#' @param mistnet_optimizer passed to \code{\link{mistnet_fit}}. By default,
#' models are fitted using \code{\link{mistnet_fit_optimx}} using
#' \code{method = "L-BFGS-B"}.
#' @param ... Additional arguments to \code{\link{mistnet_fit}}
#' @return An object of class \code{network} and subclass \code{mistnet_network}.
#' This object will contain the original \code{x} and \code{y} matrices,
#' a list of adjustable parameters (\code{par_list}), [[etc.]]
#' @useDynLib mistnet2
#' @importFrom methods is
#' @importFrom optimx optimx
#' @importFrom assertthat assert_that is.scalar is.count are_equal noNA is.flag
#' @importFrom purrr every
#' @export
#' @examples
#' set.seed(1)
#'
#' # Load data from the `vegan` package
#' data(mite, mite.env, package = "vegan")
#'
#' # x is a matrix of environmental predictors
#' x = scale(model.matrix(~., data = mite.env)[, -1])
#'
#' # y is a matrix of abundances (counts) for 35 species of mites
#' y = as.matrix(mite)
#'
#' # Fit a neural network with one hidden layer of 10 nodes and an elu
#' # activation function. The response variable has a Poisson distribution
#' # with a log link (exp_activator). The prior distributions for each layer
#' # are each standard normal distributions, and two latent variables are used.
#' net = mistnet(
#' x = x,
#' y = y,
#' n_z = 2,
#' layers = list(
#' layer(
#' activator = elu_activator,
#' n_nodes = 10,
#' weight_prior = make_distribution("NO", mu = 0, sigma = 1)
#' ),
#' layer(
#' activator = exp_activator,
#' n_nodes = ncol(y),
#' weight_prior = make_distribution("NO", mu = 0, sigma = 1)
#' )
#' ),
#' error_distribution = make_distribution("PO")
#' )
#'
#' print(net)
#'
#' # show the model's predictions for each layer
#' str(feedforward(net, par = unlist(net$par_list)))
#'
#' # Calculate the log-likelihood for each observation under the fitted model
#' log_prob(net, par = unlist(net$par_list), include_penalties = FALSE)
#'
#' # Include penalty terms from the prior to calculate the log-posterior instead
#' log_prob(net, par = unlist(net$par_list), include_penalties = TRUE)
mistnet = function(
x,
y,
n_z,
layers,
error_distribution,
z_prior = make_distribution("NO", mu = 0, sigma = 1),
fit = TRUE,
mistnet_optimizer = mistnet_fit_optimx,
...
){
if (is(layers, "layer")) {
# Correct for easy mistake with single-layer networks
layers = list(layers)
}
assert_that(is.matrix(x), is.numeric(x), noNA(x))
assert_that(is.matrix(y), is.numeric(y), noNA(y))
assert_that(are_equal(nrow(x), nrow(y)))
assert_that(every(layers, is, "layer"))
assert_that(is.count(n_z))
assert_that(are_equal(layers[[length(layers)]]$n_nodes, ncol(y)))
assert_that(is(z_prior, "distribution"))
assert_that(is(error_distribution, "distribution"))
assert_that(is.flag(fit))
if (!is.na(error_distribution$family_par$mu)) {
stop("user cannot specify mu for the error_distribution; learning mu is the network's job")
}
activators = lapply(layers, function(layer) layer$activator)
n = nrow(x)
par_list = list(
z = matrix(draw_samples(z_prior, n = n * n_z), nrow = n, ncol = n_z),
weights = make_weight_list(n_x = ncol(x), n_z = n_z, layers = layers),
biases = make_bias_list(layers = layers,
error_distribution = error_distribution,
activators = activators,
y = y),
error_distribution_par = get_adjustables(error_distribution)
)
# If a parameter is adjustable, replace it with NA. Its value is stored in
# the par_list, not in the distribution.
error_distribution$family_parameters = lapply(
error_distribution$family_parameters,
function(x){
if (is(x, "adjustable")) {
NA
}else{
x
}
}
)
assert_that(is.numeric(unlist(par_list)), noNA(unlist(par_list)))
network = list(
x = x,
y = y,
par_list = purrr::compact(par_list),
activators = activators,
weight_priors = lapply(layers, function(layer) layer$weight_prior),
z_prior = z_prior,
error_distribution = error_distribution,
optimization_results = list()
)
class(network) = c("mistnet_network", "network")
if (fit) {
network = mistnet_fit(network, ...)
}
return(network)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.