#' @title Assertion for mlr3 Objects
#' @description
#' Functions intended to be used in packages extending \pkg{mlr3}.
#' Most assertion functions ensure the right class attribute, and optionally additional properties.
#' Additionally, the following compound assertions are implemented:
#' * `assert_learnable(task, learner)`\cr
#'   ([Task], [Learner]) -> `NULL`\cr
#'   Checks if the learner is applicable to the task.
#'   This includes type checks on the type, the feature types, and properties.
#' If an assertion fails, an exception is raised.
#' Otherwise, the input object is returned invisibly.
#' @name mlr_assertions
#' @keywords internal

#' @export
#' @param b ([DataBackend]).
#' @rdname mlr_assertions
assert_backend = function(b, .var.name = vname(b)) {
  assert_class(b, "DataBackend", .var.name = .var.name)

#' @param task ([Task]).
#' @template param_feature_types
#' @param task_properties (`character()`)\cr
#'   Set of required task properties.
#' @rdname mlr_assertions
#' @export
assert_task = function(task, task_type = NULL, feature_types = NULL, task_properties = NULL, .var.name = vname(task)) {
  assert_class(task, "Task", .var.name = .var.name)

  if (!is.null(task_type) && task$task_type != task_type) {
    stopf("Task '%s' must have type '%s'", task$id, task_type)

  if (!is.null(feature_types)) {
    tmp = setdiff(task$feature_types$type, feature_types)
    if (length(tmp)) {
      stopf("Task '%s' has the following unsupported feature types: %s", task$id, str_collapse(tmp))

  if (!is.null(task_properties)) {
    tmp = setdiff(task_properties, task$properties)
    if (length(tmp)) {
      stopf("Task '%s' is missing the following properties: %s", task$id, str_collapse(tmp))


#' @export
#' @param tasks (list of [Task]).
#' @rdname mlr_assertions
assert_tasks = function(tasks, task_type = NULL, feature_types = NULL, task_properties = NULL, .var.name = vname(tasks)) {
  invisible(lapply(tasks, assert_task, task_type = task_type, feature_types = feature_types, task_properties = task_properties, .var.name = .var.name))

#' @export
#' @param learner ([Learner]).
#' @param task_type (`character(1)`).
#' @rdname mlr_assertions
assert_learner = function(learner, task = NULL, task_type = NULL, properties = character(), .var.name = vname(learner)) {
  assert_class(learner, "Learner", .var.name = .var.name)

  task_type = task_type %??% task$task_type
  # check on class(learner) does not work with GraphLearner and AutoTuner
  # check on learner$task_type does not work with TaskUnsupervised
  if (!test_matching_task_type(task_type, learner, "learner")) {
    stopf("Learner '%s' must have task type '%s'", learner$id, task_type)

  if (length(properties)) {
    miss = setdiff(properties, learner$properties)
    if (length(miss)) {
      stopf("Learner '%s' must have the properties: %s", learner$id, str_collapse(miss))


test_matching_task_type = function(task_type, object, class) {
  if (is.null(task_type) || object$task_type == task_type) {

  cl_task_type = fget(mlr_reflections$task_types, task_type, class, "type")
  if (inherits(object, cl_task_type)) {

  cl_object = fget(mlr_reflections$task_types, object$task_type, class, "type")
  return(cl_task_type == cl_object)

#' @export
#' @param learners (list of [Learner]).
#' @rdname mlr_assertions
assert_learners = function(learners, task = NULL, task_type = NULL, properties = character(), .var.name = vname(learners)) {
  invisible(lapply(learners, assert_learner, task = task, task_type = NULL, properties = properties, .var.name = .var.name))

# this does not check the validation task, as this is only possible once the validation set is known,
# which happens during worker(), so it cannot be checked before that
assert_task_learner = function(task, learner, cols = NULL) {
  pars = learner$param_set$get_values(type = "only_token", check_required = FALSE)
  if (length(pars) > 0) {
    stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars)))
  # check on class(learner) does not work with GraphLearner and AutoTuner
  # check on learner$task_type does not work with TaskUnsupervised

  if (!test_matching_task_type(task$task_type, learner, "learner")) {
    stopf("Type '%s' of %s does not match type '%s' of %s",
      task$task_type, task$format(), learner$task_type, learner$format())

  tmp = setdiff(task$feature_types$type, learner$feature_types)
  if (length(tmp) > 0) {
    stopf("%s has the following unsupported feature types: %s", task$format(), str_collapse(tmp))

  if ("missings" %nin% learner$properties) {
    miss = task$missings(cols = cols) > 0L
    if (any(miss)) {
      stopf("Task '%s' has missing values in column(s) %s, but learner '%s' does not support this",
        task$id, str_collapse(names(miss)[miss], quote = "'"), learner$id)

  tmp = mlr_reflections$task_mandatory_properties[[task$task_type]]
  if (length(tmp)) {
    tmp = setdiff(intersect(task$properties, tmp), learner$properties)
    if (length(tmp)) {
      stopf("Task '%s' has property '%s', but learner '%s' does not support that",
        task$id, tmp[1L], learner$id)

  validate = get0("validate", learner)
  if (!is.null(task$internal_valid_task) && (is.numeric(validate) || identical(validate, "test"))) {
    stopf("Parameter 'validate' of Learner '%s' cannot be set to 'test' or a ratio when internal_valid_task is present, remove it first", learner$id)

#' @export
#' @rdname mlr_assertions
assert_learnable = function(task, learner) {
  if (task$task_type == "unsupervised") {
    stopf("%s cannot be trained with %s", learner$format(), task$format())
  assert_task_learner(task, learner)

#' @export
#' @rdname mlr_assertions
assert_predictable = function(task, learner) {
  assert_task_learner(task, learner, cols = task$feature_names)

#' @export
#' @param measure ([Measure]).
#' @rdname mlr_assertions
assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vname(measure)) {
  assert_class(measure, "Measure", .var.name = .var.name)

  if (!is.null(task)) {
    if (!is_scalar_na(measure$task_type) && !test_matching_task_type(task$task_type, measure, "measure")) {
      stopf("Measure '%s' is not compatible with type '%s' of task '%s'",
        measure$id, task$task_type, task$id)

    if (measure$check_prerequisites != "ignore") {
      miss = setdiff(measure$task_properties, task$properties)
      if (length(miss) > 0) {
        warningf("Measure '%s' is missing properties %s of task '%s'",
          measure$id, str_collapse(miss, quote = "'"), task$id)

  if (!is.null(learner)) {
    if (!is_scalar_na(measure$task_type) && measure$task_type != learner$task_type) {
      stopf("Measure '%s' is not compatible with type '%s' of learner '%s'",
        measure$id, learner$task_type, learner$id)

    if (!is_scalar_na(measure$predict_type) && measure$check_prerequisites != "ignore") {
      predict_types = mlr_reflections$learner_predict_types[[learner$task_type]][[learner$predict_type]]
      if (measure$predict_type %nin% predict_types) {
        warningf("Measure '%s' is missing predict type '%s' of learner '%s'", measure$id, measure$predict_type, learner$id)

    if (measure$check_prerequisites != "ignore") {
      miss = setdiff(measure$predict_sets, learner$predict_sets)
      if (length(miss) > 0) {

        warningf("Measure '%s' needs predict sets %s, but learner '%s' only predicted on sets %s",
          measure$id, str_collapse(miss, quote = "'"), learner$id, str_collapse(learner$predict_sets, quote = "'"))


#' @export
#' @param measures (list of [Measure]).
#' @rdname mlr_assertions
assert_measures = function(measures, task = NULL, learner = NULL, .var.name = vname(measures)) {
  lapply(measures, assert_measure, task = task, learner = learner, .var.name = .var.name)
  if (anyDuplicated(ids(measures))) {
    stopf("Measures need to have unique IDs")

#' @export
#' @param resampling ([Resampling]).
#' @rdname mlr_assertions
assert_resampling = function(resampling, instantiated = NULL, .var.name = vname(resampling)) {
  assert_class(resampling, "Resampling", .var.name = .var.name)

  if (!is.null(instantiated)) {
    if (instantiated && !resampling$is_instantiated) {
      stopf("Resampling '%s' must be instantiated", resampling$id)
    if (!instantiated && resampling$is_instantiated) {
      stopf("Resampling '%s' may not be instantiated", resampling$id)


#' @export
#' @param resamplings (list of [Resampling]).
#' @rdname mlr_assertions
assert_resamplings = function(resamplings, instantiated = NULL, .var.name = vname(resamplings)) {
  invisible(lapply(resamplings, assert_resampling, instantiated = instantiated, .var.name = .var.name))

#' @export
#' @param prediction ([Prediction]).
#' @rdname mlr_assertions
assert_prediction = function(prediction, .var.name = vname(prediction)) {
  assert_class(prediction, "Prediction", .var.name = .var.name)

#' @export
#' @param rr ([ResampleResult]).
#' @rdname mlr_assertions
assert_resample_result = function(rr, .var.name = vname(rr)) {
  assert_class(rr, "ResampleResult", .var.name = .var.name)

#' @export
#' @param bmr ([BenchmarkResult]).
#' @rdname mlr_assertions
assert_benchmark_result = function(bmr, .var.name = vname(bmr)) {
  assert_class(bmr, "BenchmarkResult", .var.name = .var.name)

assert_set = function(x, empty = TRUE, .var.name = vname(x)) {
  assert_character(x, min.len = as.integer(!empty), any.missing = FALSE, min.chars = 1L, unique = TRUE, .var.name = .var.name)

assert_range = function(range, .var.name = vname(range)) {
  assert_numeric(range, len = 2L, any.missing = FALSE, .var.name = .var.name)

  if (diff(range) <= 0) {
    stopf("Invalid range specified. First value (%f) must be greater than second value (%f)", range[1L], range[2L])


#' @export
#' @template param_row_ids
#' @rdname mlr_assertions
assert_row_ids = function(row_ids, null.ok = FALSE, .var.name = vname(row_ids)) {
  assert_integerish(row_ids, coerce = TRUE, null.ok = null.ok)

assert_has_backend = function(task) {
  if (is.null(task$backend)) {
    stopf("The backend of Task '%s' has been removed. Set `store_backends` to `TRUE` during model fitting to conserve it.", task$id)

# assertion to ensure a helpful error message
assert_prediction_count = function(actual, expected, type) {
  if (actual != expected) {
    if (actual < expected) {
      stopf("Predicted %s not complete, %s for %i observations is missing",
        type, type, expected - actual)
    } else {
      stopf("Predicted %s contains %i additional predictions without matching rows",
        type, actual - expected)

assert_row_sums = function(prob) {
  for (i in seq_row(prob)) {
    x = prob[i, , drop = TRUE]
    n_missing = count_missing(x)
    if (n_missing > 0L) {
      if (n_missing < length(x)) {
        stopf("Probabilities for observation %i are partly missing", i)
    } else {
      s = sum(x)
      if (abs(s - 1) > 0.001) {
        stopf("Probabilities for observation %i do sum up to %f != 1", i, s)

assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) {
  assert_list(x, len = n_learners, .var.name = .var.name)

  ok = every(x, function(x) {
    test_list(x) && every(x, test_list, names = "unique", null.ok = TRUE)

  if (!ok) {
    stopf("'%s' must be a three-time nested list and the most inner list must be named", .var.name)
