LearnerRegrKeras | R Documentation |
Neural Network using Keras and Tensorflow.
This learner allows for supplying a custom architecture.
Calls keras::fit()
from package keras.
Parameters:
Most of the parameters can be obtained from the keras
documentation.
Some exceptions are documented here.
model
: A compiled keras model.
class_weight
: needs to be a named list of class-weights
for the different classes numbered from 0 to c-1 (for c classes).
Example: wts = c(0.5, 1) setNames(as.list(wts), seq_len(length(wts)) - 1)
callbacks
: A list of keras callbacks.
See ?callbacks
.
R6::R6Class()
inheriting from mlr3::LearnerRegr.
LearnerRegrKeras$new() mlr3::mlr_learners$get("regr.keras") mlr3::lrn("regr.keras")
Keras Learners offer several methods for easy access to the stored models.
.$plot()
Plots the history, i.e. the train-validation loss during training.
.$save(file_path)
Dumps the model to a provided file_path in 'h5' format.
.$load_model_from_file(file_path)
Loads a model saved using saved
back into the learner.
The model needs to be saved separately when the learner is serialized.
In this case, the learner can be restored from this function.
Currently not implemented for 'TabNet'.
.$lr_find(task, epochs, lr_min, lr_max, batch_size)
Employ an implementation of the learning rate finder as popularized by
Jeremy Howard in fast.ai (http://course.fast.ai/) for the learner.
For more info on parameters, see find_lr
.
Dictionary of Learners: mlr3::mlr_learners
# Define a model library(keras) model = keras_model_sequential() %>% layer_dense(units = 12L, input_shape = 10L, activation = "relu") %>% layer_dense(units = 12L, activation = "relu") %>% layer_dense(units = 1L, activation = "linear") %>% compile(optimizer = optimizer_sgd(), loss = "mean_squared_error", metrics = "mean_squared_error") # Create the learner learner = LearnerRegrKeras$new() learner$param_set$values$model = model learner$train(mlr3::mlr_tasks$get("mtcars"))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.