
Defines functions print.Prediction makePrediction.CostSensTaskDesc makePrediction.ClusterTaskDesc makePrediction.SurvTaskDesc makePrediction.MultilabelTaskDesc makePrediction.ClassifTaskDesc makePrediction.RegrTaskDesc makePrediction

Documented in makePrediction

#' @title Prediction object.
#' @description
#' Result from [predict.WrappedModel].
#' Use `as.data.frame` to access all information in a convenient format.
#' The function [getPredictionProbabilities] is useful to access predicted probabilities.
#' The `data` member of the object contains always the following columns:
#' `id`, index numbers of predicted cases from the task, `response`
#' either a numeric or a factor, the predicted response values, `truth`,
#' either a numeric or a factor, the true target values.
#' If probabilities were predicted, as many numeric columns as there were classes named
#' `prob.classname`. If standard errors were predicted, a numeric column named `se`.
#' The constructor `makePrediction` is mainly for internal use.
#' Object members:
#' \describe{
#' \item{predict.type (`character(1)`)}{Type set in [setPredictType].}
#' \item{data ([data.frame])}{See details.}
#' \item{threshold (`numeric(1)`)}{Threshold set in predict function.}
#' \item{task.desc ([TaskDesc])}{Task description object.}
#' \item{time (`numeric(1)`)}{Time learner needed to generate predictions.}
#' \item{error (`character(1)`)}{Any error messages generated by the learner (default NA_character_).}
#' }
#' @name Prediction
#' @rdname Prediction

#' @keywords internal
#' @rdname Prediction
#' @description
#' Internal, do not use!
#' @export
makePrediction = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {

#' @export
makePrediction.RegrTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
  data = namedList(c("id", "truth", "response", "se"))
  data$id = id
  data$truth = truth
  if (predict.type == "response") {
    data$response = y
  } else {
    data$response = y[, 1L]
    data$se = y[, 2L]

  makeS3Obj(c("PredictionRegr", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(as.data.frame(filterNull(data)), row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump

#' @export
makePrediction.ClassifTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {

  data = namedList(c("id", "truth", "response", "prob"))
  data$id = id
  # truth can come from a simple "newdata" df. then there might not be all factor levels present
  if (!is.null(truth)) {
    levels(truth) = union(levels(truth), task.desc$class.levels)
  data$truth = truth
  if (predict.type == "response") {
    data$response = y
    data = as.data.frame(filterNull(data))
  } else {
    data$prob = y
    data = as.data.frame(filterNull(data))
    # fix columnnames for prob if strange chars are in factor levels
    indices = stri_detect_fixed(names(data), "prob.")
    if (sum(indices) > 0) {
      names(data)[indices] = stri_paste("prob.", colnames(y))

  p = makeS3Obj(c("PredictionClassif", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(data, row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump

  if (predict.type == "prob") {
    # set default threshold to 1/k
    if (is.null(predict.threshold)) {
      predict.threshold = rep(1 / length(task.desc$class.levels), length(task.desc$class.levels))
      names(predict.threshold) = task.desc$class.levels
    p = setThreshold(p, predict.threshold)

#' @export
makePrediction.MultilabelTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
  data = namedList(c("id", "truth", "response", "prob"))
  data$id = id
  data$truth = truth
  if (predict.type == "response") {
    data$response = y
  } else {
    data$prob = y

  p = makeS3Obj(c("PredictionMultilabel", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(as.data.frame(filterNull(data)), row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump
  if (predict.type == "prob") {
    # set default threshold to 0.5
    if (is.null(predict.threshold)) {
      predict.threshold = rep(0.5, length(task.desc$class.levels))
      names(predict.threshold) = task.desc$class.levels
    p = setThreshold(p, predict.threshold)

#' @export
makePrediction.SurvTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
  data = namedList(c("id", "truth.time", "truth.event", "response"))
  data$id = id
  # FIXME: recode times
  data$truth.time = truth[, 1L]
  data$truth.event = truth[, 2L]
  data$response = y

  makeS3Obj(c("PredictionSurv", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(as.data.frame(filterNull(data)), row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump

#' @export
makePrediction.ClusterTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
  data = namedList(c("id", "response", "prob"))
  data$id = id
  if (predict.type == "response") {
    data$response = y
    data = as.data.frame(filterNull(data))
  } else {
    # this is a bit uncool, but as long we only use cl_predict we are OK I guess
    class(y) = "matrix"
    data$prob = y
    data$response = getMaxIndexOfRows(y)
    data = as.data.frame(filterNull(data))
  p = makeS3Obj(c("PredictionCluster", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(data, row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump

#' @export
makePrediction.CostSensTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
  data = namedList(c("id", "response"))
  data$id = id
  data$response = y

  makeS3Obj(c("PredictionCostSens", "Prediction"),
    predict.type = predict.type,
    data = setRowNames(as.data.frame(filterNull(data)), row.names),
    threshold = NA_real_,
    task.desc = task.desc,
    time = time,
    error = error,
    dump = dump

#' @export
print.Prediction = function(x, ...) {
  catf("Prediction: %i observations", nrow(x$data))
  catf("predict.type: %s", x$predict.type)
  catf("threshold: %s", collapse(sprintf("%s=%.2f", names(x$threshold), x$threshold)))
  catf("time: %.2f", x$time)
  if (!is.na(x$error)) catf("errors: %s", x$error)
  printHead(as.data.frame(x), ...)
