build_pytorch_net: Build a Pytorch Multilayer Perceptron

Description Usage Arguments Details Examples

View source: R/build_pytorch_net.R

Description

Utility function to build an MLP with a choice of activation function and weight initialization with optional dropout and batch normalization.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
build_pytorch_net(
  n_in,
  n_out,
  nodes = c(32, 32),
  activation = "relu",
  act_pars = list(),
  dropout = 0.1,
  bias = TRUE,
  batch_norm = TRUE,
  batch_pars = list(eps = 1e-05, momentum = 0.1, affine = TRUE),
  init = "uniform",
  init_pars = list()
)

Arguments

n_in

(integer(1))
Number of input features.

n_out

(integer(1))
Number of targets.

nodes

(numeric())
Hidden nodes in network, each element in vector represents number of hidden nodes in respective layer.

activation

(character(1)|list())
Activation function, can either be a single character and the same function is used in all layers, or a list of length length(nodes). See get_activation for options.

act_pars

(list())
Passed to get_activation.

dropout

(numeric())
Optional dropout layer, if NULL then no dropout layer added otherwise either a single numeric which will be added to all layers or a vector of differing drop-out amounts.

bias

(logical(1))
If TRUE (default) then a bias parameter is added to all linear layers.

batch_norm

(logical(1))
If TRUE (default) then batch normalisation is applied to all layers.

batch_pars

(list())
Parameters for batch normalisation, see reticulate::py_help(torch$nn$BatchNorm1d).

init

(character(1))
Weight initialization method. See get_init for options.

init_pars

(list())
Passed to get_init.

Details

This function is a helper for R users with less Python experience. Currently it is limited to simple MLPs. More advanced networks will require manual creation with reticulate.

Examples

1
2
3
4
5
build_pytorch_net(10, 1)

build_pytorch_net(n_in = 10, n_out = 1, nodes = c(4, 4, 4), activation = "elu",
act_pars = list(alpha = 0.5), dropout = c(0.2, 0.1, 0.6),
batch_norm = TRUE, init = "kaiming_normal", init_pars = list(non_linearity = "relu"))

mlr3learners/mlr3learners.pycox documentation built on Sept. 24, 2020, 10:40 a.m.