R/xai.R

Defines functions visualize_variable_importance_rf list_variable_importance save_pimp_plots get_pimp_range_List_models_list get_pimp_for_models_list get_xai_explanations get_percentile_from_model

################################################################################
#
# XAI Functions
#
################################################################################

######################################################################
# Function get_percentile_from_model()
# INPUT model (caret::train)
# OUTPUT percentile (numeric)
######################################################################
get_percentile_from_model <- function(model, percentile) {

  model$trainingData$.outcome %>%
    quantile(percentile)
}

######################################################################
# Function get_xai_explanations()
# IN:   models_list (list) containing caret models
# OUT:  xai output by DALEX or LIME
######################################################################
get_xai_explanations <- function(
  models_list,
  seed = 171,
  cutoff_greater = 0,
  n_features_lime = 5,
  local_obs = NULL,
  local_min_cutoff = 0.95,
  local_no = 6,
  random_case = NULL,
  save_path = NULL,
  suffix = NULL,
  width = 6, height = 6,
  # get_DALEX_explainer = TRUE,
  get_DALEX_residual_plot = TRUE,
  no_permutations = 50,
  get_DALEX_feature_importance = TRUE,
  get_DALEX_feature_importance_plot = TRUE,
  get_DALEX_pdp_plot = TRUE,
  get_DALEX_attribution_plot = TRUE,
  get_DALEX_attribution_text = TRUE,
  get_DALEX_attribution_uncertainty_plot = TRUE,
  get_DALEX_shapley_plot = TRUE,
  get_LIME_explanations = FALSE,
  get_LIME_explanations_plot = FALSE,
  get_LIME_features_plot = FALSE
) {
  # tidyverse packages
  require(ggplot2) # ggsave
  require(dplyr)
  require(furrr)
  # XAI packages
  require(DALEX)
  require(iBreakDown)
  require(ingredients)
  require(lime)

  options(parallelly.fork.enable = TRUE)
  plan(multicore, workers = 8)

  xai.list <- models_list %>%

    future_map(function(model_object) {

      print(paste("*********", model_object$method))
      training.set <- model_object$trainingData %>%
        select(.outcome, everything())

      target <- training.set$.outcome
      print(paste("***target"))

      features <- training.set %>% select(-.outcome)

      # local observations for prediction
      local.obs <- if (!is.null(local_obs)) {
        local_obs
      } else {
        training.set %>%
          filter(
            .outcome >=
              get_percentile_from_model(model_object, local_min_cutoff)) %>%
          sample_n(local_no)
      }

      random.case <- if (!is.null(random_case)) {
        random_case
      } else {
        local.obs %>% sample_n(1)
      }

      print("*** DALEX.explainer")
      DALEX.explainer <- DALEX::explain(
          model = model_object,
          data = features,
          y = training.set$.outcome >= cutoff_greater,
          label = paste(model_object$method, " model"),
          colorize = TRUE
        )

      # for residual plots by plot(geom = "histogram")
      DALEX.performance <- DALEX.explainer %>%
        DALEX::model_performance()

      DALEX.residual.plot <- if (get_DALEX_residual_plot) {

          DALEX.performance %>% plot(geom = "histogram")

      } else {
        NULL
      }

      DALEX.feature.importance <- if (
        get_DALEX_feature_importance & !is.null(DALEX.explainer)) {

        print("*** DALEX.permutation.feature.importance")

        DALEX.explainer %>%
          model_parts(
            B = no_permutations,
            type = "ratio"
          )

      } else {
        NULL
      }

      DALEX.feature.importance.plot <- if (
        get_DALEX_feature_importance_plot &
        !is.null(DALEX.feature.importance)) {

        print("*** DALEX.feature.importance.plot")

        DALEX.feature.importance %>%
          plot(
            bar_width = 20 / log(ncol(features))
            , show_boxplots = FALSE
            , title = "Permutation Feature importance"
            , subtitle = ""
          ) +
          # reduce space to axis
          scale_y_continuous(expand = expansion()) %T>%
        {
          if (!is.null(save_path)) {
            ggsave(
              width = width, height = height,
              filename = paste(
                c(save_path, "DALEX.feature.importance.plot", model_object$method,
                suffix, "png"),
                collapse = ".")
            )
          }
        }

      } else {
        NULL
      }

      DALEX.pdp.plot <- if (get_DALEX_pdp_plot & !is.null(DALEX.explainer)) {

        print("*** DALEX.pdp.plot")

        DALEX.pdp <- DALEX.explainer %>% ingredients::partial_dependency()

        DALEX.pdp %>% plot %T>%
        {
          if (!is.null(save_path)) {
            ggsave(
              width = width, height = height,
              filename = paste(
                c(save_path, "plot.pdp.DALEX", model_object$method,
                  suffix, "png"),
                collapse = ".")
            )
          }
        }
      } else {
        NULL
      }

      DALEX.attribution <- DALEX.explainer %>%
        iBreakDown::local_attributions(
          local.obs,
          keep_distributions = TRUE
        )
      print("*** DALEX.attribution")

      DALEX.attribution.text <- if(get_DALEX_attribution_text) {

        print("*** DALEX.attribution.text")

        DALEX.attribution %>%
          iBreakDown::describe()
      } else {
        NULL
      }

      DALEX.attribution.plot <- if (get_DALEX_attribution_plot &
                                    !is.null(DALEX.explainer)) {

        print("*** DALEX.attribution.plot")

        DALEX.explainer %>%
          iBreakDown::local_attributions(
            local.obs,
            keep_distributions = TRUE
            ) %>%
          plot(
            shift_contributions = 0.03
          ) %T>%
          {
            if (!is.null(save_path)) {
              ggsave(
                width = width, height = height,
                filename = paste(
                  c(save_path, "DALEX.attribution.plot", model_object$method,
                    suffix, "png"),
                  collapse = ".")
              )
            }
          }
      } else {
        NULL
      }

      DALEX.attribution.uncertainty.plot <-
        if (get_DALEX_attribution_uncertainty_plot &
            !is.null(DALEX.explainer)) {

          print("*** DALEX.attribution.uncertainty.plot")
          DALEX.explainer %>%
            iBreakDown::break_down_uncertainty(local.obs) %>%
            plot %T>%
            {
              if (!is.null(save_path)) {
                ggsave(
                  width = width, height = height,
                  filename = paste(
                    c(save_path, "DALEX.attribution.uncertainty.plot",
                      model_object$method, suffix, "png"),
                    collapse = ".")
                )
              }
            }
        } else {
          NULL
        }


      DALEX.distribution.plot <- DALEX.attribution %>%
        plot(plot_distributions = TRUE)

      print("*** DALEX.distribution.plot")


      DALEX.shapley.plot <- if (
        get_DALEX_shapley_plot &
        !is.null(DALEX.explainer) & !is.null(random_case)) {

        print("*** DALEX.shapley.plot")

        DALEX.explainer %>%
          iBreakDown::shap(random_case,
                           B = no_permutations) %>%
          plot()
      }


      LIME.explainer <- if (
        get_LIME_explanations |
        get_LIME_explanations_plot |
        get_LIME_features_plot ) {

        print("*** LIME.explainer")
        lime::lime(
          # tricky: features not training.set
          x = features,
          model = model_object
        )
      } else {
        NULL
      }

      LIME.explanations <- if (
        get_LIME_explanations & !is.null(LIME.explainer)) {

        print("***LIME.explanations")
        lime::explain(
          # tricky: features not training.set
          x = local.obs %>% select(-.outcome),
          explainer = LIME.explainer,
          n_features = n_features_lime
        ) %T>% print
      } else {
        NULL
      }

      LIME.explanations.plot <- if (
        get_LIME_explanations_plot & !is.null(LIME.explanations)) {

        print("***LIME.explanations.plot")
        lime::plot_explanations(
          LIME.explanations
        ) + ggtitle(model_object$method)  %T>%
          {
            if (!is.null(save_path)) {
              ggsave(
                width = width, height = height,
                filename = paste(
                  c(save_path, "LIME.explanations.plot",
                    model_object$method, suffix, "png"),
                  collapse = ".")
              )
            }
          }

      } else {
        NULL
      }

      LIME.features.plot <- if (
        get_LIME_features_plot & !is.null(LIME.explainer)) {

        print("***LIME.features.plot")
        lime::plot_features(
          LIME.explanations,
          ncol = 2
        ) + ggtitle(model_object$method)  %T>%
          {
            if (!is.null(save_path)) {
              ggsave(
                width = width, height = height,
                filename = paste(
                  c(save_path, "LIME.features.plot",
                    model_object$method, suffix, "png"),
                  collapse = ".")
              )
            }
          }
      } else {
        NULL
      }

      return(
        list(
          DALEX.explainer = DALEX.explainer
          , DALEX.performance = DALEX.performance
          , DALEX.feature.importance = DALEX.feature.importance
          , DALEX.feature.importance.plot = DALEX.feature.importance.plot
          , DALEX.residual.plot = DALEX.residual.plot
          , DALEX.pdp.plot = DALEX.pdp.plot
          , DALEX.attribution.text = DALEX.attribution.text
          , DALEX.attribution.plot = DALEX.attribution.plot
          , DALEX.attribution.uncertainty.plot = DALEX.attribution.uncertainty.plot
          , DALEX.distribution.plot = DALEX.distribution.plot
          , DALEX.shapley.plot = DALEX.shapley.plot
          , LIME.explainer = LIME.explainer
          , LIME.explanations = LIME.explanations
          , LIME.explanations.plot = LIME.explanations.plot
          , LIME.features.plot = LIME.features.plot
        )
      )
    },
    .options = furrr_options(
      seed = seed
      , packages = c("DALEX", "iBreakDown", "ingredients", "lime")
    ))
}

