task_dataset: Create a Dataset from a Task

task_datasetR Documentation

Create a Dataset from a Task

Description

Creates a torch dataset from an mlr3 Task. The resulting dataset's ⁠$.get_batch()⁠ method returns a list with elements x, y and index:

  • x is a list with tensors, whose content is defined by the parameter feature_ingress_tokens.

  • y is the target variable and its content is defined by the parameter target_batchgetter.

  • .index is the index of the batch in the task's data.

The data is returned on the device specified by the parameter device.

Usage

task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL)

Arguments

task

(Task)
The task for which to build the dataset.

feature_ingress_tokens

(named list() of TorchIngressToken)
Each ingress token defines one item in the ⁠$x⁠ value of a batch with corresponding names.

target_batchgetter

(⁠function(data, device)⁠)
A function taking in arguments data, which is a data.table containing only the target variable, and device. It must return the target as a torch tensor on the selected device.

Value

torch::dataset

Examples


task = tsk("iris")
sepal_ingress = TorchIngressToken(
  features = c("Sepal.Length", "Sepal.Width"),
  batchgetter = batchgetter_num,
  shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
  features = c("Petal.Length", "Petal.Width"),
  batchgetter = batchgetter_num,
  shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)

target_batchgetter = function(data) {
  torch_tensor(data = data[[1L]], dtype = torch_float32())$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter)
batch = dataset$.getbatch(1:10)
batch


mlr3torch documentation built on April 4, 2025, 3:03 a.m.