Description Format Construction Learner Methods Super classes Methods See Also Examples
Neural Network using Keras and Tensorflow.
This learner allows for supplying a custom architecture.
Calls keras::fit()
from package keras.
Parameters:
Most of the parameters can be obtained from the keras
documentation.
Some exceptions are documented here.
model
: A compiled keras model suited for the task.
class_weight
: A named list of class-weights
for the different classes numbered from 0 to c-1 (for c classes).
Example: wts = c(0.5, 1) setNames(as.list(wts), seq_len(length(wts)) - 1)
callbacks
: A list of keras callbacks.
See ?callbacks
.
R6::R6Class()
inheriting from mlr3::LearnerMultioutput.
1 2 3 | LearnerMultioutputKeras$new()
mlr3::mlr_learners$get("multiout.keras")
mlr3::lrn("multiout.keras")
|
Keras Learners offer several methods for easy access to the stored models.
.$plot()
Plots the history, i.e. the train-validation loss during training.
.$save(file_path)
Dumps the model to a provided file_path in 'h5' format.
.$load_model_from_file(file_path)
Loads a model saved using saved
back into the learner.
The model needs to be saved separately when the learner is serialized.
In this case, the learner can be restored from this function.
Currently not implemented for 'TabNet'.
.$lr_find(task, epochs, lr_min, lr_max, batch_size)
Employ an implementation of the learning rate finder as popularized by
Jeremy Howard in fast.ai (http://course.fast.ai/) for the learner.
For more info on parameters, see find_lr
.
mlr3::Learner
-> mlr3multioutput::LearnerMultioutput
-> LearnerMultioutputKeras
new()
LearnerMultioutputKeras$new( id = "multiout.keras", predict_types = c("response", "prob"), feature_types = c("integer", "numeric"), properties = c("multilabel"), packages = "keras", man = "mlr3keras::mlr_learners_classif.keras", architecture = KerasArchitectureCustomModel$new() )
save()
LearnerMultioutputKeras$save(filepath)
load_model_from_file()
LearnerMultioutputKeras$load_model_from_file(filepath)
plot()
LearnerMultioutputKeras$plot()
lr_find()
LearnerMultioutputKeras$lr_find( task, epochs = 5L, lr_min = 10^-4, lr_max = 0.8, batch_size = 128L )
clone()
The objects of this class are cloneable with this method.
LearnerMultioutputKeras$clone(deep = FALSE)
deep
Whether to make a deep clone.
Dictionary of Learners: mlr3::mlr_learners
1 2 3 4 5 6 7 8 9 10 11 12 13 | # Define a model
library(keras)
model = keras_model_sequential() %>%
layer_dense(units = 12L, input_shape = 4L, activation = "relu") %>%
layer_dense(units = 12L, activation = "relu") %>%
layer_dense(units = 3L, activation = "sigmoid") %>%
compile(optimizer = optimizer_sgd(),
loss = "binary_crossentropy",
metrics = "accuracy")
# Create the learner
learner = LearnerMultioutputKeras$new()
learner$param_set$values$model = model
learner$train(mlr3::mlr_tasks$get("flags"))
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.