R/plot.overall_comparison.R

Defines functions plot.overall_comparison

Documented in plot.overall_comparison

#' Plot function for overall_comparison
#'
#' The function plots data created with \code{\link{overall_comparison}}. For radar plot it uses auditor's
#' \code{\link[auditor]{plot_radar}}. Keep in mind that the function creates two plots returned as list.
#'
#' @param x - data created with \code{\link{overall_comparison}}
#' @param ... - other parameters
#'
#' @return A named list of ggplot objects.
#'
#' It consists of:
#' \itemize{
#' \item \code{radar_plot} plot created with \code{\link[auditor]{plot_radar}}
#' \item \code{accordance_plot} accordance plot of responses. OX axis stand for champion response, while OY for one of challengers
#'                              responses. Colour indicates on challenger.
#' }
#'
#' @importFrom graphics plot
#'
#' @rdname plot.overall_comparison
#' @export
#'
#' @examples
#' \donttest{
#' library("DALEXtra")
#' library("mlr")
#' task <- mlr::makeRegrTask(
#'   id = "R",
#'   data = apartments,
#'   target = "m2.price"
#' )
#' learner_lm <- mlr::makeLearner(
#'   "regr.lm"
#' )
#' model_lm <- mlr::train(learner_lm, task)
#' explainer_lm <- explain_mlr(model_lm, apartmentsTest, apartmentsTest$m2.price, label = "LM")
#'
#' learner_rf <- mlr::makeLearner(
#'   "regr.ranger"
#' )
#' model_rf <- mlr::train(learner_rf, task)
#' explainer_rf <- explain_mlr(model_rf, apartmentsTest, apartmentsTest$m2.price, label = "RF")
#'
#' learner_gbm <- mlr::makeLearner(
#'   "regr.gbm"
#' )
#' model_gbm<- mlr::train(learner_gbm, task)
#' explainer_gbm <- explain_mlr(model_gbm, apartmentsTest, apartmentsTest$m2.price, label = "GBM")
#'
#' data <- overall_comparison(explainer_lm, list(explainer_gbm, explainer_rf), type = "regression")
#' plot(data)
#' }

plot.overall_comparison <- function(x, ...) {
  data <- x
  if (!requireNamespace("auditor")) {
    stop("Please download auditor package to access that functionallity")
  }
  p <- do.call(plot, data$radar)
  q <- ggplot(data = data$accordance, aes(x = data$accordance$Champion, y = data$accordance$Challenger, colour = data$accordance$Label)) +
    geom_point() +
    geom_abline(slope = 1, intercept = 0, size = 1, color = "#371ea3", show.legend = TRUE) +
    labs(x = "Champion response",
         y = "Challenger response",
         colour = "Challengers") +
    scale_color_manual(values = colors_discrete_drwhy(length(unique(data$accordance$Label))+1)) +
    theme_drwhy()

  list("radar_plot" = p, "accordance_plot" = q)
}
ModelOriented/DALEXtra documentation built on June 28, 2023, 5:01 p.m.