################################################################################
# get permutation feature importance (pimp) for models.list
################################################################################
get_pimp_for_models_list <- function(
  models_list, no_permutations = 50, seed = 171) {

  require(dplyr)
  require(DALEX)

  fi.list.DALEX <- models_list %>%

    imap(function(model_object, model_name) {

      print(paste("*********", model_object$method))

      training.set <- model_object$trainingData %>%
        select(.outcome, everything())

      target <- training.set$.outcome
      print(paste("***target"))

      features <- training.set %>% select(-.outcome)

      DALEX.explainer <- DALEX::explain(
        model = model_object,
        data = features,
        y = target,
        label = model_name, # .y in imap() is list element name
        colorize = TRUE
      )
      print("*** DALEX.explainer")

      set.seed(seed)
      system.time(
        DALEX.permutation.fi <- model_parts(
          explainer = DALEX.explainer,
          loss_function = loss_root_mean_square,
          B = no_permutations,
          type = "ratio"
        )
      ) # 78-84s for rf

      print("*** DALEX.permutation.fi")

      # bar_width = 5.9 items / 12.4 composites
      ( bar.width <- 20 / log(ncol(features)) )

      DALEX.permutation.fi.plot <- DALEX.permutation.fi %>%
        plot(
          bar_width = bar.width
          , show_boxplots = FALSE
          , title = "Permutation Feature importance"
          , subtitle = ""
        ) +
        scale_y_continuous(expand = expansion()) # reduce space to axis

      print("*** DALEX.permutation.fi.plot")

      return(
        list(
          DALEX.explainer = DALEX.explainer
          , DALEX.permutation.fi = DALEX.permutation.fi
          , DALEX.permutation.fi.plot = DALEX.permutation.fi.plot
        )
      )
    })
  # 44s B=10, 220s B=50, 408s B=100

  return(fi.list.DALEX)
}

