mlr_learners.module | R Documentation |
Create a torch learner from a torch module.
This Learner can be instantiated using the sugar function lrn()
:
lrn("classif.module", ...) lrn("regr.module", ...)
Supported task types: 'classif', 'regr'
Predict Types:
classif: 'response', 'prob'
regr: 'response'
Feature Types: “logical”, “integer”, “numeric”, “character”, “factor”, “ordered”, “POSIXct”, “Date”, “lazy_tensor”
mlr3::Learner
-> mlr3torch::LearnerTorch
-> LearnerTorchModule
mlr3::Learner$base_learner()
mlr3::Learner$configure()
mlr3::Learner$encapsulate()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$selected_features()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
new()
Creates a new instance of this R6 class.
LearnerTorchModule$new( module_generator = NULL, param_set = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL, callbacks = list(), packages = character(0), feature_types = NULL )
module_generator
(function
or nn_module_generator
)
A nn_module_generator
or function
returning an nn_module
.
Both must take as argument the task
for which to construct the network.
Other arguments to its initialize method can be provided as parameters.
param_set
(NULL
or ParamSet
)
If provided, contains the parameters for the module_generator.
If NULL
, parameters will be inferred from the module_generator.
ingress_tokens
(list
of TorchIngressToken()
)
A list with ingress tokens that defines how the dataset will be defined.
The names must correspond to the arguments of the network's forward method.
For numeric, categorical, and lazy tensor features, you can use ingress_num()
,
ingress_categ()
, and ingress_ltnsr()
to create them.
task_type
(character(1)
)
The task type, either "classif
" or "regr"
.
task_type
(character(1)
)
The task type.
properties
(NULL
or character()
)
The properties of the learner.
Defaults to all available properties for the given task type.
optimizer
(TorchOptimizer
)
The optimizer to use for training.
Per default, adam is used.
loss
(TorchLoss
)
The loss used to train the network.
Per default, mse is used for regression and cross_entropy for classification.
callbacks
(list()
of TorchCallback
s)
The callbacks. Must have unique ids.
packages
(character()
)
The R packages this object depends on.
feature_types
(NULL
or character()
)
The feature types. Defaults to all available feature types.
clone()
The objects of this class are cloneable with this method.
LearnerTorchModule$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
,
mlr_learners_torch_model
Other Learner:
mlr_learners.mlp
,
mlr_learners.tab_resnet
,
mlr_learners.torch_featureless
,
mlr_learners_torch
,
mlr_learners_torch_image
,
mlr_learners_torch_model
nn_one_layer = nn_module("nn_one_layer",
initialize = function(task, size_hidden) {
self$first = nn_linear(task$n_features, size_hidden)
self$second = nn_linear(size_hidden, output_dim_for(task))
},
# argument x corresponds to the ingress token x
forward = function(x) {
x = self$first(x)
x = nnf_relu(x)
self$second(x)
}
)
learner = lrn("classif.module",
module_generator = nn_one_layer,
ingress_tokens = list(x = ingress_num()),
epochs = 10,
size_hidden = 20,
batch_size = 16
)
task = tsk("iris")
learner$train(task)
learner$network
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.