lgb_fit: Wrapper for lgb.train.

Description Usage Arguments Value Examples

View source: R/lgb_fit.R

Description

Fit and evaluate lightgbm model with data.table as input data. Model are trained (including all preprocessing steps) on train part and evaluated on validation part according to split indicator variable.

Usage

1
2
3
4
lgb_fit(data = data, target = target, split = split,
  preproc_fun = preproc_fun, params = params, args = args,
  metrics = metrics, return_val_preds = FALSE,
  return_model_obj = FALSE, train_on_all_data = FALSE, ...)

Arguments

data

data.table with all input data.

target

Target variable name (character).

split

Indicator variable with 1 corresponds to observations in validation dataset.

preproc_fun

Preprocessing function which takes data.table data+split as input and returns processed data.table with same target and split columns.

params

1-row data.table with all hyperparameters.

args

List with parameters unchangeable during tuning.

metrics

Vector of metric functions names.

return_val_preds

If TRUE, predictions for validation data will be returned.

return_model_obj

If TRUE, model object will be returned.

train_on_all_data

If TRUE, model will be fitted on all data (without train/validation split) and model object will be returned.

...

Other parameters for kgb.train().

Value

data.table with optimal number of iterations (if early stopping is used) and all metrics calculated for validation part of the data. It also contains predictions for validation data if return_val_preds = TRUE and model object if return_model_obj = TRUE. If train_on_all_data = TRUE, only model object will be returned.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Input data
dt <- as.data.table(mtcars)
# data.table with resamples
splits <- resampleR::cv_base(dt, "hp")
# data.table with all hyperparameters
    lgb_grid <- CJ(
    learning_rate = 0.03, 
    metric = "rmse",
    num_leaves = 30,
    verbose = 1,
    subsample = 0.9,
    colsample_bytree = 0.8,
    random_state = 42,
    max_depth = c(3, 5, 7),
    lambda_l2 = 0.02,
    lambda_l1 = 0.004,
    bagging_fraction = 0.8,
    feature_fraction = 0.7,
    min_child_samples = 3,
    verbose = -1
)
# Non-tunable parameters for lightgbm
lgb_args <- list(
    nrounds = 1000,
    obj = "regression",
    early_stopping_rounds = 10,
    verbose = -1
)
# Dumb preprocessing function
# Real function will contain imputation, feature engineering etc.
# with all statistics computed on train folds and applied to validation fold
preproc_fun_example <- function(data) return(data[])
lgb_fit(data = dt,
        target = "hp",
        split = splits[, split_1],
        preproc_fun = preproc_fun_example,
        params = lgb_grid[1, ],
        args = lgb_args,
        metrics = c("rmse", "mae"),
        return_val_preds = TRUE)

statist-bhfz/grideR documentation built on Aug. 8, 2019, 7:08 p.m.