proj_test: Test a project with smaller data and fewer resampling...

View source: R/proj.R

proj_testR Documentation

Test a project with smaller data and fewer resampling iterations

Description

Runs proj_grid to create a new project in the test sub-directory, with a smaller number of samples in each task, and with only one iteration per Resampling. Runs proj_compute_all on this new test project, and then reads any CSV result files.

Usage

proj_test(proj_dir, min_samples_per_stratum = 10,
 edit_learner=edit_learner_default, max_jobs=Inf)

Arguments

proj_dir

Project directory created by proj_grid.

min_samples_per_stratum

Minimum number of samples to include in the smallest stratum. Other strata will be down-sampled proportionally.

edit_learner

Function which inputs a learner object, and changes it to take less time for testing. Default calls edit_learner method if it exists, or for AutoTuner based on LearnerTorch, reduces max epochs and patience to 2.

max_jobs

Numeric, max number of jobs to test (default Inf).

Value

Same value as proj_fread on test project (list of data tables).

Author(s)

Toby Dylan Hocking

Examples


library(data.table)
N <- 8000
set.seed(1)
reg.dt <- data.table(
  x=runif(N, -2, 2),
  person=factor(rep(c("Alice","Bob"), c(0.1,0.9)*N)))
reg.pattern.list <- list(
  easy=function(x, person)x^2,
  impossible=function(x, person)(x^2)*(-1)^as.integer(person))
kfold <- mlr3::ResamplingCV$new()
kfold$param_set$values$folds <- 2
reg.task.list <- list()
for(pattern in names(reg.pattern.list)){
  f <- reg.pattern.list[[pattern]]
  task.dt <- data.table(reg.dt)[
  , y := f(x,person)+rnorm(N, sd=0.5)
  ][]
  task.obj <- mlr3::TaskRegr$new(
    pattern, task.dt, target="y")
  task.obj$col_roles$feature <- "x"
  task.obj$col_roles$stratum <- "person"
  task.obj$col_roles$subset <- "person"
  reg.task.list[[pattern]] <- task.obj
}
reg.learner.list <- list(
  featureless=mlr3::LearnerRegrFeatureless$new())
if(requireNamespace("rpart")){
  reg.learner.list$rpart <- mlr3::LearnerRegrRpart$new()
}
pkg.proj.dir <- tempfile()
mlr3resampling::proj_grid(
  pkg.proj.dir,
  reg.task.list,
  reg.learner.list,
  kfold,
  save_learner=function(L){
    if(inherits(L, "LearnerRegrRpart")){
      list(rpart=L$model$frame)
    }
  },
  score_args=mlr3::msrs(c("regr.rmse", "regr.mae")))
mlr3resampling::proj_test(pkg.proj.dir)


mlr3resampling documentation built on Nov. 21, 2025, 1:07 a.m.