R/zzz.R

Defines functions .onUnload .onLoad register_mlr3graphs register_mlr3pipelines register_mlr3

#' @rawNamespace import(data.table, except = transpose)
#' @import checkmate
#' @import mlr3
#' @import mlr3misc
#' @import paradox
#' @import ordinal
#' @import mlr3pipelines
#' @import nloptr
#' @importFrom R6 R6Class
"_PACKAGE"

register_mlr3 = function() {
  # let mlr3 know about ordinal
  x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")
  x$task_types = setkeyv(rbind(x$task_types, rowwise_table(
    ~type,     ~package,      ~task,         ~learner,         ~prediction,         ~measure,
    "ordinal", "mlr3ordinal", "TaskOrdinal", "LearnerOrdinal", "PredictionOrdinal", "MeasureOrdinal"
  )), "type")
  x$task_col_roles$ordinal = c("feature", "target", "order", "stratum", "group", "weight")
  x$learner_properties$ordinal = x$learner_properties$classif
  x$task_properties$ordinal = c("weights")
  x$learner_properties$ordinal = c("missings", "weights", "parallel", "importance") # FIXME for ordinal
  x$learner_predict_types$ordinal = list(response = "response", prob = c("response", "prob"))
  x$task_col_roles$regr = union(x$task_col_roles$regr, "target_ordinal")
  x$measure_properties$ordinal = x$measure_properties$regr
  x$default_measures$ordinal = "ordinal.ce"

  # tasks
  x = utils::getFromNamespace("mlr_tasks", ns = "mlr3")

  x$add("winerating", load_task_winerating)

  # learners
  x = utils::getFromNamespace("mlr_learners", ns = "mlr3")

  x$add("ordinal.clm", LearnerOrdinalClm)

  # measures
  x = utils::getFromNamespace("mlr_measures", ns = "mlr3")

  x$add("ordinal.ce", MeasureOrdinalCE)
  x$add("ordinal.acc", MeasureOrdinalACC)
  x$add("ordinal.mae", MeasureOrdinalMAE)
}

register_mlr3pipelines = function() {
  # pipeops
  x = utils::getFromNamespace("mlr_pipeops", ns = "mlr3pipelines")

  x$add("ordinalregr", PipeOpOrdinalRegr)
  x$add("ordinalclassif", PipeOpOrdinalClassif)
  x$add("update_target", PipeOpUpdateTarget)
}

register_mlr3graphs = function() {
  # pipeops
  x = utils::getFromNamespace("mlr_graphs", ns = "mlr3pipelines")

  x$add("ordinal", pipeline_ordinal)
}

.onLoad = function(libname, pkgname) { # nocov start
  register_mlr3()
  setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(), action = "append")
  register_mlr3pipelines()
  register_mlr3graphs()
  setHook(packageEvent("mlr3pipelines", "onLoad"), function(...) {register_mlr3pipelines(); register_mlr3graphs()}, action = "append")
} # nocov end

.onUnload = function(libpath) { # nocov start
  event = packageEvent("mlr3", "onLoad")
  hooks = getHook(event)
  pkgname = vapply(hooks, function(x) environment(x)$pkgname, NA_character_)
  setHook(event, hooks[pkgname != "mlr3ordinal"], action = "replace")

  event = packageEvent("mlr3pipelines", "onLoad")
  hooks = getHook(event)
  pkgname = vapply(hooks, function(x) environment(x)$pkgname, NA_character_)
  setHook(event, hooks[pkgname != "mlr3ordinal"], action = "replace")
} # nocov end
mlr-org/mlr3ordinal documentation built on Jan. 10, 2023, 10:04 a.m.