#' @title Abstract Prediction Object
#'
#' @usage NULL
#' @format [R6::R6Class] object.
#'
#' @description
#' This is the abstract base class for task objects like [PredictionClassif] or [PredictionRegr].
#'
#' Prediction objects store the following information:
#' 1. The row ids of the test set
#' 2. The corresponding true (observed) response.
#' 3. The corresponding predicted response.
#' 4. Additional predictions based on the class and `predict_type`.
#' E.g., the class probabilities for classification or the estimated standard error for regression.
#'
#' @section Construction:
#' This object is constructed via a derived classes, e.g. [PredictionClassif] or [PredictionRegr].
#'
#' @section Fields:
#' * `row_ids` :: (`integer()` | `character()`)\cr
#' Vector of row ids for which predictions are stored.
#'
#' * `truth` :: `any`\cr
#' True (observed) outcome.
#'
#' * `task_type` :: `character(1)`\cr
#' Stores the type of the [Task].
#'
#' * `predict_types` :: `character()`\cr
#' Vector of predict types this object stores.
#'
#' * `missing` :: `logical()`\cr
#' Returns `row_ids` for which the predictions are missing or incomplete.
#'
#' @section Methods:
#' * `score(measures = NULL, task = NULL, learner = NULL)`\cr
#' (list of [Measure], [Task], [Learner]) -> [Prediction]\cr
#' Calculates the performance for all provided measures
#' [Task] and [Learner] may be `NULL` for most measures, but some measures need to extract information
#' from these objects.
#'
#' @section S3 Methods:
#' * `as.data.table(rr)`\cr
#' [Prediction] -> [data.table::data.table()]\cr
#' Converts the data to a [data.table::data.table()].
#'
#' * `c(..., keep_duplicates = TRUE)`\cr
#' ([Prediction], [Prediction], ...) -> [Prediction]\cr
#' Combines multiple `Prediction`s to a single `Prediction`.
#' If `keep_duplicates` is `FALSE` and there are duplicated row ids,
#' the data of the former passed objects get overwritten by the data of the later passed objects.
#'
#' @export
#' @family Prediction
Prediction = R6Class("Prediction",
public = list(
data = list(),
task_type = NULL,
predict_types = character(),
format = function() {
sprintf("<%s>", class(self)[1L])
},
print = function(...) {
if (!nrow(self$data$tab)) {
catf("%s for 0 observations", format(self))
} else {
data = as.data.table(self)
catf("%s for %i observations:", format(self), nrow(data))
print(data, nrows = 10L, topn = 3L, class = FALSE, row.names = FALSE, print.keys = FALSE)
}
},
score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) {
measures = assert_measures(as_measures(measures, task_type = self$task_type), task = task, learner = learner)
scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set))
set_names(scores, ids(measures))
}
),
active = list(
row_ids = function() self$data$tab$row_id,
truth = function() self$data$tab$truth,
missing = function() self$data$tab$row_id[0L] # empty vector
)
)
#' @export
c.Prediction = function(..., keep_duplicates = TRUE) {
cl = class(list(...)[[1L]])[1L]
stopf("c.Prediction not implemented for '%s'", cl)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.