across_models: Models fitting and prediction.

Description Usage Arguments Value Examples

View source: R/across_models.R

Description

Fit list of models on train part and predict on validation part of one resample.

Usage

1
2
across_models(data, target, split, models, model_params, model_args,
  preproc_funs)

Arguments

data

data.table with all input data.

target

Target variable name (character).

split

Indicator variable with 1 corresponds to observations in validation dataset.

models

Named list of fit functions from tuneR package (xgb_fit, lgb_fit etc.)

model_params

List of data.table's with tunable model parameters.

model_args

List of unchangeable model parameters.

preproc_funs

List of preprocessing functions (one function per model) which takes data.table data+split as input and returns processed data.table with same target and split columns.

Value

data.table with ground truth for validation part of the resample and predictions from all fitted models.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Input data
dt <- as.data.table(mtcars)

# data.table with resamples
splits <- resampleR::cv_base(dt, "hp")

# List of models
models <- list("xgboost" = xgb_fit, "catboost" = catboost_fit)

# Model parameters
xgb_params <- data.table(
    max_depth = 6,
    eta = 0.025,
    colsample_bytree = 0.9,
    subsample = 0.8,
    gamma = 0,
    min_child_weight = 5,
    alpha = 0,
    lambda = 1
)
xgb_args <- list(
    nrounds = 500,
    early_stopping_rounds = 10,
    booster = "gbtree",
    eval_metric = "rmse",
    objective = "reg:linear",
    verbose = 0
)

catboost_params <- data.table(
    iterations = 1000,
    learning_rate = 0.05,
    depth = 8,
    loss_function = "RMSE",
    eval_metric = "RMSE",
    random_seed = 42,
    od_type = 'Iter',
    od_wait = 10,
    use_best_model = TRUE,
    logging_level = "Silent"
)
catboost_args <- NULL

model_params <- list(xgb_params, catboost_params)
model_args <- list(xgb_args, catboost_args)

# Dumb preprocessing function
# Real function will contain imputation, feature engineering etc.
# with all statistics computed on train folds and applied to validation fold
preproc_fun_example <- function(data) return(data[])
# List of preprocessing fuctions for each model
preproc_funs <- list(preproc_fun_example, preproc_fun_example)

across_models(data = dt,
              target = "hp",
              split = splits[, split_1],
              models = models,
              model_params = model_params,
              model_args = model_args,
              preproc_funs = preproc_funs)

statist-bhfz/stackeR documentation built on Aug. 7, 2019, 4:57 a.m.