task_dataset | R Documentation |
Creates a torch dataset from an mlr3 Task
.
The resulting dataset's $.get_batch()
method returns a list with elements x
, y
and index
:
x
is a list with tensors, whose content is defined by the parameter feature_ingress_tokens
.
y
is the target variable and its content is defined by the parameter target_batchgetter
.
.index
is the index of the batch in the task's data.
The data is returned on the device specified by the parameter device
.
task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL)
task |
( |
feature_ingress_tokens |
(named |
target_batchgetter |
( |
torch::dataset
task = tsk("iris")
sepal_ingress = TorchIngressToken(
features = c("Sepal.Length", "Sepal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
features = c("Petal.Length", "Petal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)
target_batchgetter = function(data) {
torch_tensor(data = data[[1L]], dtype = torch_float32())$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter)
batch = dataset$.getbatch(1:10)
batch
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.