cross_train_predict: Train and predict out-of-fold so that predictions can be used...

Description Usage Arguments Value Examples

Description

Train and predict out-of-fold so that predictions can be used in downstream models. Parallelise by calling registerDoMC(n_cores) before the cross_train_predict call.

Usage

1
cross_train_predict(data, fold_id, train_model, predict_model = predict)

Arguments

data

any training data format that can be indexed by [i, ]

fold_id

integer sequence denoting folds starting at 1

train_model

a function with signature data => model

predict_model

a function with signature (model, data) => predictions. Defaults to predict.

Value

out of fold model predictions

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
library("xgboost")
library("ggplot2")
library("foreach")
library("doMC")

data(diamonds)

set.seed(1234)
train <- diamonds[order(runif(nrow(diamonds))), ][1:10000, ]

train_model <- function(data) {
  train_mm <- model.matrix(price ~ ., data)
  train_y  <- train$price
  train_dm <- xgb.DMatrix(train_mm, label = data$price)

  xgb.train(params = list(max_depth = 2,
                          subsample = 0.5,
                          eta = 0.1,
                          colsample_bytree = 0.5,
                          objective = "reg:linear"),
            nrounds = 100,
            data = train_dm)
}

predict_model <- function(model, data) {
  score_mm <- model.matrix(price ~ ., data)
  score_dm <- xgb.DMatrix(score_mm)

  predict(model, newdata = score_dm)
}

fold_id <- sample(1:10, nrow(train), TRUE)
ct_preds <- cross_train_predict(train, fold_id, train_model, predict_model)

DexGroves/hacktoolkit documentation built on May 6, 2019, 2:12 p.m.