#' @title Prediction object.
#'
#' @description
#' Result from [predict.WrappedModel].
#' Use `as.data.frame` to access all information in a convenient format.
#' The function [getPredictionProbabilities] is useful to access predicted probabilities.
#'
#' The `data` member of the object contains always the following columns:
#' `id`, index numbers of predicted cases from the task, `response`
#' either a numeric or a factor, the predicted response values, `truth`,
#' either a numeric or a factor, the true target values.
#' If probabilities were predicted, as many numeric columns as there were classes named
#' `prob.classname`. If standard errors were predicted, a numeric column named `se`.
#'
#' The constructor `makePrediction` is mainly for internal use.
#'
#' Object members:
#' \describe{
#' \item{predict.type (`character(1)`)}{Type set in [setPredictType].}
#' \item{data ([data.frame])}{See details.}
#' \item{threshold (`numeric(1)`)}{Threshold set in predict function.}
#' \item{task.desc ([TaskDesc])}{Task description object.}
#' \item{time (`numeric(1)`)}{Time learner needed to generate predictions.}
#' \item{error (`character(1)`)}{Any error messages generated by the learner (default NA_character_).}
#' }
#' @name Prediction
#' @rdname Prediction
NULL
#' @keywords internal
#' @rdname Prediction
#' @description
#' Internal, do not use!
#' @export
makePrediction = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
UseMethod("makePrediction")
}
#' @export
makePrediction.RegrTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "truth", "response", "se"))
data$id = id
data$truth = truth
if (predict.type == "response") {
data$response = y
} else {
data$response = y[, 1L]
data$se = y[, 2L]
}
makeS3Obj(c("PredictionRegr", "Prediction"),
predict.type = predict.type,
data = setRowNames(as.data.frame(filterNull(data)), row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
}
#' @export
makePrediction.ClassifTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "truth", "response", "prob"))
data$id = id
# truth can come from a simple "newdata" df. then there might not be all factor levels present
if (!is.null(truth)) {
levels(truth) = union(levels(truth), task.desc$class.levels)
}
data$truth = truth
if (predict.type == "response") {
data$response = y
data = as.data.frame(filterNull(data))
} else {
data$prob = y
data = as.data.frame(filterNull(data))
# fix columnnames for prob if strange chars are in factor levels
indices = stri_detect_fixed(names(data), "prob.")
if (sum(indices) > 0) {
names(data)[indices] = stri_paste("prob.", colnames(y))
}
}
p = makeS3Obj(c("PredictionClassif", "Prediction"),
predict.type = predict.type,
data = setRowNames(data, row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
if (predict.type == "prob") {
# set default threshold to 1/k
if (is.null(predict.threshold)) {
predict.threshold = rep(1 / length(task.desc$class.levels), length(task.desc$class.levels))
names(predict.threshold) = task.desc$class.levels
}
p = setThreshold(p, predict.threshold)
}
return(p)
}
#' @export
makePrediction.MultilabelTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "truth", "response", "prob"))
data$id = id
data$truth = truth
if (predict.type == "response") {
data$response = y
} else {
data$prob = y
}
p = makeS3Obj(c("PredictionMultilabel", "Prediction"),
predict.type = predict.type,
data = setRowNames(as.data.frame(filterNull(data)), row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
if (predict.type == "prob") {
# set default threshold to 0.5
if (is.null(predict.threshold)) {
predict.threshold = rep(0.5, length(task.desc$class.levels))
names(predict.threshold) = task.desc$class.levels
}
p = setThreshold(p, predict.threshold)
}
return(p)
}
#' @export
makePrediction.SurvTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "truth.time", "truth.event", "response"))
data$id = id
# FIXME: recode times
data$truth.time = truth[, 1L]
data$truth.event = truth[, 2L]
data$response = y
makeS3Obj(c("PredictionSurv", "Prediction"),
predict.type = predict.type,
data = setRowNames(as.data.frame(filterNull(data)), row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
}
#' @export
makePrediction.ClusterTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "response", "prob"))
data$id = id
if (predict.type == "response") {
data$response = y
data = as.data.frame(filterNull(data))
} else {
# this is a bit uncool, but as long we only use cl_predict we are OK I guess
class(y) = "matrix"
data$prob = y
data$response = getMaxIndexOfRows(y)
data = as.data.frame(filterNull(data))
}
p = makeS3Obj(c("PredictionCluster", "Prediction"),
predict.type = predict.type,
data = setRowNames(data, row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
return(p)
}
#' @export
makePrediction.CostSensTaskDesc = function(task.desc, row.names, id, truth, predict.type, predict.threshold = NULL, y, time, error = NA_character_, dump = NULL) {
data = namedList(c("id", "response"))
data$id = id
data$response = y
makeS3Obj(c("PredictionCostSens", "Prediction"),
predict.type = predict.type,
data = setRowNames(as.data.frame(filterNull(data)), row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error,
dump = dump
)
}
#' @export
print.Prediction = function(x, ...) {
catf("Prediction: %i observations", nrow(x$data))
catf("predict.type: %s", x$predict.type)
catf("threshold: %s", collapse(sprintf("%s=%.2f", names(x$threshold), x$threshold)))
catf("time: %.2f", x$time)
if (!is.na(x$error)) catf("errors: %s", x$error)
printHead(as.data.frame(x), ...)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.