#####################################################
# calculate permutation feature importance
#####################################################
get_pimp_range_List_models_list <- function(
  list_models_list, model_string = "", value_function = max) {

  list_models_list %>%
    {
      if (model_string != "") {

        get_list_elements_by_string(., model_string)

      } else {
        .
      }
    } %>%
    {
      # TRICKY: temporary assignment works only for . not names(.) bec it
      # is piped through, so call models.selected %>% names later
      {
        . -> models.selected
      } %>%
        imap_dfr(function(models_list, models_list_label) {
          models_list %>%
            imap(function(model, model_label) {
              model$DALEX.permutation.fi %>%
                # remove the baseline model (seems without any predictor)
                filter(variable != "_baseline_") %>%
                group_by(variable) %>%
                dplyr::summarise(mean_dropout_loss = mean(dropout_loss)) %>%
                .$mean_dropout_loss %>%
                value_function
            })
        }) %>%
        as.data.frame %>%
        set_rownames(models.selected %>% names)
    }
}

################################################################################
# save feature importance plots
################################################################################
save_pimp_plots <- function(
  datasets_pimp_lists,
  model_string = "",
  scaling_factor = 1,
  width = "automatic",
  height = "automatic",
  axis_tick_labels = NULL) {

  pimp.max <- datasets_pimp_lists %>%
    get_pimp_range_List_models_list(model_string) %>%
    max %>%
    print

  pimp.min <- datasets_pimp_lists %>%
    get_pimp_range_List_models_list(model_string, min) %>%
    min %>%
    print

  no.features <- function(model) {
    model %>%
      pluck("DALEX.explainer") %>%
      pluck("data") %>%
      ncol %>% print
  }

  datasets_pimp_lists %>%
    get_list_elements_by_string(model_string) %>%
    imap(function(model, model_label) {
      model %>%
        imap(
          ~ .x %>%
            {
              .$DALEX.permutation.fi.plot +
                # same scale on flipped x-asis for same datasets_pimp_lists
                scale_y_continuous(limits = c(pimp.min, pimp.max)) +
                # next layer must be added within code block of plot object
                {
                  if (!is.null(axis_tick_labels)) {
                    scale_x_discrete(labels = axis_tick_labels)
                  }
                }
            } %>%
            ggsave(
              filename = paste0(c("figures/pimp", model_label, .y, "png"), collapse = "."),
              plot = .,
              scale = scaling_factor,
              width = ifelse(is.numeric(width), width, log(no.features(.x)) * 3),
              height = ifelse(is.numeric(height), height, log(no.features(.x)) * 2.2),
              dpi = 150
            )
        )
    })
}

