Nothing
#' Compare different metrics
#'
#' @rdname fairness_compare_metrics
#'
#' @description
#' Compare learners with respect to to one or multiple metrics.
#' Metrics can but be but are not limited to fairness metrics.
#'
#' @template pta
#'
#' @param object ([PredictionClassif] | [BenchmarkResult] | [ResampleResult])\cr
#' The object to create a plot for.
#' * If provided a ([PredictionClassif]).
#' Then the visualization will compare the fairness metrics among the binary level from protected field
#' through bar plots.
#' * If provided a ([ResampleResult]).
#' Then the visualization will generate the boxplots for fairness metrics, and compare them among
#' the binary level from protected field.
#' * If provided a ([BenchmarkResult]).
#' Then the visualization will generate the boxplots for fairness metrics, and compare them among
#' both the binary level from protected field and the models implemented.
#' @param ...
#' The arguments to be passed to methods, such as:
#' * `fairness_measures` (list of [Measure])\cr
#' The fairness measures that will evaluated on object, could be single [Measure] or list of [Measure]s.
#' Default measure set to be `msr("fairness.acc")`.
#' * `task` ([TaskClassif])\cr
#' The data task that contains the protected column, only required when object is ([PredictionClassif]).
#'
#' @export
#' @return A 'ggplot2' object.
#' @examples
#' library("mlr3")
#' library("mlr3learners")
#'
#' # Setup the Fairness Measures and tasks
#' task = tsk("adult_train")$filter(1:500)
#' learner = lrn("classif.ranger", predict_type = "prob")
#' learner$train(task)
#' predictions = learner$predict(task)
#' design = benchmark_grid(
#' tasks = task,
#' learners = lrns(c("classif.ranger", "classif.rpart"),
#' predict_type = "prob", predict_sets = c("train", "predict")),
#' resamplings = rsmps("cv", folds = 3)
#' )
#'
#' bmr = benchmark(design)
#' fairness_measure = msr("fairness.tpr")
#' fairness_measures = msrs(c("fairness.tpr", "fairness.fnr", "fairness.acc"))
#'
#' # Predictions
#' compare_metrics(predictions, fairness_measure, task)
#' compare_metrics(predictions, fairness_measures, task)
#'
#' # BenchmarkResult and ResamplingResult
#' compare_metrics(bmr, fairness_measure)
#' compare_metrics(bmr, fairness_measures)
compare_metrics = function(object, ...) {
UseMethod("compare_metrics")
}
#' @export
compare_metrics.PredictionClassif = function(object, measures = msr("fairness.acc"), task, ...) { # nolint
measures = as_measures(measures)
scores = setDT(as.data.frame(t(object$score(measures, task = task, ...))))
data = melt(scores[, ids(measures), with = FALSE], measure.vars = names(scores))
ggplot(data, aes(x = variable, y = value)) +
geom_bar(stat = "identity") +
xlab("Metrics") +
ylab("Value") +
theme(legend.position = "none") +
scale_fill_hue(c = 100, l = 60)
}
#' @export
compare_metrics.BenchmarkResult = function(object, measures = msr("fairness.acc"), ...) { # nolint
measures = as_measures(measures)
scores = object$aggregate(measures, ...)
data = melt(scores[, c(ids(measures), "learner_id", "task_id"), with = FALSE], id.vars = c("learner_id", "task_id"))
ggplot(data, aes(x = learner_id, y = value, fill = variable)) +
geom_bar(stat = "identity", position = "dodge") +
xlab("Metrics") +
ylab("Value") +
scale_fill_hue(name = "Metric", c = 100, l = 60) +
facet_wrap(~task_id)
}
#' @export
compare_metrics.ResampleResult = function(object, measures = msr("fairness.acc"), ...) { # nolint
object = as_benchmark_result(object)
compare_metrics(object, measures)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.