R/PredictionData.R

Defines functions filter_prediction_data is_missing_prediction_data check_prediction_data print.PredictionData create_empty_prediction_data new_prediction_data

Documented in check_prediction_data create_empty_prediction_data filter_prediction_data is_missing_prediction_data

#' @title Convert to PredictionData
#'
#' @name PredictionData
#' @rdname PredictionData
#'
#' @description
#' Objects of type `PredictionData` serve as a intermediate representation for objects of type [Prediction].
#' It is an internal data structure, implemented to optimize runtime and solve some issues emerging while serializing R6 objects.
#' End-users typically do not need to worry about the details, package developers are advised to continue reading for some technical information.
#'
#' Unlike most other \CRANpkg{mlr3} objects, `PredictionData` relies on the S3 class system.
#' The following operations must be supported to extend mlr3 for new task types:
#'
#' * [as_prediction_data()] converts objects to class `PredictionData`, e.g. objects of type [Prediction].
#' * [as_prediction()] converts objects to class [Prediction], e.g. objects of type `PredictionData`.
#' * `check_prediction_data()` is called on the return value of the predict method of a [Learner] to perform assertions and type conversions.
#'   Returns an update object of class `PredictionData`.
#' * `is_missing_prediction_data()` is used for the fallback learner (see [Learner]) to impute missing predictions. Returns vector with row ids which need imputation.
#'
#'
NULL

new_prediction_data = function(li, task_type) {
  li = discard(li, is.null)
  class(li) = c(fget(mlr_reflections$task_types, task_type, "prediction_data", "type"), "PredictionData")
  li
}

#' @rdname PredictionData
#'
#' @param task ([Task]).
#' @param learner ([Learner]).
#'
#' @export
create_empty_prediction_data = function(task, learner) {
  UseMethod("create_empty_prediction_data")
}

#' @export
print.PredictionData = function(x, ...) {
  catf("<%s:%i>", class(x)[1L], length(x$row_ids))
}

#' @rdname PredictionData
#' @param pdata ([PredictionData])\cr
#'   Named list inheriting from `"PredictionData"`.
#' @export
check_prediction_data = function(pdata, ...) {
  UseMethod("check_prediction_data")
}

#' @rdname PredictionData
#' @export
is_missing_prediction_data = function(pdata, ...) {
  UseMethod("is_missing_prediction_data")
}

#' @rdname PredictionData
#' @template param_row_ids
#' @export
filter_prediction_data = function(pdata, row_ids, ...) {
  UseMethod("filter_prediction_data")
}

Try the mlr3 package in your browser

Any scripts or data that you put into this service are public.

mlr3 documentation built on Oct. 18, 2024, 5:11 p.m.