Lrnr_gru_keras: Recurrent Neural Network with Gated Recurrent Unit (GRU) with...

Description Format Value Parameters See Also Examples

Description

This learner supports Recurrent Neural Networks (RNNs) with Gated Recurrent Units (GRU). This learner leverages the same principles as LSTM networks but is more streamlined and thus cheaper to run, at the expense of some loss in representational power. This learner uses the keras package. Note that all preprocessing, such as differencing and seasonal effects for time series, should be addressed before using this learner. Desired lags of the time series should be added as predictors before using the learner.

Format

An R6Class object inheriting from Lrnr_base.

Value

A learner object inheriting from Lrnr_base with methods for training and prediction. For a full list of learner functionality, see the complete documentation of Lrnr_base.

Parameters

See Also

Other Learners: Custom_chain, Lrnr_HarmonicReg, Lrnr_arima, Lrnr_bartMachine, Lrnr_base, Lrnr_bayesglm, Lrnr_bilstm, Lrnr_caret, Lrnr_cv_selector, Lrnr_cv, Lrnr_dbarts, Lrnr_define_interactions, Lrnr_density_discretize, Lrnr_density_hse, Lrnr_density_semiparametric, Lrnr_earth, Lrnr_expSmooth, Lrnr_gam, Lrnr_ga, Lrnr_gbm, Lrnr_glm_fast, Lrnr_glmnet, Lrnr_glm, Lrnr_grf, Lrnr_gts, Lrnr_h2o_grid, Lrnr_hal9001, Lrnr_haldensify, Lrnr_hts, Lrnr_independent_binomial, Lrnr_lightgbm, Lrnr_lstm_keras, Lrnr_mean, Lrnr_multiple_ts, Lrnr_multivariate, Lrnr_nnet, Lrnr_nnls, Lrnr_optim, Lrnr_pca, Lrnr_pkg_SuperLearner, Lrnr_polspline, Lrnr_pooled_hazards, Lrnr_randomForest, Lrnr_ranger, Lrnr_revere_task, Lrnr_rpart, Lrnr_rugarch, Lrnr_screener_augment, Lrnr_screener_coefs, Lrnr_screener_correlation, Lrnr_screener_importance, Lrnr_sl, Lrnr_solnp_density, Lrnr_solnp, Lrnr_stratified, Lrnr_subset_covariates, Lrnr_svm, Lrnr_tsDyn, Lrnr_ts_weights, Lrnr_xgboost, Pipeline, Stack, define_h2o_X(), undocumented_learner

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
## Not run: 
library(origami)
data(bsds)

# make folds appropriate for time-series cross-validation
folds <- make_folds(bsds,
  fold_fun = folds_rolling_window, window_size = 500,
  validation_size = 100, gap = 0, batch = 50
)

# build task by passing in external folds structure
task <- sl3_Task$new(
  data = bsds,
  folds = folds,
  covariates = c(
    "weekday", "temp"
  ),
  outcome = "cnt"
)

# create tasks for taining and validation (simplifed example)
train_task <- training(task, fold = task$folds[[1]])
valid_task <- validation(task, fold = task$folds[[1]])

# instantiate learner, then fit and predict (simplifed example)
gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200)
gru_fit <- gru_lrnr$train(train_task)
gru_preds <- gru_fit$predict(valid_task)

## End(Not run)

jeremyrcoyle/sl3 documentation built on Feb. 3, 2022, 9:12 a.m.