inst/paramtest/test_paramtest_classif.xgboost.R

library(mlr3learners)
library(magrittr, exclude = c("equals", "is_less_than", "not"))
library(rvest)

add_params_xgboost = read_html("https://xgboost.readthedocs.io/en/latest/parameter.html") %>%
  html_elements(c("li", "p")) %>%
  html_text2() %>%
  grep("default=", ., value = T) %>%
  strsplit(., split = " ") %>%
  mlr3misc::map_chr(., function(x) x[1]) %>%
  gsub(",", replacement = "", .) %>%
  ## these are defined on the same line as colsample_bytree and cannot be scraped therefore
  append(values = c("colsample_bylevel", "colsample_bynode")) %>%
  # values which do not match regex
  append(values = c("interaction_constraints", "monotone_constraints", "base_score")) %>%
  # only defined in help page but not in signature or website
  append(values = c("lambda_bias"))

test_that("classif.xgboost", {
  learner = lrn("classif.xgboost", nrounds = 1)
  fun = list(xgboost::xgb.train, xgboost::xgboost, add_params_xgboost)
  exclude = c(
    "x", # handled by mlr3
    "params", # handled by mlr3
    "data", # handled by mlr3
    "obj", # handled via type parameter
    "verbosity", # handled by mlr3
    "seed", # not available in R package
    "train", # handled by mlr3
    "task", # handled by mlr3
    "model_in", # handled by mlr3
    "model_out", # handled by mlr3
    "model_dir", # handled by mlr3
    "dump_format", # CLI parameter, not for R package
    "name_dump", # CLI parameter, not for R package
    "name_pred", # CLI parameter, not for R package
    "pred_margin", # CLI parameter, not for R package
    "eval_metric", # handled by mlr3
    "label", # handled by mlr3
    "weight", # handled by mlr3
    "nthread", # handled by mlr3
    "early_stopping_set" # extra parameter of mlr3
  )

  ParamTest = run_paramtest(learner, fun, exclude, tag = "train")
  expect_true(ParamTest, info = paste0(
    "\nMissing parameters in mlr3 param set:\n",
    paste0("- ", ParamTest$missing, "\n", collapse = ""),
    "\nOutdated param or param defined in additional control function not included in list of function definitions:\n",
    paste0("- ", ParamTest$extra, "\n", collapse = ""))
    )
})

test_that("predict classif.xgboost", {
  learner = lrn("classif.xgboost")
  fun = xgboost:::predict.xgb.Booster
  exclude = c(
    "object", # handled by mlr3
    "newdata", # handled by mlr3o
    "objective" # defined in xgboost::xgboost and already in param set
  )

  ParamTest = run_paramtest(learner, fun, exclude, tag = "predict")
  expect_true(ParamTest, info = paste0(
    "\nMissing parameters in mlr3 param set:\n",
    paste0("- ", ParamTest$missing, "\n", collapse = ""),
    "\nOutdated param or param defined in additional control function not included in list of function definitions:\n",
    paste0("- ", ParamTest$extra, "\n", collapse = ""))
    )
})

Try the mlr3learners package in your browser

Any scripts or data that you put into this service are public.

mlr3learners documentation built on Nov. 21, 2023, 5:07 p.m.