LearnerMultioutputKeras: Keras Neural Network with custom architecture (Multilabel)

Description Format Construction Learner Methods Super classes Methods See Also Examples

Description

Neural Network using Keras and Tensorflow. This learner allows for supplying a custom architecture. Calls keras::fit() from package keras.

Parameters: Most of the parameters can be obtained from the keras documentation. Some exceptions are documented here.

Format

R6::R6Class() inheriting from mlr3::LearnerMultioutput.

Construction

1
2
3
LearnerMultioutputKeras$new()
mlr3::mlr_learners$get("multiout.keras")
mlr3::lrn("multiout.keras")

Learner Methods

Keras Learners offer several methods for easy access to the stored models.

Super classes

mlr3::Learner -> mlr3multioutput::LearnerMultioutput -> LearnerMultioutputKeras

Methods

Public methods

Inherited methods

Method new()

Usage
LearnerMultioutputKeras$new(
  id = "multiout.keras",
  predict_types = c("response", "prob"),
  feature_types = c("integer", "numeric"),
  properties = c("multilabel"),
  packages = "keras",
  man = "mlr3keras::mlr_learners_classif.keras",
  architecture = KerasArchitectureCustomModel$new()
)

Method save()

Usage
LearnerMultioutputKeras$save(filepath)

Method load_model_from_file()

Usage
LearnerMultioutputKeras$load_model_from_file(filepath)

Method plot()

Usage
LearnerMultioutputKeras$plot()

Method lr_find()

Usage
LearnerMultioutputKeras$lr_find(
  task,
  epochs = 5L,
  lr_min = 10^-4,
  lr_max = 0.8,
  batch_size = 128L
)

Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerMultioutputKeras$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

See Also

Dictionary of Learners: mlr3::mlr_learners

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
 # Define a model
 library(keras)
 model = keras_model_sequential() %>%
 layer_dense(units = 12L, input_shape = 4L, activation = "relu") %>%
 layer_dense(units = 12L, activation = "relu") %>%
 layer_dense(units = 3L, activation = "sigmoid") %>%
   compile(optimizer = optimizer_sgd(),
     loss = "binary_crossentropy",
     metrics = "accuracy")
 # Create the learner
 learner = LearnerMultioutputKeras$new()
 learner$param_set$values$model = model
 learner$train(mlr3::mlr_tasks$get("flags"))

mlr-org/mlr3multioutput documentation built on Nov. 22, 2020, 1:17 p.m.