mlr_learners.module: Learner Torch Module

mlr_learners.moduleR Documentation

Learner Torch Module

Description

Create a torch learner from a torch module.

Dictionary

This Learner can be instantiated using the sugar function lrn():

lrn("classif.module", ...)
lrn("regr.module", ...)

Properties

  • Supported task types: 'classif', 'regr'

  • Predict Types:

    • classif: 'response', 'prob'

    • regr: 'response'

  • Feature Types: “logical”, “integer”, “numeric”, “character”, “factor”, “ordered”, “POSIXct”, “Date”, “lazy_tensor”

  • Required Packages: mlr3, mlr3torch, torch

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModule

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
LearnerTorchModule$new(
  module_generator = NULL,
  param_set = NULL,
  ingress_tokens = NULL,
  task_type,
  properties = NULL,
  optimizer = NULL,
  loss = NULL,
  callbacks = list(),
  packages = character(0),
  feature_types = NULL
)
Arguments
module_generator

(function or nn_module_generator)
A nn_module_generator or function returning an nn_module. Both must take as argument the task for which to construct the network. Other arguments to its initialize method can be provided as parameters.

param_set

(NULL or ParamSet)
If provided, contains the parameters for the module_generator. If NULL, parameters will be inferred from the module_generator.

ingress_tokens

(list of TorchIngressToken())
A list with ingress tokens that defines how the dataset will be defined. The names must correspond to the arguments of the network's forward method. For numeric, categorical, and lazy tensor features, you can use ingress_num(), ingress_categ(), and ingress_ltnsr() to create them.

task_type

(character(1))
The task type, either ⁠"classif⁠" or "regr".

task_type

(character(1))
The task type.

properties

(NULL or character())
The properties of the learner. Defaults to all available properties for the given task type.

optimizer

(TorchOptimizer)
The optimizer to use for training. Per default, adam is used.

loss

(TorchLoss)
The loss used to train the network. Per default, mse is used for regression and cross_entropy for classification.

callbacks

(list() of TorchCallbacks)
The callbacks. Must have unique ids.

packages

(character())
The R packages this object depends on.

feature_types

(NULL or character())
The feature types. Defaults to all available feature types.


Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerTorchModule$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Other Learner: mlr_learners.mlp, mlr_learners.tab_resnet, mlr_learners.torch_featureless, mlr_learners_torch, mlr_learners_torch_image, mlr_learners_torch_model

Other Learner: mlr_learners.mlp, mlr_learners.tab_resnet, mlr_learners.torch_featureless, mlr_learners_torch, mlr_learners_torch_image, mlr_learners_torch_model

Examples


nn_one_layer = nn_module("nn_one_layer",
  initialize = function(task, size_hidden) {
    self$first = nn_linear(task$n_features, size_hidden)
    self$second = nn_linear(size_hidden, output_dim_for(task))
  },
  # argument x corresponds to the ingress token x
  forward = function(x) {
    x = self$first(x)
    x = nnf_relu(x)
    self$second(x)
  }
)
learner = lrn("classif.module",
  module_generator = nn_one_layer,
  ingress_tokens = list(x = ingress_num()),
  epochs = 10,
  size_hidden = 20,
  batch_size = 16
)
task = tsk("iris")
learner$train(task)
learner$network


mlr-org/mlr3torch documentation built on April 17, 2025, 8:22 p.m.