nn_module | R Documentation |
Your models should also subclass this class.
nn_module( classname = NULL, inherit = nn_Module, ..., private = NULL, active = NULL, parent_env = parent.frame() )
classname |
an optional name for the module |
inherit |
an optional module to inherit from |
... |
methods implementation |
private |
passed to |
active |
passed to |
parent_env |
passed to |
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes.
You are expected to implement the initialize
and the forward
to create a
new nn_module
.
The initialize function will be called whenever a new instance of the nn_module
is created. We use the initialize functions to define submodules and parameters
of the module. For example:
initialize = function(input_size, output_size) { self$conv1 <- nn_conv2d(input_size, output_size, 5) self$conv2 <- nn_conv2d(output_size, output_size, 5) }
The initialize function can have any number of parameters. All objects
assigned to self$
will be available for other methods that you implement.
Tensors wrapped with nn_parameter()
or nn_buffer()
and submodules are
automatically tracked when assigned to self$
.
The initialize function is optional if the module you are defining doesn't have weights, submodules or buffers.
The forward method is called whenever an instance of nn_module
is called.
This is usually used to implement the computation that the module does with
the weights ad submodules defined in the initialize
function.
For example:
forward = function(input) { input <- self$conv1(input) input <- nnf_relu(input) input <- self$conv2(input) input <- nnf_relu(input) input }
The forward
function can use the self$training
attribute to make different
computations depending wether the model is training or not, for example if you
were implementing the dropout module.
if (torch_is_installed()) { model <- nn_module( initialize = function() { self$conv1 <- nn_conv2d(1, 20, 5) self$conv2 <- nn_conv2d(20, 20, 5) }, forward = function(input) { input <- self$conv1(input) input <- nnf_relu(input) input <- self$conv2(input) input <- nnf_relu(input) input } ) }
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.