catboost_fit: Wrapper for catboost.train.

Description Usage Arguments Value Examples

View source: R/catboost_fit.R

Description

Fit and evaluate catboost model with data.table as input data. Model are trained (including all preprocessing steps) on train part and evaluated on validation part according to split indicator variable.

Usage

1
2
3
4
catboost_fit(data = data, target = target, split = split,
  preproc_fun = preproc_fun, params = params, args = NULL,
  metrics = metrics, return_val_preds = FALSE,
  return_model_obj = FALSE, train_on_all_data = FALSE, ...)

Arguments

data

data.table with all input data.

split

Indicator variable with 1 corresponds to observations in validation dataset.

preproc_fun

Preprocessing function which takes data.table data+split as input and returns processed data.table with same target and split columns.

params

1-row data.table with all hyperparameters.

args

NULL value for consistency with xgb_fit().

metrics

Vector of metric functions names.

return_val_preds

If TRUE, predictions for validation data will be returned.

return_model_obj

If TRUE, model object will be returned.

train_on_all_data

If TRUE, model will be fitted on all data (without train/validation split) and model object will be returned.

...

Other parameters for catboost.train().

y

Target variable name (character).

Value

data.table with optimal number of iterations (if early stopping is used) and all metrics calculated for validation part of the data. It also contains predictions for validation data if return_val_preds = TRUE and model object if return_model_obj = TRUE. If train_on_all_data = TRUE, only model object will be returned.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Input data
dt <- as.data.table(mtcars)
# data.table with resamples
splits <- resampleR::cv_base(dt, "hp")
# data.table with all hyperparameters
catboost_grid <- CJ(
    iterations = 1000,
    learning_rate = 0.05,
    depth = c(8, 9),
    loss_function = "RMSE",
    eval_metric = "RMSE",
    random_seed = 42,
    od_type = "Iter",
    # metric_period = 50,
    od_wait = 10,
    use_best_model = TRUE,
    logging_level = "Silent"
) 
# Dumb preprocessing function
# Real function will contain imputation, feature engineering etc.
# with all statistics computed on train folds and applied to validation fold
preproc_fun_example <- function(data) return(data[])
catboost_fit(data = dt,
             target = "hp",
             split = splits[, split_1],
             preproc_fun = preproc_fun_example,
             params = catboost_grid[1, ],
             metrics = c("rmse", "mae"),
             return_val_preds = TRUE)

statist-bhfz/grideR documentation built on Aug. 8, 2019, 7:08 p.m.