mlr_learners.torchvision | R Documentation |
Classic image classification networks from torchvision
.
Parameters from LearnerTorchImage
and
pretrained
:: logical(1)
Whether to use the pretrained model.
The final linear layer will be replaced with a new nn_linear
with the
number of classes inferred from the Task
.
Supported task types: "classif"
Predict Types: "response"
and "prob"
Feature Types: "lazy_tensor"
Required packages: "mlr3torch"
, "torch"
, "torchvision"
mlr3::Learner
-> mlr3torch::LearnerTorch
-> mlr3torch::LearnerTorchImage
-> LearnerTorchVision
mlr3::Learner$base_learner()
mlr3::Learner$configure()
mlr3::Learner$encapsulate()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$reset()
mlr3::Learner$selected_features()
mlr3::Learner$train()
mlr3torch::LearnerTorch$dataset()
mlr3torch::LearnerTorch$format()
mlr3torch::LearnerTorch$marshal()
mlr3torch::LearnerTorch$print()
mlr3torch::LearnerTorch$unmarshal()
new()
Creates a new instance of this R6 class.
LearnerTorchVision$new( name, module_generator, label, optimizer = NULL, loss = NULL, callbacks = list() )
name
(character(1)
)
The name of the network.
module_generator
(function(pretrained, num_classes)
)
Function that generates the network.
label
(character(1)
)
The label of the network.
#' @references
Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017).
“Imagenet classification with deep convolutional neural networks.”
Communications of the ACM, 60(6), 84–90.
Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018).
“Mobilenetv2: Inverted residuals and linear bottlenecks.”
In Proceedings of the IEEE conference on computer vision and pattern recognition, 4510–4520.
He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016).
“Deep residual learning for image recognition.”
In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
Simonyan, Karen, Zisserman, Andrew (2014).
“Very deep convolutional networks for large-scale image recognition.”
arXiv preprint arXiv:1409.1556.
optimizer
(TorchOptimizer
)
The optimizer to use for training.
Per default, adam is used.
loss
(TorchLoss
)
The loss used to train the network.
Per default, mse is used for regression and cross_entropy for classification.
callbacks
(list()
of TorchCallback
s)
The callbacks. Must have unique ids.
clone()
The objects of this class are cloneable with this method.
LearnerTorchVision$clone(deep = FALSE)
deep
Whether to make a deep clone.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.