train_model: Train a model across horizons and validation datasets

Description Usage Arguments Value Methods and related functions Examples

View source: R/train_model.R

Description

Train a user-defined forecast model for each horizon, 'h', and across the validation datasets, 'd'. If method = "direct", a total of 'h' * 'd' models are trained. If method = "multi_output", a total of 1 * 'd' models are trained. These models can be trained in parallel with the future package.

Usage

1
2
3
4
5
6
7
8
train_model(
  lagged_df,
  windows,
  model_name,
  model_function,
  ...,
  use_future = FALSE
)

Arguments

lagged_df

An object of class 'lagged_df' from create_lagged_df.

windows

An object of class 'windows' from create_windows.

model_name

A name for the model.

model_function

A user-defined wrapper function for model training that takes the following arguments: (1) a horizon-specific data.frame made with create_lagged_df(..., type = "train") (i.e., the dataset(s) stored in lagged_df) and, optionally, (2) any number of additional named arguments which can be passed in ... in this function.

...

Optional. Named arguments passed into the user-defined model_function.

use_future

Boolean. If TRUE, the future package is used for training models in parallel. The models will train in parallel across either (1) model forecast horizons or (b) validation windows, whichever is longer (i.e., length(create_lagged_df()) or nrow(create_windows())). The user should run future::plan(future::multiprocess) or similar prior to this function to train these models in parallel.

Value

An S3 object of class 'forecast_model': A nested list of trained models. Models can be accessed with my_trained_model$horizon_h$window_w$model where 'h' gives the forecast horizon and 'w' gives the validation dataset window number from create_windows().

Methods and related functions

The output of train_model can be passed into

and has the following generic S3 methods

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Sampled Seatbelts data from the R package datasets.
data("data_seatbelts", package = "forecastML")

# Example - Training data for 2 horizon-specific models w/ common lags per predictor.
horizons <- c(1, 12)
lookback <- 1:15

data_train <- create_lagged_df(data_seatbelts, type = "train", outcome_col = 1,
                               lookback = lookback, horizon = horizons)

# One custom validation window at the end of the dataset.
windows <- create_windows(data_train, window_start = 181, window_stop = 192)

# User-define model - LASSO
# A user-defined wrapper function for model training that takes the following
# arguments: (1) a horizon-specific data.frame made with create_lagged_df(..., type = "train")
# (e.g., my_lagged_df$horizon_h) and, optionally, (2) any number of additional named arguments
# which are passed as '...' in train_model().
library(glmnet)
model_function <- function(data, my_outcome_col) {

  x <- data[, -(my_outcome_col), drop = FALSE]
  y <- data[, my_outcome_col, drop = FALSE]
  x <- as.matrix(x, ncol = ncol(x))
  y <- as.matrix(y, ncol = ncol(y))

  model <- glmnet::cv.glmnet(x, y, nfolds = 3)
  return(model)
}

# my_outcome_col = 1 is passed in ... but could have been defined in model_function().
model_results <- train_model(data_train, windows, model_name = "LASSO", model_function,
                             my_outcome_col = 1)

# View the results for the model (a) trained on the first horizon
# and (b) to be assessed on the first outer-loop validation window.
model_results$horizon_1$window_1$model

Example output

Loading required package: dplyr

Attaching package:dplyrThe following objects are masked frompackage:stats:

    filter, lag

The following objects are masked frompackage:base:

    intersect, setdiff, setequal, union

Loading required package: Matrix
Loaded glmnet 4.0-2

Call:  glmnet::cv.glmnet(x = x, y = y, nfolds = 3) 

Measure: Mean-Squared Error 

    Lambda Measure    SE Nonzero
min  2.454   301.1 15.91       7
1se  4.288   316.0 17.84       5

forecastML documentation built on July 8, 2020, 7:27 p.m.