Lrnr_multiple_ts: Stratify univariable time-series learners by time-series

Lrnr_multiple_tsR Documentation

Stratify univariable time-series learners by time-series

Description

Stratify univariable time-series learners by time-series

Format

R6Class object.

Value

Learner object with methods for training and prediction. See Lrnr_base for documentation on learners.

Parameters

learner="learner"

An initialized Lrnr_* object.

variable_stratify="variable_stratify"

A character giving the variable in the covariates on which to stratify. Supports only variables with discrete levels coded as numeric.

...

Other parameters passed directly to learner$train. See its documentation for details.

See Also

Other Learners: Custom_chain, Lrnr_HarmonicReg, Lrnr_arima, Lrnr_bartMachine, Lrnr_base, Lrnr_bayesglm, Lrnr_caret, Lrnr_cv_selector, Lrnr_cv, Lrnr_dbarts, Lrnr_define_interactions, Lrnr_density_discretize, Lrnr_density_hse, Lrnr_density_semiparametric, Lrnr_earth, Lrnr_expSmooth, Lrnr_gam, Lrnr_ga, Lrnr_gbm, Lrnr_glm_fast, Lrnr_glm_semiparametric, Lrnr_glmnet, Lrnr_glmtree, Lrnr_glm, Lrnr_grfcate, Lrnr_grf, Lrnr_gru_keras, Lrnr_gts, Lrnr_h2o_grid, Lrnr_hal9001, Lrnr_haldensify, Lrnr_hts, Lrnr_independent_binomial, Lrnr_lightgbm, Lrnr_lstm_keras, Lrnr_mean, Lrnr_multivariate, Lrnr_nnet, Lrnr_nnls, Lrnr_optim, Lrnr_pca, Lrnr_pkg_SuperLearner, Lrnr_polspline, Lrnr_pooled_hazards, Lrnr_randomForest, Lrnr_ranger, Lrnr_revere_task, Lrnr_rpart, Lrnr_rugarch, Lrnr_screener_augment, Lrnr_screener_coefs, Lrnr_screener_correlation, Lrnr_screener_importance, Lrnr_sl, Lrnr_solnp_density, Lrnr_solnp, Lrnr_stratified, Lrnr_subset_covariates, Lrnr_svm, Lrnr_tsDyn, Lrnr_ts_weights, Lrnr_xgboost, Pipeline, Stack, define_h2o_X(), undocumented_learner

Examples

library(origami)
library(dplyr)

set.seed(123)

# Simulate simple AR(2) process
data <- matrix(arima.sim(model = list(ar = c(0.9, -0.2)), n = 200))
id <- c(rep("Series_1", 50), rep("Series_2", 50), rep("Series_3", 50), rep("Series_4", 50))
data <- data.frame(data)
data$id <- as.factor(id)
data <- data %>%
  group_by(id) %>%
  dplyr::mutate(time = 1:n())

data$W1 <- rbinom(200, 1, 0.6)
data$W2 <- rbinom(200, 1, 0.2)

folds <- origami::make_folds(data,
  t = max(data$time),
  id = data$id,
  time = data$time,
  fold_fun = folds_rolling_window_pooled,
  window_size = 20,
  validation_size = 15,
  gap = 0,
  batch = 10
)

task <- sl3_Task$new(
  data = data, outcome = "data",
  time = "time", id = "id",
  covariates = c("W1", "W2"),
  folds = folds
)

train_task <- training(task, fold = task$folds[[1]])
valid_task <- validation(task, fold = task$folds[[1]])

lrnr_arima <- Lrnr_arima$new()
multiple_ts_arima <- Lrnr_multiple_ts$new(learner = lrnr_arima)

multiple_ts_arima_fit <- multiple_ts_arima$train(train_task)
multiple_ts_arima_preds <- multiple_ts_arima_fit$predict(valid_task)

tlverse/sl3 documentation built on Nov. 18, 2024, 12:46 a.m.