proj_grid | R Documentation |
A project grid consists of all combinations of tasks, learners, resampling types, and resampling iterations, to be computed in parallel. This function creates a project directory with files to describe the grid.
proj_grid(
proj_dir, tasks, learners, resamplings,
order_jobs = NULL, score_args = NULL,
save_learner = FALSE, save_pred = FALSE)
proj_dir |
Path to directory to create. |
tasks |
List of Tasks, or a single Task. |
learners |
List of Learners, or a single Learner. |
resamplings |
List of Resamplings, or a single Resampling. |
order_jobs |
Function which takes split table as input, and
returns integer vector of row numbers of the split table to write to
|
score_args |
Passed to |
save_learner |
Function to process Learner, after
training/prediction, but before saving result to disk.
For interpreting complex models, you should write a function that
returns only the parts of the model that you need (and discards the
other parts which would take up disk space for no reason).
Default |
save_pred |
Function to process Prediction before saving to disk.
Default |
This is Step 1 out of the
typical 3 step pipeline (init grid, submit, read results).
It creates a grid_jobs.csv
table which has a column status
;
each row is initialized to "not started"
or "done"
,
depending on whether the corresponding result RDS file exists already.
Data table of splits to be processed (same as table saved to grid_jobs.rds
).
Toby Dylan Hocking
N <- 80
library(data.table)
set.seed(1)
reg.dt <- data.table(
x=runif(N, -2, 2),
person=factor(rep(c("Alice","Bob"), each=0.5*N)))
reg.pattern.list <- list(
easy=function(x, person)x^2,
impossible=function(x, person)(x^2)*(-1)^as.integer(person))
SOAK <- mlr3resampling::ResamplingSameOtherSizesCV$new()
reg.task.list <- list()
for(pattern in names(reg.pattern.list)){
f <- reg.pattern.list[[pattern]]
task.dt <- data.table(reg.dt)[
, y := f(x,person)+rnorm(N, sd=0.5)
][]
task.obj <- mlr3::TaskRegr$new(
pattern, task.dt, target="y")
task.obj$col_roles$feature <- "x"
task.obj$col_roles$stratum <- "person"
task.obj$col_roles$subset <- "person"
reg.task.list[[pattern]] <- task.obj
}
reg.learner.list <- list(
featureless=mlr3::LearnerRegrFeatureless$new())
if(requireNamespace("rpart")){
reg.learner.list$rpart <- mlr3::LearnerRegrRpart$new()
}
pkg.proj.dir <- tempfile()
mlr3resampling::proj_grid(
pkg.proj.dir,
reg.task.list,
reg.learner.list,
SOAK,
score_args=mlr3::msrs(c("regr.rmse", "regr.mae")))
mlr3resampling::proj_compute(pkg.proj.dir)
fread(file.path(pkg.proj.dir, "grid_jobs.csv"))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.