knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.path = "man/figures/README-",
  out.width = "100%"
)

grideR

The goal of grideR is to provides general infrastructure based on data.table for grid search in model hyperparameters space. Key feature: all preprocessing steps are performed within resamples. This package aims to be as easy to extend as possible.

Installation

You can install dev version of grideR from GitHub with:

devtools::install_github("statist-bhfz/grideR")

Example

library(resampleR)
library(grideR)
library(parallel)

# Input data
dt <- as.data.table(mtcars)
# data.table with resamples
splits <- resampleR::cv_base(dt, "hp")
# data.table with tunable model hyperparameters
xgb_grid <- CJ(
    max_depth = c(6, 8),
    eta = 0.025,
    colsample_bytree = 0.9,
    subsample = 0.8,
    gamma = 0,
    min_child_weight = c(3, 5),
    alpha = 0,
    lambda = 1
)
# Non-tunable parameters for xgboost
xgb_args <- list(
    nrounds = 500,
    early_stopping_rounds = 10,
    booster = "gbtree",
    eval_metric = "rmse",
    objective = "reg:linear",
    verbose = 0
)
# Dumb preprocessing function
# Real function will contain imputation, feature engineering etc.
# with all statistics computed on train folds and applied to validation fold
preproc_fun_example <- function(data) return(data[])

cl <- makePSOCKcluster(2L)
clusterExport(cl, list("dt", "splits", 
                       "xgb_grid", "xgb_args", 
                       "preproc_fun_example"))
clusterEvalQ(cl, {
    suppressMessages(library(resampleR))
    suppressMessages(library(grideR))
    suppressMessages(library(checkmate))
    suppressMessages(library(data.table))
    suppressMessages(library(xgboost))
})

res <- clusterApply(
    cl, 
    splits, 
    function(split) across_grid(data = dt,
                                target = "hp",
                                split = split,
                                fit_fun = xgb_fit,
                                preproc_fun = preproc_fun_example,
                                grid = xgb_grid,
                                args = xgb_args,
                                metrics = c("rmse", "mae")))
res <- rbindlist(res, idcol = "split")
res

stopCluster(cl)


statist-bhfz/grideR documentation built on Aug. 8, 2019, 7:08 p.m.