R/plot_multiclass.R

Defines functions plot.multiclass

Documented in plot.multiclass

#' Plot of Metrics for Multiclass Classification Models
#'
#' Function plot.multiclass plots measure of quality of regression models.
#'
#' @param x An object, returned from \code{train()} function.
#' @param models A character or numeric that indicates which models
#' will be presented. If `NULL` (the default option) then the three best models
#' will be presented.
#' @param type A character one of `comparison`, `confusion-matrix`, `train-test`
#' indicates the type of chart.
#' @param metric A character such as `accuracy` indicates the metric on the plots.
#' @param ... Other parameters that are necessary for consistency with generic plot function.
#'
#' @return a ggplot2 object
#'
#' @examples
#' \dontrun{
#' library('forester')
#' data('iris')
#'
#' x <- train(iris, 'Species', bayes_iter = 0, random_evals = 0)
#' plot(x)
#' }
#'
#' @import ggplot2
#' @import patchwork
#' @export


plot.multiclass <- function(x,
                            models = NULL,
                            type   = 'comparison',
                            metric = 'accuracy',
                            ...){

  if (!(c('multiclass') %in% class(x)))
    stop('The plot() function requires an object created with train() function for multiclass classification task.')

  if (!any(type %in% c('comparison', 'confusion-matrix', 'train-test')))
    stop('The selected plot type does not exist.')

  models_names <- models_num <- NULL

  if (is.null(models)) {
    models_names <- x$score_test$name[1:3]
    models_num   <- 1:3
  } else{
    if (is.character(models)) {
      if (any(models %in% names(x$models_list))) {
        models_names <- models[models %in% names(x$models_list)]
        models_num   <- which(names(x$models_list) %in% models)
        if (length(models_num) < length(models)) {
          message(paste0(
            'Check the given models.',
            ' Models do not exist: ',
            paste(models[!models %in% names(x$models_list)], collapse = ', '),
            '.'
          ))
        }
      } else{
        stop('Models with the given names do not exist.')
      }
    } else{
      if (any(models %in% 1:length(x$models_list))) {
        models_names <- names(x$models_list)[which(1:length(x$models_list) %in% models)]
        models_num   <- which(1:length(x$models_list) %in% models)
        if (length(models_num) < length(models)) {
          message(paste0(
            'Check the given models.',
            ' Models do not exist: ',
            paste(models[which(!models %in% 1:length(x$models_list))], collapse = ', '),
            '.'
          ))
        }
      } else{
        stop('Models with the given numbers do not exist.')
      }
    }
  }
  if (type == 'comparison') {
    test_scores    <- data.frame(t(x$score_test[, (NCOL(x$score_test) - 3):NCOL(x$score_test)]))
    no_cols        <- min(10, ncol(test_scores))
    test_scores    <- test_scores[, 1:no_cols]
    test_y         <- data.frame(metric = rownames(test_scores), value = unlist(test_scores))
    test_data      <- x$score_test[1:no_cols, ]
    test_data$name <- factor(test_data$name, levels = unique(test_data$name))
    test_all       <- cbind(test_data[rep(seq_len(nrow(test_data)), each = 4), ],
                            data.frame(metric = rownames(test_scores), value = unlist(test_scores)))
    test_all       <- test_all[, c('name', 'engine', 'tuning', 'metric', 'value')]

    valid_scores    <- data.frame(t(x$score_valid[, (NCOL(x$score_valid) - 3):NCOL(x$score_valid)]))
    valid_scores    <- valid_scores[, 1:no_cols]
    valid_y         <- data.frame(metric = rownames(valid_scores), value = unlist(valid_scores))
    valid_data      <- x$score_valid[1:no_cols, ]
    valid_data$name <- factor(valid_data$name, levels = unique(valid_data$name))
    valid_all       <- cbind(valid_data[rep(seq_len(nrow(valid_data)), each = 4), ],
                             data.frame(metric = rownames(valid_scores), value = unlist(valid_scores)))
    valid_all       <- valid_all[, c('name', 'engine', 'tuning', 'metric', 'value')]

    comparison_plot <- function(test, all) {
      if (test) {
        all  <- test_all
        text <- 'testing'
      } else {
        all  <- valid_all
        text <- 'validation'
      }
      p <- ggplot(all, aes(
        x     = all$name,
        y     = all$value,
        color = metric,
        group = metric
      )) +
        geom_line(linewidth = 1) +
        scale_color_manual(values = colors_discrete_forester(length(unique(all$name)))) +
        theme_forester() +
        geom_point() +
        labs(
          title    = 'Model comparison',
          subtitle = paste0('on ', text, ' dataset'),
          y        = 'Value of metric',
          x        = 'Model',
          color    = 'Metric'
        ) + theme(axis.text.x = element_text(angle = 25))
      return(p)
    }

    plot_test  <- comparison_plot(TRUE , test_all)
    plot_valid <- comparison_plot(FALSE, valid_all)
    return(patchwork::wrap_plots(list(plot_test, plot_valid), ncol = 1))
  }


  if(type == 'confusion-matrix') {
    test_preds     <- sapply(x$predictions_test[[x$score_test$name[[1]]]], FUN = round)
    test_observed  <- x$test_observed
    valid_preds    <- sapply(x$predictions_valid[[x$score_valid$name[[1]]]], FUN = round)
    valid_observed <- x$valid_observed

    confusion_matrix <- function(test) {
      if (test) {
        observed  <- test_observed
        preds     <- test_preds
        text      <- 'testing'
        name      <- x$score_test$name[[1]]
      } else {
        observed  <- valid_observed
        preds     <- valid_preds
        text      <- 'validation'
        name      <- x$score_valid$name[[1]]
      }

      d_binomial <- table('target'= observed, 'prediction' = preds)
      n <- length(unique(observed))
      conf_mat <- matrix(rep(0, n * n), nrow = n, ncol = n)

      for (i in 1:length(observed)) {
        conf_mat[observed[i], preds[i]] <- conf_mat[observed[i], preds[i]] + 1
      }

      Y          <- as.vector(conf_mat)
      Target     <- factor(rep(1:n, times = n))
      Prediction <- factor(rep(1:n, each = n))
      df         <- data.frame(Target, Prediction, Y)

      p <- ggplot(data =  df, mapping = aes(x = Target, y = Prediction)) +
        geom_tile(aes(fill = Y), colour = 'white') +
        geom_text(aes(label = sprintf('%1.0f', Y)), vjust = 1, colour = colors_discrete_forester(5)[5], size = 5) +
        scale_fill_gradient(low = colors_diverging_forester()[1], high = colors_diverging_forester()[2]) +
        theme_forester() +
        labs(
          title    = paste0('Confusion Matrix'),
          subtitle = paste0('for the best ', text, ' model: ' , name),
          y        = 'Prediction',
          x        = 'Target'
        ) +
        theme(legend.position = 'bottom')

      return(p)
    }
    plot_test  <- confusion_matrix(TRUE)
    plot_valid <- confusion_matrix(FALSE)
    return(patchwork::wrap_plots(plot_test, plot_valid))
  }

  if(type == 'train-test') {
    no_columns <- min(10, ncol(x$score_test))
    models_names <- x$score_test$name[1:no_columns]

    train_score <- x$score_train[x$score_train$name %in% models_names, ]
    names(train_score)[which(names(train_score) %in% c('accuracy', 'weighted_averaged_precision',
                                                       'weighted_recall', 'weighted_f1'))] <-
      paste0(names(train_score)[which(names(train_score) %in% c('accuracy', 'weighted_precision',
                                                                'weighted_recall', 'weighted_averaged_f1'))], '_train')

    test_score <- x$score_test[x$score_test$name %in% models_names,
                                          c('accuracy', 'weighted_precision', 'weighted_recall', 'weighted_f1')]
    names(test_score)[which(names(test_score) %in% c('accuracy', 'weighted_precision',
                                                     'weighted_recall', 'weighted_f1'))] <-
      paste0(names(test_score)[which(names(test_score) %in% c('accuracy', 'weighted_precision',
                                                              'weighted_recall', 'weighted_f1'))], '_test')

    score <- cbind(train_score, test_score)

    p <- ggplot(score, aes(x = .data[[paste0(metric, '_train')]], y = .data[[paste0(metric, '_test')]], color = .data[['engine']])) +
      geom_point() +
      geom_abline(intercept = 0, slope = 1) +
      theme_forester() +
      scale_color_manual(values = colors_discrete_forester(length(unique(score$engine)))) +
      labs(
        title = paste(toupper(metric), 'train vs test'),
        x     = 'Train',
        y     = 'Test',
        color = 'Engine'
      ) +
      ggrepel::geom_text_repel(aes(label = score$name), show.legend = FALSE, max.time = 3) +
      scale_x_continuous(limits = c(min(score[paste0(metric, '_train')], score[paste0(metric, '_test')]),
                                    max(score[paste0(metric, '_train')], score[paste0(metric, '_test')]))) +
      scale_y_continuous(limits = c(min(score[paste0(metric, '_train')], score[paste0(metric, '_test')]),
                                    max(score[paste0(metric, '_train')], score[paste0(metric, '_test')])))
  }

  return(p)
}
ModelOriented/forester documentation built on June 6, 2024, 7:29 a.m.