#' Measure performance of prediction.
#'
#' Measures the quality of a prediction w.r.t. some performance measure.
#'
#' @template arg_pred
#' @template arg_measures
#' @param task ([Task])\cr
#' Learning task, might be requested by performance measure, usually not needed except for clustering or survival.
#' @param model ([WrappedModel])\cr
#' Model built on training data, might be requested by performance measure, usually not needed except for survival.
#' @param feats ([data.frame])\cr
#' Features of predicted data, usually not needed except for clustering.
#' If the prediction was generated from a `task`, you can also pass this instead and the features
#' are extracted from it.
#' @param simpleaggr ([logical])\cr
#' If TRUE, aggregation of \code{ResamplePrediction} objects is skipped. This is used internally for threshold tuning. Default is \code{FALSE}.
#' @return (named [numeric]). Performance value(s), named by measure(s).
#' @export
#' @family performance
#' @examples
#' training.set = seq(1, nrow(iris), by = 2)
#' test.set = seq(2, nrow(iris), by = 2)
#'
#' task = makeClassifTask(data = iris, target = "Species")
#' lrn = makeLearner("classif.lda")
#' mod = train(lrn, task, subset = training.set)
#' pred = predict(mod, newdata = iris[test.set, ])
#' performance(pred, measures = mmce)
#'
#' # Compute multiple performance measures at once
#' ms = list("mmce" = mmce, "acc" = acc, "timetrain" = timetrain)
#' performance(pred, measures = ms, task, mod)
performance = function(pred, measures, task = NULL, model = NULL, feats = NULL, simpleaggr = FALSE) {
if (!is.null(pred)) {
assertClass(pred, classes = "Prediction")
}
measures = checkMeasures(measures, pred$task.desc)
res = vnapply(measures, doPerformanceIteration, pred = pred, task = task, model = model, td = NULL, feats = feats, simpleaggr = simpleaggr)
# FIXME: This is really what the names should be, but it breaks all kinds of other stuff
# if (inherits(pred, "ResamplePrediction")) {
# setNames(res, vcapply(measures, measureAggrName))
# } else {
# setNames(res, extractSubList(measures, "id"))
# }
setNames(res, extractSubList(measures, "id"))
}
doPerformanceIteration = function(measure, pred = NULL, task = NULL, model = NULL, td = NULL, feats = NULL, simpleaggr = simpleaggr) {
m = measure
props = getMeasureProperties(m)
if ("req.pred" %in% props) {
if (is.null(pred)) {
stopf("You need to pass pred for measure %s!", m$id)
}
}
if ("req.truth" %in% props) {
type = getTaskDesc(pred)$type
if (type == "surv") {
if (is.null(pred$data$truth.time) || is.null(pred$data$truth.event)) {
stopf("You need to have 'truth.time' and 'truth.event' columns in your pred object for measure %s!", m$id)
}
} else if (type == "multilabel") {
if (!(any(stri_detect_regex(colnames(pred$data), "^truth\\.")))) {
stopf("You need to have 'truth.*' columns in your pred object for measure %s!", m$id)
}
} else {
if (is.null(pred$data$truth)) {
stopf("You need to have a 'truth' column in your pred object for measure %s!", m$id)
}
}
}
if ("req.model" %in% props) {
if (is.null(model)) {
stopf("You need to pass model for measure %s!", m$id)
}
assertClass(model, classes = "WrappedModel")
}
if ("req.task" %in% props) {
if (is.null(task)) {
stopf("You need to pass task for measure %s!", m$id)
}
assertClass(task, classes = "Task")
}
if ("req.feats" %in% props) {
if (is.null(task) && is.null(feats)) {
stopf("You need to pass either task or features for measure %s!", m$id)
} else if (is.null(feats)) {
feats = task$env$data[pred$data$id, , drop = FALSE]
} else {
assertClass(feats, "data.frame")
}
}
# we need to find desc somewhere
td = if (!is.null(pred)) {
pred$task.desc
} else if (!is.null(model)) {
model$task.desc
} else if (!is.null(task)) {
getTaskDesc(task)
}
# null only happens in custom resampled measure when we do no individual measurements
if (!is.null(td)) {
if (td$type %nin% props) {
stopf("Measure %s does not support task type %s!", m$id, td$type)
}
if (td$type == "classif" && length(td$class.levels) > 2L && "classif.multi" %nin% props) {
stopf("Multiclass problems cannot be used for measure %s!", m$id)
}
# if we have multiple req.pred.types, check if we have one of them (currently we only need prob)
req.pred.types = if ("req.prob" %in% props) "prob" else character(0L)
if (!is.null(pred) && length(req.pred.types) > 0L && pred$predict.type %nin% req.pred.types) {
on.measure.not.applicable = getMlrOption(name = "on.measure.not.applicable")
msg = sprintf("Measure %s requires predict type to be: '%s'!", m$id, collapse(req.pred.types))
if (on.measure.not.applicable == "stop") {
stop(msg)
} else if (on.measure.not.applicable == "warn") {
warning(msg)
}
return(NA_real_)
}
}
# if it's a ResamplePrediction, aggregate
if (simpleaggr) {
measure$fun(task, model, pred, feats, m$extra.args)
} else {
if (inherits(pred, "ResamplePrediction")) {
if (is.null(pred$data$iter)) pred$data$iter = 1L
if (is.null(pred$data$set)) pred$data$set = "test"
fun = function(ss) {
is.train = ss$set == "train"
if (any(is.train)) {
pred$data = as.data.frame(ss[is.train, ])
perf.train = measure$fun(task, model, pred, feats, m$extra.args)
} else {
perf.train = NA_real_
}
pred$data = as.data.frame(ss[!is.train, ])
perf.test = measure$fun(task, model, pred, feats, m$extra.args)
list(perf.train = perf.train, perf.test = perf.test)
}
perfs = as.data.table(pred$data)[, fun(.SD), by = "iter"]
measure$aggr$fun(task, perfs$perf.test, perfs$perf.train, measure, perfs$iter, pred)
} else {
measure$fun(task, model, pred, feats, m$extra.args)
}
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.