################################################################################
# List variable importance scores
# input caret::train object
################################################################################
list_variable_importance <- function(model_object) {

  require(dplyr)
  require(caret)

  model_object$importance %>%
    as.data.frame %>%
    tibble::rownames_to_column() %>%
    dplyr::mutate(Importance = round(IncNodePurity * 100/max(IncNodePurity),
                                     digits =2)) %>%
    arrange(-IncNodePurity)

}

################################################################################
# visualize feature importance by caret::varImp
# input caret::train object
################################################################################
visualize_varImp <- function (
  model_object,
  x_label = "",
  y_label = "feature importance",
  fill_color = "#114151") {

  require(dplyr)
  require(caret)

  importance_object <- model_object %>% varImp()

  if (class(importance_object) == "varImp.train") {
    importance_object %<>% .$importance
  }
  if (!hasName(importance_object, "rowname")) {
    importance_object %<>% rownames_to_column()
  }

  importance_object %>%
    setNames(c("variable", "Importance")) %>%
    ggplot(data = ., aes(x = reorder(variable, Importance), y = Importance)) +
    theme_minimal() +
    geom_bar(stat = "identity", fill = fill_color) +
    coord_flip() +
    theme(axis.title = element_text(size = 12),
          axis.text = element_text(size = 12)) +
    scale_y_continuous(expand = c(0, 0), limits = c(0, 102)) +
    xlab(x_label) + ylab(y_label)
}


################################################################################
# Visualize variable imporance for randomForests objects
# input randomForest object
################################################################################
visualize_variable_importance_rf <- function(rf_object) {

  require(dplyr)

  rf_object$importance %>%
    as.data.frame %>%
    tibble::rownames_to_column() %>%
    dplyr::mutate(Importance =
                  round(IncNodePurity * 100/max(IncNodePurity), digits =2)) %>%
    arrange(-IncNodePurity) %>%
    ggplot(data = ., aes(x = reorder(rowname, Importance), y = Importance)) +
    theme_minimal() +
    geom_bar(stat="identity", fill = "#114151") +
    coord_flip() +
    theme(axis.title = element_text(size = 12)
          , axis.text = element_text(size = 12)
          # , panel.grid.major.y = element_blank() # remove horizontal grid lines
    ) +
    scale_y_continuous(expand = c(0,0), limits = c(0,102)) +
    xlab("item") + ylab("variable importance")
}
agilebean/machinelearningtools documentation built on May 2, 2023, 7:38 p.m.