proj_grid: Initialize a new project grid table

View source: R/proj.R

proj_gridR Documentation

Initialize a new project grid table

Description

A project grid consists of all combinations of tasks, learners, resampling types, and resampling iterations, to be computed in parallel. This function creates a project directory with files to describe the grid.

Usage

proj_grid(
  proj_dir, tasks, learners, resamplings,
  order_jobs = NULL, score_args = NULL,
  save_learner = FALSE, save_pred = FALSE)

Arguments

proj_dir

Path to directory to create.

tasks

List of Tasks, or a single Task.

learners

List of Learners, or a single Learner.

resamplings

List of Resamplings, or a single Resampling.

order_jobs

Function which takes split table as input, and returns integer vector of row numbers of the split table to write to grid_jobs.csv, which is how worker processes determine what work to do next (smaller numbers have higher priority). Default NULL means to keep default order.

score_args

Passed to pred$score().

save_learner

Function to process Learner, after training/prediction, but before saving result to disk. For interpreting complex models, you should write a function that returns only the parts of the model that you need (and discards the other parts which would take up disk space for no reason). Default FALSE means to not keep it (always returns NULL). TRUE means to keep it without any special processing.

save_pred

Function to process Prediction before saving to disk. Default FALSE means to not keep it (always returns NULL). TRUE means to keep it without any special processing.

Details

This is Step 1 out of the typical 3 step pipeline (init grid, submit, read results). It creates a grid_jobs.csv table which has a column status; each row is initialized to "not started" or "done", depending on whether the corresponding result RDS file exists already.

Value

Data table of splits to be processed (same as table saved to grid_jobs.rds).

Author(s)

Toby Dylan Hocking

Examples


N <- 80
library(data.table)
set.seed(1)
reg.dt <- data.table(
  x=runif(N, -2, 2),
  person=factor(rep(c("Alice","Bob"), each=0.5*N)))
reg.pattern.list <- list(
  easy=function(x, person)x^2,
  impossible=function(x, person)(x^2)*(-1)^as.integer(person))
SOAK <- mlr3resampling::ResamplingSameOtherSizesCV$new()
reg.task.list <- list()
for(pattern in names(reg.pattern.list)){
  f <- reg.pattern.list[[pattern]]
  task.dt <- data.table(reg.dt)[
  , y := f(x,person)+rnorm(N, sd=0.5)
  ][]
  task.obj <- mlr3::TaskRegr$new(
    pattern, task.dt, target="y")
  task.obj$col_roles$feature <- "x"
  task.obj$col_roles$stratum <- "person"
  task.obj$col_roles$subset <- "person"
  reg.task.list[[pattern]] <- task.obj
}
reg.learner.list <- list(
  featureless=mlr3::LearnerRegrFeatureless$new())
if(requireNamespace("rpart")){
  reg.learner.list$rpart <- mlr3::LearnerRegrRpart$new()
}

pkg.proj.dir <- tempfile()
mlr3resampling::proj_grid(
  pkg.proj.dir,
  reg.task.list,
  reg.learner.list,
  SOAK,
  score_args=mlr3::msrs(c("regr.rmse", "regr.mae")))
mlr3resampling::proj_compute(pkg.proj.dir)
fread(file.path(pkg.proj.dir, "grid_jobs.csv"))


mlr3resampling documentation built on June 23, 2025, 5:08 p.m.