R/InterpretingMethod.R

Defines functions get_lrp_rule aggregate_channels get_aggr_function apply_results move_channels_last transform_label_to_idx check_output_idx_for_plot create_grid create_dataframe_from_result tensor_list_to_named_array check_output_idx print_output_idx print_result plot_global.InterpretingMethod plot_global boxplot.InterpretingMethod get_result.InterpretingMethod get_result

Documented in get_result plot_global

###############################################################################
#                         Interpreting Method
###############################################################################

#' @title Super class for interpreting methods
#' @description This is a super class for all interpreting methods in the
#' `innsight` package. Implemented are the following methods:
#'
#' - *Deep Learning Important Features* ([`DeepLift`])
#' - *Deep Shapley additive explanations* ([`DeepSHAP`])
#' - *Layer-wise Relevance Propagation* ([`LRP`])
#' - Gradient-based methods:
#'    - *Vanilla gradients* including *Gradient\eqn{\times}Input* ([`Gradient`])
#'    - Smoothed gradients including *SmoothGrad\eqn{\times}Input* ([`SmoothGrad`])
#'    - *Integrated gradients* ([`IntegratedGradient`])
#'    - *Expected gradients* ([`ExpectedGradient`])
#' - *Connection Weights* (global and local) ([`ConnectionWeights`])
#' - Also some model-agnostic approaches:
#'    - *Local interpretable model-agnostic explanations* ([`LIME`])
#'    - *Shapley values* ([`SHAP`])
#'
#' @template param-converter
#' @template param-data
#' @template param-channels_first
#' @template param-ignore_last_act
#' @template param-dtype
#' @template param-aggr_channels
#' @template param-output_label
#' @template param-as_plotly
#' @template param-verbose
#' @template param-ref_data_idx
#' @template param-preprocess_FUN
#' @template param-individual_data_idx
#' @template param-individual_max
#' @template param-winner_takes_all
#' @template field-data
#' @template field-converter
#' @template field-channels_first
#' @template field-dtype
#' @template field-ignore_last_act
#' @template field-result
#' @template field-output_idx
#' @template field-output_label
#' @template field-verbose
#' @template field-winner_takes_all
#'
#' @field preds (`list`)\cr
#' In this field, all calculated predictions are stored as a list of
#' `torch_tensor`s. Each output layer has its own list entry and contains
#' the respective predicted values.\cr
#' @field decomp_goal (`list`)\cr
#' In this field, the method-specific decomposition objectives are stored as
#' a list of `torch_tensor`s for each output layer. For example,
#' GradientxInput and LRP attempt to decompose the prediction into
#' feature-wise additive effects. DeepLift and IntegratedGradient decompose
#' the difference between \eqn{f(x)} and \eqn{f(x')}. On the other hand,
#' DeepSHAP and ExpectedGradient aim to decompose \eqn{f(x)} minus the
#' averaged prediction across the reference values.\cr
#'
InterpretingMethod <- R6Class(
  classname = "InterpretingMethod",
  public = list(
    data = NULL,
    converter = NULL,
    channels_first = NULL,
    dtype = NULL,
    winner_takes_all = TRUE,
    ignore_last_act = NULL,
    result = NULL,
    output_idx = NULL,
    output_label = NULL,
    verbose = NULL,
    preds = NULL,
    decomp_goal = NULL,

    #' @description
    #' Create a new instance of this super class.
    #'
    #' @param output_idx (`integer`, `list` or `NULL`)\cr
    #' These indices specify the output nodes for which the method is to be
    #' applied. In order to allow models with multiple output layers, there are
    #' the following possibilities to select the indices of the output
    #' nodes in the individual output layers:
    #' \itemize{
    #'   \item An `integer` vector of indices: If the model has only one output
    #'   layer, the values correspond to the indices of the output nodes, e.g.
    #'   `c(1,3,4)` for the first, third and fourth output node. If there are
    #'   multiple output layers, the indices of the output nodes from the first
    #'   output layer are considered.
    #'   \item A `list` of `integer` vectors of indices: If the method is to be
    #'   applied to output nodes from different layers, a list can be passed
    #'   that specifies the desired indices of the output nodes for each
    #'   output layer. Unwanted output layers have the entry `NULL` instead
    #'   of a vector of indices, e.g. `list(NULL, c(1,3))` for the first and
    #'   third output node in the second output layer.
    #'   \item `NULL` (default): The method is applied to all output nodes
    #'   in the first output layer but is limited to the first ten as the
    #'   calculations become more computationally expensive for more
    #'   output nodes.\cr
    #' }
    initialize = function(converter, data,
                          channels_first = TRUE,
                          output_idx = NULL,
                          output_label = NULL,
                          ignore_last_act = TRUE,
                          winner_takes_all = TRUE,
                          verbose = interactive(),
                          dtype = "float") {
      cli_check(checkClass(converter, "Converter"), "converter")
      self$converter <- converter

      cli_check(checkLogical(channels_first), "channels_first")
      self$channels_first <- channels_first

      cli_check(checkLogical(ignore_last_act), "ignore_last_act")
      self$ignore_last_act <- ignore_last_act

      cli_check(checkLogical(winner_takes_all), "winner_takes_all")
      self$winner_takes_all <- winner_takes_all

      cli_check(checkLogical(verbose), "verbose")
      self$verbose <- verbose

      cli_check(checkChoice(dtype, c("float", "double")), "dtype")
      self$dtype <- dtype
      self$converter$model$set_dtype(dtype)

      # Check output indices and labels
      outputs <- check_output_idx(output_idx, converter$output_dim,
                                  output_label, converter$output_names)
      self$output_idx <- outputs[[1]]
      self$output_label <- outputs[[2]]

      self$data <- private$test_data(data)
    },

    #' @description
    #' This function returns the result of this method for the given data
    #' either as an array (`'array'`), a torch tensor (`'torch.tensor'`,
    #' or `'torch_tensor'`) of size *(batch_size, dim_in, dim_out)* or as a
    #' data.frame (`'data.frame'`). This method is also implemented as a
    #' generic S3 function [`get_result`]. For a detailed description, we refer
    #' to our in-depth vignette (`vignette("detailed_overview", package = "innsight")`)
    #' or our [website](https://bips-hb.github.io/innsight/articles/detailed_overview.html#get-results).
    #'
    #' @param type (`character(1)`)\cr
    #' The data type of the result. Use one of `'array'`,
    #' `'torch.tensor'`, `'torch_tensor'` or `'data.frame'`
    #' (default: `'array'`).\cr
    #'
    #' @return The result of this method for the given data in the chosen
    #' type.
    get_result = function(type = "array") {
      cli_check(checkChoice(type, c("array", "data.frame", "torch.tensor",
                           "torch_tensor")), "type")

      # Get the result as an array
      if (type == "array") {
        # Get the input names and move the channel dimension (if necessary)
        input_names <- self$converter$input_names
        if (!self$channels_first) {
          input_names <- move_channels_last(input_names)
        }
        # Convert the torch_tensor result into a named array
        result <- tensor_list_to_named_array(
          self$result, input_names, self$converter$output_names,
          self$output_idx)
      } else if (type == "data.frame") {
        # Get the result as a data.frame
        # The function 'create_dataframe_from_result' assumes the channels
        # first format
        result <- self$result
        if (self$channels_first == FALSE) {
          FUN <- function(result, out_idx, in_idx) {
            res <- result[[out_idx]][[in_idx]]
            if (res$dim() > 1) {
              res <- torch_movedim(res, source = -2, destination = 2)
            }

            res
          }
          result <- apply_results(result, FUN)
        }

        # Prepare predictions
        prepare_preds <- function(x) {
          if (is.null(unlist(x))) {
            NULL
          } else {
            lapply(seq_along(x), function(i) as.array(x[[i]]))
          }
        }
        preds <- prepare_preds(self$preds)
        decomp_goals <- prepare_preds(self$decomp_goal)

        # Convert the torch_tensor result into a data.frame
        result <- create_dataframe_from_result(
          seq_len(dim(self$data[[1]])[1]), result,
          self$converter$input_names, self$converter$output_names,
          self$output_idx, preds, decomp_goals)

        # Remove unnecessary columns
        if (all(result$input_dimension <= 2)) {
          result$feature_2 <- NULL
        }
        if (all(result$input_dimension <= 1)) {
          result$channel <- NULL
        }
      } else {
        # Get the result as a torch_tensor and remove unnecessary axis
        num_inputs <- length(self$converter$input_names)
        out_null_idx <- unlist(lapply(self$output_idx, is.null))
        out_nonnull_idx <-
          seq_along(self$converter$output_names)[!out_null_idx]
        result <- self$result
        # Name the inner list or remove the inner list
        for (out_idx in seq_along(result)) {
          if (num_inputs == 1) {
            result[[out_idx]] <- result[[out_idx]][[1]]
          } else {
            names(result[[out_idx]]) <- paste0("Input_", seq_len(num_inputs))
          }
        }
        # Name the outer list or remove it
        if (length(self$output_idx) == 1) {
          result <- result[[1]]
        } else {
          names(result) <- paste0("Output_", out_nonnull_idx)
        }
      }

      result
    },

    #' @description
    #' This method visualizes the result of the selected
    #' method and enables a visual in-depth investigation with the help
    #' of the S4 classes [`innsight_ggplot2`] and [`innsight_plotly`].\cr
    #' You can use the argument `data_idx` to select the data points in the
    #' given data for the plot. In addition, the individual output nodes for
    #' the plot can be selected with the argument `output_idx`. The different
    #' results for the selected data points and outputs are visualized using
    #' the ggplot2-based S4 class `innsight_ggplot2`. You can also use the
    #' `as_plotly` argument to generate an interactive plot with
    #' `innsight_plotly` based on the plot function [plotly::plot_ly]. For
    #' more information and the whole bunch of possibilities,
    #' see [`innsight_ggplot2`] and [`innsight_plotly`].\cr
    #' \cr
    #' **Notes:**
    #' 1. For the interactive plotly-based plots, the suggested package
    #' `plotly` is required.
    #' 2. The ggplot2-based plots for models with multiple input layers are
    #' a bit more complex, therefore the suggested packages `'grid'`,
    #' `'gridExtra'` and `'gtable'` must be installed in your R session.
    #' 3. If the global *Connection Weights* method was applied, the
    #' unnecessary argument `data_idx` will be ignored.
    #' 4. The predictions, the sum of relevances, and, if available, the
    #' decomposition target are displayed by default in a box within the plot.
    #' Currently, these are not generated for `plotly` plots.
    #'
    #' @param data_idx (`integer`)\cr
    #' An integer vector containing the numbers of the data
    #' points whose result is to be plotted, e.g., `c(1,3)` for the first
    #' and third data point in the given data. Default: `1`. This argument
    #' will be ignored for the global *Connection Weights* method.\cr
    #' @param output_idx (`integer`, `list` or `NULL`)\cr
    #' The indices of the output nodes for which the results
    #' is to be plotted. This can be either a `integer` vector of indices or a
    #' `list` of `integer` vectors of indices but must be a subset of the indices for
    #' which the results were calculated, i.e., a subset of `output_idx` from the
    #' initialization `new()` (see argument `output_idx` in method `new()` of
    #' this R6 class for details). By default (`NULL`), the smallest index
    #' of all calculated output nodes and output layers is used.\cr
    #' @param same_scale (`logical`)\cr
    #' A logical value that specifies whether the individual plots have the
    #' same fill scale across multiple input layers or whether each is
    #' scaled individually. This argument is only used if more than one input
    #' layer results are plotted.\cr
    #' @param show_preds (`logical`)\cr
    #' This logical value indicates whether the plots display the prediction,
    #' the sum of calculated relevances, and, if available, the targeted
    #' decomposition value. For example, in the case of GradientxInput, the
    #' goal is to obtain a decomposition of the predicted value, while for
    #'  DeepLift and IntegratedGradient, the goal is the difference between
    #'  the prediction and the reference value, i.e., \eqn{f(x) - f(x')}.\cr
    #'
    #' @return
    #' Returns either an [`innsight_ggplot2`] (`as_plotly = FALSE`) or an
    #' [`innsight_plotly`] (`as_plotly = TRUE`) object with the plotted
    #' individual results.
    #'
    plot = function(data_idx = 1,
                    output_idx = NULL,
                    output_label = NULL,
                    aggr_channels = "sum",
                    as_plotly = FALSE,
                    same_scale = FALSE,
                    show_preds = TRUE) {

      # Get method-specific arguments -----------------------------------------
      if (inherits(self, "ConnectionWeights")) {
        if (!self$times_input) {
          if (!identical(data_idx, 1)) {
            messagef(
              "Without the 'times_input' argument, the method ",
              "'ConnectionWeights' is a global method, therefore no individual",
              " data instances can be plotted. But you passed the argument ",
              "'data_idx': 'c(", paste(data_idx, collapse = ", "), ")'!",
              "\nThe argument 'data_idx' will be ignored in the following!"
            )
          }
          data_idx <- 1
          self$data <- list(array(0, dim = c(1, 1)))
          include_data <- TRUE
        } else {
          include_data <- FALSE
        }
        value_name <- "Relative Importance"
      } else if (inherits(self, "LRP")) {
        value_name <- "Relevance"
        include_data <- TRUE
      } else if (inherits(self, c("DeepLift", "DeepSHAP"))) {
        value_name <- "Contribution"
        include_data <- TRUE
      } else if (inherits(self, "GradientBased")) {
        value_name <- if(self$times_input) "Relevance" else "Gradient"
        include_data <- TRUE
      } else if (inherits(self, "LIME")) {
        value_name <- "Weight"
        include_data <- TRUE
      } else if (inherits(self, "SHAP")) {
        value_name <- "Shapley Value"
        include_data <- TRUE
      }

      # Check correctness of arguments ----------------------------------------
      cli_check(
        checkIntegerish(data_idx, lower = 1, upper = dim(self$data[[1]])[1]),
        "data_idx")
      cli_check(checkLogical(as_plotly), "as_plotly")
      cli_check(checkLogical(same_scale), "same_scale")

      # Set aggregation function for channels
      aggr_channels <- get_aggr_function(aggr_channels)

      # Check for output_label and output_idx
      if (!is.null(output_label) & is.null(output_idx)) {
        output_idx <- transform_label_to_idx(output_label,
                                             self$converter$output_names)
      } else if (!is.null(output_label) & !is.null(output_idx)) {
        warningf("You passed non-{.code NULL} values for the arguments ",
                 "{.arg output_label} and {.arg output_idx}. Both are used ",
                 "to specify the output nodes to be plotted. In the ",
                 "following, only the values of {.arg output_idx} are used!")
      }
      output_idx <- check_output_idx_for_plot(output_idx, self$output_idx)

      # Get only the relevant results -----------------------------------------
      # Get only relevant model output layer
      null_idx <- unlist(lapply(output_idx, is.null))
      result <- self$result[!null_idx]
      preds <- if (length(self$preds) != 0) self$preds[!null_idx] else NULL
      decomp_goals <- if (length(self$decomp_goal) != 0) self$decomp_goal[!null_idx] else NULL

      # Get the relevant output and class node indices
      # This is done by matching the given output indices ('output_idx') with
      # the calculated output indices ('self$output_idx'). Afterwards,
      # all non-relevant output indices are removed
      idx_matches <- lapply(
        seq_along(output_idx),
        function(i) match(output_idx[[i]], self$output_idx[[i]]))[!null_idx]

      # Get only the relevant results, predictions and decomposition goals
      result <- apply_results(result, aggregate_channels, idx_matches,
                              data_idx, self$channels_first, aggr_channels)
      if (show_preds) {
        fun <- function(i, x) {
          as.array(x[[i]][data_idx, idx_matches[[i]], drop = FALSE])
        }
        if (!is.null(unlist(preds))) {
          preds <- lapply(seq_along(preds), fun, x = preds)
        } else {
          preds = NULL
        }
        if (!is.null(unlist(decomp_goals))) {
          decomp_goals <- lapply(seq_along(decomp_goals), fun, x = decomp_goals)
        } else {
          decomp_goals = NULL
        }
      } else {
        preds = NULL
        decomp_goals = NULL
      }


      # Get and modify input names
      input_names <- lapply(self$converter$input_names, function(in_name) {
        if (length(in_name) > 1) {
          in_name[[1]] <- "aggregated"
        }
        in_name
      })

      # Create the data.frame with all information necessary for the plot
      result_df <- create_dataframe_from_result(
        data_idx, result, input_names, self$converter$output_names, output_idx,
        preds, decomp_goals)

      # Get plot
      if (as_plotly) {
        p <- create_plotly(result_df, value_name, include_data, FALSE, NULL,
                           same_scale)
      } else {
        p <- create_ggplot(result_df, value_name, include_data, FALSE, NULL,
                           same_scale, show_preds)
      }

      p
    },

    #' @description
    #' This method visualizes the results of the selected method summarized as
    #' boxplots/median image and enables a visual in-depth investigation of the global
    #' behavior with the help of the S4 classes [`innsight_ggplot2`] and
    #' [`innsight_plotly`].\cr
    #' You can use the argument `output_idx` to select the individual output
    #' nodes for the plot. For tabular and 1D data, boxplots are created in
    #' which a reference value can be selected from the data using the
    #' `ref_data_idx` argument. For images, only the pixel-wise median is
    #' visualized due to the complexity. The plot is generated using the
    #' ggplot2-based S4 class `innsight_ggplot2`. You can also use the
    #' `as_plotly` argument to generate an interactive plot with
    #' `innsight_plotly` based on the plot function [plotly::plot_ly]. For
    #' more information and the whole bunch of possibilities, see
    #' [`innsight_ggplot2`] and [`innsight_plotly`].\cr \cr
    #' **Notes:**
    #' 1. This method can only be used for the local *Connection Weights*
    #' method, i.e., if `times_input` is `TRUE` and `data` is provided.
    #' 2. For the interactive plotly-based plots, the suggested package
    #' `plotly` is required.
    #' 3. The ggplot2-based plots for models with multiple input layers are
    #' a bit more complex, therefore the suggested packages `'grid'`,
    #' `'gridExtra'` and `'gtable'` must be installed in your R session.
    #'
    #' @param output_idx (`integer`, `list` or `NULL`)\cr
    #' The indices of the output nodes for which the
    #' results is to be plotted. This can be either a `vector` of indices or
    #' a `list` of vectors of indices but must be a subset of the indices for
    #' which the results were calculated, i.e., a subset of `output_idx` from
    #' the initialization `new()` (see argument `output_idx` in method `new()`
    #' of this R6 class for details). By default (`NULL`), the smallest index
    #' of all calculated output nodes and output layers is used.\cr
    #' @param data_idx (`integer`)\cr
    #' By default, all available data points are used
    #' to calculate the boxplot information. However, this parameter can be
    #' used to select a subset of them by passing the indices. For example, with
    #' `c(1:10, 25, 26)` only the first 10 data points and
    #' the 25th and 26th are used to calculate the boxplots.\cr
    #'
    #' @return
    #' Returns either an [`innsight_ggplot2`] (`as_plotly = FALSE`) or an
    #' [`innsight_plotly`] (`as_plotly = TRUE`) object with the plotted
    #' summarized results.
    plot_global = function(output_idx = NULL,
                           output_label = NULL,
                           data_idx = "all",
                           ref_data_idx = NULL,
                           aggr_channels = "sum",
                           preprocess_FUN = abs,
                           as_plotly = FALSE,
                           individual_data_idx = NULL,
                           individual_max = 20) {

      # Get method-specific arguments -----------------------------------------
      if (inherits(self, "ConnectionWeights")) {
        if (!self$times_input) {
          stopf(
            "Only if the result of the {.emph ConnectionWeights} method is ",
            "multiplied by the data ({.arg times_input} = TRUE), it is a local ",
            "method and only then boxplots can be generated over multiple ",
            "instances. Thus, the argument {.arg data} must be specified and ",
            "{.arg times_input} = TRUE when applying the ",
            "{.code ConnectionWeights$new} method.")
        }
        value_name <- "Relative Importance"
      } else if (inherits(self, "LRP")) {
        value_name <- "Relevance"
      } else if (inherits(self, c("DeepLift", "DeepSHAP"))) {
        value_name <- "Contribution"
      } else if (inherits(self, "GradientBased")) {
        value_name <- if(self$times_input) "Relevance" else "Gradient"
      } else if (inherits(self, "LIME")) {
        value_name <- "Weight"
      } else if (inherits(self, "SHAP")) {
        value_name <- "Shapley Value"
      }

      # Check correctness of arguments ----------------------------------------
      # data_idx
      num_data <- dim(self$data[[1]])[1]
      if (identical(data_idx, "all")) {
        data_idx <- seq_len(num_data)
      }
      cli_check(checkIntegerish(data_idx, lower = 1, upper = num_data,
                                any.missing = FALSE), "data_idx")
      # ref_data_idx
      cli_check(
        checkInt(ref_data_idx, lower = 1, upper = num_data, null.ok = TRUE),
        "ref_data_idx")
      # aggr_channels
      aggr_channels <- get_aggr_function(aggr_channels)
      # preprocess_FUN
      cli_check(checkFunction(preprocess_FUN), "preprocess_FUN")
      # as_plotly
      cli_check(checkLogical(as_plotly), "as_plotly")
      # individual_data_idx
      cli_check(
        checkIntegerish(individual_data_idx, lower = 1, upper = num_data,
                        null.ok = TRUE, any.missing = FALSE),
        "individual_data_idx")
      if (is.null(individual_data_idx)) individual_data_idx <- seq_len(num_data)
      # individual_max
      cli_check(checkInt(individual_max, lower = 1), "individual_max")
      individual_max <- min(individual_max, num_data)

      # Check for output_label and output_idx
      if (!is.null(output_label) & is.null(output_idx)) {
        output_idx <- transform_label_to_idx(output_label,
                                             self$converter$output_names)
      } else if (!is.null(output_label) & !is.null(output_idx)) {
        warningf("You passed non-{.code NULL} values for the arguments ",
                 "{.arg output_label} and {.arg output_idx}. Both are used ",
                 "to specify the output nodes to be plotted. In the ",
                 "following, only the values of {.arg output_idx} are used!")
      }
      output_idx <- check_output_idx_for_plot(output_idx, self$output_idx)

      # Set the individual instances for the plot
      if (!as_plotly) {
        individual_idx <- ref_data_idx
      } else {
        individual_idx <- unique(
          c(ref_data_idx, individual_data_idx[seq_len(individual_max)]))
        individual_idx <- individual_idx[!is.na(individual_idx)]
      }

      # Get only the relevant results -----------------------------------------
      # Get only relevant model outputs
      null_idx <- unlist(lapply(output_idx, is.null))
      result <- self$result[!null_idx]

      # Get the relevant output and class node indices
      # This is done by matching the given output indices ('output_idx') with
      # the calculated output indices ('self$output_idx'). Afterwards,
      # all non-relevant output indices are removed
      idx_matches <- lapply(
        seq_along(output_idx),
        function(i) match(output_idx[[i]], self$output_idx[[i]]))[!null_idx]

      # Apply preprocess function ---------------------------------------------
      preprocess <- function(result, out_idx, in_idx, idx_matches) {
        res <- result[[out_idx]][[in_idx]]
        if (is.null(res)) {
          res <- NULL
        } else {
          res <- preprocess_FUN(res)
        }

        res
      }
      result <- apply_results(result, preprocess, idx_matches)

      # Get and modify input names
      input_names <- lapply(self$converter$input_names, function(in_name) {
        if (length(in_name) > 1) {
          in_name[[1]] <- "aggregated"
        }
        in_name
      })

      idx <- sort(unique(c(individual_idx, data_idx)))

      # Create boxplot data and plots -----------------------------------------
      result <-
        apply_results(result, aggregate_channels, idx_matches, idx,
                      self$channels_first, aggr_channels)
      result_df <- create_dataframe_from_result(
        idx, result, input_names, self$converter$output_names, output_idx)

      idx <- as.numeric(gsub("data_", "", as.character(result_df$data)))
      result_df$boxplot_data <- ifelse(idx %in% data_idx, TRUE, FALSE)
      result_df$individual_data <- ifelse(idx %in% individual_idx, TRUE, FALSE)

      # Get plot
      if (as_plotly) {
        p <- create_plotly(result_df, value_name, FALSE, TRUE, ref_data_idx,
                           TRUE)
      } else {
        p <- create_ggplot(result_df, value_name, FALSE, TRUE, ref_data_idx,
                           TRUE)
      }

      p
    },

    #' @description
    #' Print a summary of the method object. This summary contains the
    #' individual fields and in particular the results of the applied method.
    #'
    #' @return Returns the method object invisibly via [`base::invisible`].
    #'
    print = function() {
      cli_h1(paste0("Method {.emph ", class(self)[1], "} ({.pkg innsight})"))
      cat("\n")

      cli_div(theme = list(ul = list(`margin-left` = 2, before = ""),
                           dl = list(`margin-left` = 2, before = "")))
      cli_text("{.strong Fields} (method-specific):")
      private$print_method_specific()
      cat("\n")

      cli_text("{.strong Fields} (other):")
      i <- cli_ul()
      print_output_idx(self$output_idx, self$converter$output_names)
      cli_li(paste0("{.field ignore_last_act}:  ", self$ignore_last_act))
      cli_li(paste0("{.field channels_first}:  ", self$channels_first))
      cli_li(paste0("{.field dtype}:  '", self$dtype, "'"))
      cli_end(id = i)

      cli_h2("{.strong Result} ({.field result})")
      print_result(self$result)

      cli_h1("")

      invisible(self)
    }
  ),
  private = list(

    # ----------------------- backward Function -------------------------------
    run = function(method_name, reset = TRUE) {

      # Declare vector for relevances for each output node
      rel_list <- vector(mode = "list",
                         length = length(self$converter$model$output_nodes))
      # Declare vector for the predictions for each output node
      pred_list <- vector(mode = "list",
                          length = length(self$converter$model$output_nodes))
      # Declare vector for the decomposition goal for each output node
      decomp_list <- vector(mode = "list",
                            length = length(self$converter$model$output_nodes))

      if (self$verbose) {
        #messagef("Backward pass '", method_name, "':")
        # Define Progressbar
        cli_progress_bar(name = paste0("Backward pass '", method_name, "'"),
                         total = length(self$converter$model$graph),
                         type = "iterator", clear = FALSE)
      }

      # We go through the graph in reversed order
      for (step in rev(self$converter$model$graph)) {
        # set the rule name
        rule_name <- self$rule_name

        # Get the current layer
        layer <- self$converter$model$modules_list[[step$used_node]]

        # Get the upper layer relevance ...
        # ... for an output layer
        if (step$used_node %in% self$converter$model$output_nodes) {
          # check if current node is required in 'self$output_idx'
          # get index of the current layer in 'self$output_idx'
          idx <- match(step$used_node,
                       self$converter$model$output_nodes)
          # The current node is not required, i.e. we do not need to calculate
          # relevances for this output
          if (is.null(self$output_idx[[idx]])) {
            rel <- NULL
            pred <- NULL
          } else {
            # Otherwise ...
            # get the corresponding output depending on the argument
            # 'ignore_last_act'
            if (self$ignore_last_act) {
              out <- layer$preactivation
            } else {
              out <- layer$output

              # For probabilistic output we need to subtract 0.5, such that
              # 0 means no relevance
              if (method_name == "LRP" & layer$activation_name %in%
                  c("softmax", "sigmoid", "logistic")) {
                out <- out - 0.5
              }
            }

            # Save the prediction
            pred <- layer$output
            decomp_goal <- out
            if (method_name %in% c("DeepLift", "DeepSHAP")) {
              if (self$ignore_last_act) {
                pred_ref <- layer$preactivation_ref
              } else {
                pred_ref <- layer$output_ref
              }
              num_samples <- dim(self$data[[1]])[1]
              decomp_goal <- torch_stack((out - pred_ref)$chunk(num_samples),
                                  dim = 1)$mean(2)
              pred <- torch_stack(pred$chunk(num_samples), dim = 1)$mean(2)
            }
            pred_list[[idx]] <- pred[, self$output_idx[[idx]], drop = FALSE]
            decomp_list[[idx]] <- decomp_goal[, self$output_idx[[idx]], drop = FALSE]

            # For DeepLift, we only need ones
            if (method_name %in% c("DeepLift", "DeepSHAP")) {
              rel <- torch_diag_embed(torch_ones_like(out))
              # Overwrite rule name
              if (self$ignore_last_act) {
                rule_name <- "ignore_last_act"
              }
            } else if (method_name == "Connection-Weights") {
              if (self$dtype == "float") {
                rel <- torch_diag_embed(torch_ones(c(1, layer$output_dim)))
              } else {
                rel <- torch_diag_embed(
                  torch_ones(c(1, layer$output_dim), dtype = torch_double()))
              }

            } else {
              rel <- torch_diag_embed(out)
            }

            # Get necessary output nodes and fill up with zeros
            #
            # We flatten the list of outputs and put the corresponding outputs
            # into the last axis of the relevance tensor, e.g. we have
            # output_idx = list(c(1), c(2,4,5)) and the current layer
            # (of shape (10,4)) corresponds to the first entry (c(1)), then
            # we concatenate the output of this layer (shape (10,1)) and
            # three times the same tensor with zeros (shape (10,3) )
            tensor_list <- list()
            for (i in seq_along(self$output_idx)) {
              out_idx <- self$output_idx[[i]]
              # if current layer, use the true output/preactivation and only
              # relevant output nodes
              if (i == idx) {
                tensor_list <-
                  append(tensor_list, list(rel[, , out_idx, drop = FALSE]))
              } else if (!is.null(out_idx)) {
                # otherwise, create for each output node a tensor of zeros
                dims <- c(rel$shape[-length(rel$shape)], length(out_idx))
                tensor_list <- append(tensor_list, list(torch_zeros(dims)))
              }
            }
            # concatenate all together
            rel <- torch_cat(tensor_list, dim = -1)
          }
        } else {
          # ... or a normal layer
          # Get relevant entries from 'rel_list' for the current layer
          rel <- rel_list[seq_len(step$times) + min(step$used_idx) - 1]
          if (step$times == 1) {
            rel <- rel[[1]]
          } else {
            # If more than one output for this layer was created, we sum up
            # all relevances from the corresponding upper nodes
            result <- 0
            for (res in rel) {
              if (!is.null(res)) {
                result <- result + res
              }
            }
            rel <- result
          }
        }

        # Remove the used relevances from 'rel_list'
        rel_list <- rel_list[-(seq_len(step$times) + min(step$used_idx) - 1)]

        # Apply the LRP method for the current layer and reset the layer
        # afterwards
        if (!is.null(rel)) {
          if (method_name == "LRP") {
            lrp_rule <-
              get_lrp_rule(self$rule_name, self$rule_param, class(layer)[1])
            rel <- layer$get_input_relevances(rel, rule_name = lrp_rule$rule_name,
                                              rule_param = lrp_rule$rule_param,
                                              winner_takes_all = self$winner_takes_all)
          } else if (method_name %in% c("DeepLift", "DeepSHAP")) {
            rel <- layer$get_input_multiplier(rel, rule_name = rule_name,
                                              winner_takes_all = self$winner_takes_all,
                                              use_grad_near_zero = TRUE)
          } else if (method_name == "Connection-Weights") {
            rel <- layer$get_gradient(rel, weight = layer$W,
                                      use_avgpool = !self$winner_takes_all)
          }
        }

        if (reset) layer$reset()

        # Transform it back to a list
        if (!is.list(rel)) {
          rel <- list(rel)
        }

        # Save the lower-layer relevances in the list 'rel_list' in the
        # required order
        order <- order(step$used_idx)
        ordered_idx <- step$used_idx[order]
        rel_ordered <- rel[order]
        for (i in seq_along(step$used_idx)) {
          rel_list <-
            append(rel_list, rel_ordered[i], after = ordered_idx[i] - 1)
        }

        if (self$verbose) {
          # Update progress bar
          cli_progress_update(force = TRUE)
        }
      }

      # Save output predictions
      self$preds <- pred_list
      self$decomp_goal <- decomp_list

      if (self$verbose) cli_progress_done()

      # If necessary, move channels last
      if (self$channels_first == FALSE) {
        rel_list <- lapply(
          rel_list,
          function(x) torch_movedim(x, source = 2, destination = -2))
      }

      # As mentioned above, the results of the individual output nodes are
      # stored in the last dimension of the results for each input. Hence,
      # we need to transform it back to the structure: outer list (model output)
      # and inner list (model input)
      result <- list()
      sum_nodes <- 0
      for (i in seq_along(self$output_idx)) {
        if (!is.null(self$output_idx[[i]])) {
          index <- seq_len(length(self$output_idx[[i]])) + sum_nodes
          res_output_i <- lapply(rel_list, torch_index_select, dim = -1,
                                 index = as.integer(index))
          result <- append(result, list(res_output_i))

          sum_nodes <- sum_nodes + length(self$output_idx[[i]])
        }
      }

      # For the DeepLift method, we only get the multiplier.
      # Hence, we have to multiply this by the differences of inputs
      if (method_name %in% c("DeepLift")) {
        fun <- function(result, out_idx, in_idx, x, x_ref) {
          res <- result[[out_idx]][[in_idx]]
          if (is.null(res)) {
            res <- NULL
          } else {
            res <- res * (x[[in_idx]] - x_ref[[in_idx]])$unsqueeze(-1)
          }
        }
        result <- apply_results(result, fun, x = self$data, x_ref = self$x_ref)
      }

      result
    },

    # ----------------------- Test data ----------------------------------

    test_data = function(data, name = "data") {
      if (missing(data)) {
        stopf("Argument {.arg data} is missing!")
      }
      if (!is.list(data) | is.data.frame(data)) {
        data <- list(data)
      }

      lapply(seq_along(data), function(i) {
        input_data <- data[[i]]
        input_data <- tryCatch({
          if (is.data.frame(input_data)) {
            input_data <- as.matrix(input_data)
          }
          input_data <- as.array(input_data)
          assertNumeric(input_data)
          input_data
        },
        error = function(e) {
          stopf("Failed to convert the argument {.arg ", name,
               "[[", i, "]]} to a numeric array ",
               "using the function {.fn base::as.array}. The class of your ",
               "argument {.arg ", name, "[[", i, "]]}: '",
               paste(class(input_data), collapse = "', '"), "'",
               " (of type: '", paste(typeof(input_data), collapse = "', '"), "')")
        })

        ordered_dim <- self$converter$input_dim[[i]]
        if (!self$channels_first) {
          channels <- ordered_dim[1]
          ordered_dim <- c(ordered_dim[-1], channels)
        }

        if (length(dim(input_data)[-1]) != length(ordered_dim) ||
            !all(dim(input_data)[-1] == ordered_dim)) {
          stopf(
            "Unmatch in model input dimension (*, ",
            paste0(ordered_dim, collapse = ", "), ") and dimension of ",
            "argument {.arg ", name, "[[", i, "]]} (",
            paste0(dim(input_data), collapse = ", "),
            "). Try to change the argument {.arg channels_first}, if only ",
            "the channels are wrong."
          )
        }


        if (self$dtype == "float") {
          input_data <- torch_tensor(input_data, dtype = torch_float())
        } else {
          input_data <- torch_tensor(input_data, dtype = torch_double())
        }

        input_data
      })
    },

    print_method_specific = function() {
      NULL
    }
  )
)


#' Get the result of an interpretation method
#'
#' This is a generic S3 method for the R6 method
#' `InterpretingMethod$get_result()`. See the respective method described in
#' [`InterpretingMethod`] for details.
#'
#' @param x An object of the class [`InterpretingMethod`] including the
#' subclasses [`Gradient`], [`SmoothGrad`], [`LRP`], [`DeepLift`],
#' [`DeepSHAP`], [`IntegratedGradient`], [`ExpectedGradient`] and
#' [`ConnectionWeights`].
#' @param ... Other arguments specified in the R6 method
#' `InterpretingMethod$get_result()`. See [`InterpretingMethod`] for details.
#'
#' @export
get_result <- function(x, ...) UseMethod("get_result", x)

#' @exportS3Method
get_result.InterpretingMethod <- function(x, ...) {
  x$get_result(...)
}


#'
#' @importFrom graphics boxplot
#' @exportS3Method
#'
boxplot.InterpretingMethod <- function(x, ...) {
  dims <- unlist(lapply(x$converter$input_dim, length))
  if (any(dims > 2)) {
    warningf("The {.fn boxplot} function is only intended for tabular or signal ",
             "data. It is called {.fn plot_global} instead. ")
  }
  x$plot_global(...)
}


#' Get the result of an interpretation method
#'
#' This is a generic S3 method for the R6 method
#' `InterpretingMethod$plot_global()`. See the respective method described in
#' [`InterpretingMethod`] for details.
#'
#' @param x An object of the class [`InterpretingMethod`] including the
#' subclasses [`Gradient`], [`SmoothGrad`], [`LRP`], [`DeepLift`],
#' [`DeepSHAP`], [`IntegratedGradient`], [`ExpectedGradient`] and
#' [`ConnectionWeights`].
#' @param ... Other arguments specified in the R6 method
#' `InterpretingMethod$plot_global()`. See [`InterpretingMethod`] for details.
#'
#' @export
plot_global <- function(x, ...) UseMethod("plot_global", x)

#' @exportS3Method
plot_global.InterpretingMethod <- function(x, ...) {
  x$plot_global(...)
}

###############################################################################
#                           print utility functions
###############################################################################
print_result <- function(result) {
  num_outlayers <- length(result)
  num_inlayers <- length(result[[1]])

  for (i in seq_along(result)) {
    if (num_outlayers > 1) cli_text(paste0("{.strong Output layer ", i, ":}"))
    for (j in seq_along(result[[i]])) {
      if (num_inlayers > 1) {
        in_l <- cli_ul()
        cli_li(paste0("Input layer ", j, ":"))
      }
      if (is.null(result[[i]][[j]])) {
        items <- paste0(col_cyan(symbol$i), " {.emph (not connected to output layer ", i, ")}")
        cli_bullets(c(" " = items))
      } else {
        items <- list(
          paste0("(", paste0(result[[i]][[j]]$shape, collapse = ", "), ")"),
          paste0(paste0(c("min: ", "median: ", "max: "),
                        signif(as_array(result[[i]][[j]]$quantile(c(0,0.5,1))))),
                 collapse = ", "),
          as_array(result[[i]][[j]]$isnan()$sum())
        )

        names(items) <- paste0(symbol$line,
                               c(" Shape", " Range", " Number of NaN values"))
        cli_dl(items)
      }
      if (num_inlayers > 1) cli_end(in_l)
    }
  }
}

print_output_idx <- function(output_idx, out_names) {
  draw_layer <- if (length(output_idx) > 1) TRUE else FALSE

  if (draw_layer) {
    cli_li("{.field output_idx}:")
    layer_list <- cli_ul()
  }

  for (i in seq_along(output_idx)) {
    if (draw_layer) {
      prefix <- paste0("Output layer ", i, ": {.emph ")
    } else {
      prefix <- "{.field output_idx}: {.emph "
    }

    if (is.null(output_idx[[i]])) {
      output_idx[[i]] <- "not applied!"
      labels <- ""
    } else {
      labels <- paste0(
        " (", symbol$arrow_right, " corresponding labels: {.emph '",
        paste0(out_names[[i]][[1]][output_idx[[i]]], collapse = "'}, {.emph '"),
        "'})")
    }

    cli_li(paste0(
      prefix,
      paste0(output_idx[[i]], collapse = "}, {.emph "), "}", labels))
  }
}


###############################################################################
#                                 Utils
###############################################################################

check_output_idx <- function(output_idx, output_dim, output_label, output_names) {
  # Check the output indices --------------------------
  # Check if output_idx is a single vector
  if (testIntegerish(output_idx,
                     lower = 1,
                     upper = output_dim[[1]])) {
    output_idx <- list(output_idx)
  # Check if it's a list of vectors
  } else if (testList(output_idx, max.len = length(output_dim))) {
    # the argument output_idx is a list of output_nodes for each output
    n <- 1
    for (output in output_idx) {
      limit <- output_dim[[n]]
      cli_check(checkInt(limit), "limit")
      if (!testIntegerish(output, lower = 1, upper = limit, null.ok = TRUE)) {
        stopf("Assertion on {.arg output_idx[[", n, "]]} failed: Value(s) ",
              paste(output, collapse = ","), " not <= ", limit, ".")
      }
      n <- n + 1
    }
  } else if (!is.null(output_idx)) {
    stopf("The argument {.arg output_idx} has to be either a vector with maximum ",
          "value of '", output_dim[[1]], "' or a list of length '",
          length(output_dim), "' with maximal values of '",
          paste(unlist(output_dim), collapse = ","), "'.")
  }

  # Check the output labels -------------------------
  # Check if output_label is a single vector
  if (testCharacter(output_label, min.len = 1, max.len = output_dim[[1]]) ||
      testFactor(output_label, min.len = 1, max.len = output_dim[[1]])) {
    # Check if labels are a subset of output_names
    cli_check(checkSubset(as.factor(output_label), unlist(output_names[[1]])),
              "output_label")
    output_label <- list(output_label)
    # Check if it's a list of vectors
  } else if (testList(output_label, max.len = length(output_names))) {
    # the argument output_label is a list of names for each output
    n <- 1
    for (output in output_label) {
      # Check if labels are a subset of output_names
      cli_check(checkSubset(as.factor(unlist(output)), unlist(output_names[[n]])),
                "output_label")
      n <- n + 1
    }
  } else if (!is.null(output_label)) {
    stopf("The argument {.arg output_label} has to be either a vector of ",
          "characters/factors or a list of vectors of characters/factors!")
  }

  if (is.null(output_idx) && is.null(output_label)) {
    output_idx <- list(1:min(10, output_dim[[1]]))
    output_label <- list(output_names[[1]][[1]][output_idx[[1]]])
  } else if (is.null(output_idx) && !is.null(output_label)) {
    output_idx <- list()
    for (i in seq_along(output_label)) {
      output_idx[[i]] <- match(output_label[[i]], output_names[[i]][[1]])
      if (length(output_idx[[i]]) == 0) output_idx[[i]] <- NULL
    }
  } else if (is.null(output_label) && !is.null(output_idx)) {
    output_label <- list()
    for (i in seq_along(output_idx)) {
      output_label[[i]] <- output_names[[i]][[1]][output_idx[[i]]]
    }
  }

  # Fill up with NULLs
  num_layers <- length(output_dim)
  if (length(output_idx) < num_layers) {
    output_idx <- append(output_idx,
                         rep(list(NULL), num_layers - length(output_idx)))
  }
  if (length(output_label) < num_layers) {
    output_label <- append(output_label,
                           rep(list(NULL), num_layers - length(output_label)))
  }

  # Check if both are consistent
  for (i in seq_along(output_dim)) {
    if (testTRUE(length(output_idx[[i]]) != length(output_label[[i]]))) {
      stopf("Both the {.arg output_idx} and {.arg output_label} arguments ",
            "were passed (i.e., not {.code NULL}). However, they do not ",
            "match and point to different output nodes.")
    }

    # Get labels from output_idx
    labels <- output_names[[i]][[1]][output_idx[[i]]]
    labels_ref <- as.factor(output_label[[i]])
    if (length(labels) == 0) labels <- NULL
    if (length(labels_ref) == 0) labels_ref <- NULL
    if (!testSetEqual(labels, labels_ref)) {
      stopf("Both the {.arg output_idx} and {.arg output_label} arguments ",
            "were passed (i.e., not {.code NULL}). However, they do not ",
            "match and point to different output nodes.")
    }
  }

  list(output_idx, output_label)
}


tensor_list_to_named_array <- function(torch_result, input_names, output_names,
                                       output_idx) {
  # get the indices of the output for which we have and haven't calculated
  # attribution values
  out_null_idx <- unlist(lapply(output_idx, is.null))
  out_nonnull_idx <- seq_along(output_names)[!out_null_idx]

  # select only relevant output indices and output names
  output_idx <- output_idx[out_nonnull_idx]
  output_names <- output_names[out_nonnull_idx]

  # 'torch_result' is a list (with output layer indices) of list (input layer
  # indices) and the inner list contains the corresponding result of the
  # respective output and input layer combination
  result <- lapply(
    # for each output layer
    seq_along(torch_result),
    function(out_idx) {
      result_i <- lapply(
        # and for each input layer
        seq_along(torch_result[[out_idx]]),
        function(in_idx) {
          # get the corresponding result
          result_ij <- torch_result[[out_idx]][[in_idx]]
          # if the output layer isn't connected to the input layer, we set
          # the value NaN
          if (is.null(result_ij)) {
            result_ij <- NaN
          } else {
            # otherwise convert the result to an array and set dimnames
            result_ij <- as_array(result_ij)
            in_name <- input_names[[in_idx]]
            out_name <-
              list(output_names[[out_idx]][[1]][output_idx[[out_idx]]])
            names <- append(list(NULL), in_name)
            names <- append(names, out_name)
            dimnames(result_ij) <- names
          }

          result_ij
        }
      )
      # Skip one list dimension if there is only one input layer, otherwise
      # set the names of the list entries
      if (length(input_names) == 1) {
        result_i <- result_i[[1]]
      } else {
        names(result_i) <- paste0("Input_", seq_along(input_names))
      }

      result_i
    }
  )
  # Skip one list dimension if there is only one output layer, otherwise set
  # the names of the list entries
  if (length(output_idx) == 1) {
    result <- result[[1]]
  } else {
    names(result) <- paste0("Output_", out_nonnull_idx)
  }

  result
}


create_dataframe_from_result <- function(data_idx, result, input_names,
                                         output_names, output_idx,
                                         preds = NULL, decomp_goal = NULL) {

  if (length(data_idx) == 0) {
    result_df <- NULL
  } else {
    null_idx <- unlist(lapply(output_idx, is.null))
    nonnull_idx <- seq_along(output_names)[!null_idx]
    output_idx <- output_idx[nonnull_idx]
    output_levels <- paste0("Output_", seq_along(output_names))
    output_names <- output_names[nonnull_idx]

    fun <- function(result, out_idx, in_idx, input_names, output_names,
                    output_idx, nonnull_idx, preds) {
      res <- result[[out_idx]][[in_idx]]

      result_df <-
        create_grid(data_idx, input_names[[in_idx]],
                    output_names[[out_idx]][[1]][output_idx[[out_idx]]])
      if (is.null(res)) {
        result_df$value <- NaN
      } else {
        result_df$value <- as.vector(as.array(res))
        res_sum <- apply(res, c(1, length(dim(res))), sum)
        num_reps <- nrow(result_df) %/% (length(data_idx) * length(output_idx[[out_idx]]))
        res_sum <- do.call("rbind", lapply(seq_len(num_reps), function(x) res_sum))

        if (length(preds) == 0) {
          pred <- NA
        } else {
          pred <- preds[[out_idx]]
          pred <- do.call("rbind", lapply(seq_len(num_reps), function(x) pred))
        }

        if (length(decomp_goal) == 0) {
          dec_goal <- NA
        } else {
          dec_goal <- decomp_goal[[out_idx]]
          dec_goal <- do.call("rbind", lapply(seq_len(num_reps), function(x) dec_goal))
        }

        result_df$pred <- as.vector(pred)
        result_df$decomp_sum <- as.vector(res_sum)
        result_df$decomp_goal <- as.vector(dec_goal)
      }
      result_df$model_input <- paste0("Input_", in_idx)
      result_df$model_output <- factor(
        paste0("Output_", nonnull_idx[out_idx]),
        levels = output_levels)



      result_df
    }

    result <- apply_results(result, fun, input_names, output_names,
                            output_idx, nonnull_idx, preds)
    result_df <- do.call("rbind",
                         lapply(result, function(x) do.call("rbind", x)))
    result_df <- result_df[, c(1, 11, 12, 3, 4, 2, 5, 7, 8, 9, 10, 6)]
  }

  result_df
}

create_grid <- function(data_idx, input_names, output_names) {
  dimension <- length(input_names)

  if (dimension == 1) {
    feature <- input_names[[1]]
    feature_2 <- NaN
    channel <- NaN
  } else if (dimension == 2) {
    feature <- input_names[[2]]
    feature_2 <- NaN
    channel <- input_names[[1]]
  } else {
    feature <- input_names[[2]]
    feature_2 <- input_names[[3]]
    channel <- input_names[[1]]
  }

  expand.grid(data = paste0("data_", data_idx),
              channel = channel,
              feature = feature,
              feature_2 = feature_2,
              output_node = output_names,
              input_dimension = dimension)
}


check_output_idx_for_plot <- function(output_idx, true_output_idx) {
  if (is.null(output_idx)) {
    # Find first non-NULL value
    output_idx <- rep(list(NULL), length(true_output_idx))
    idx <- which(unlist(lapply(true_output_idx, is.null)) == FALSE)[1]
    output_idx[[idx]] <- true_output_idx[[idx]][1]
  } else if (testIntegerish(output_idx)) {
    cli_check(checkSubset(output_idx, true_output_idx[[1]]), "output_idx")
    output_idx <- list(output_idx)
  } else if (testList(output_idx, max.len = length(true_output_idx))) {
    for (out_idx in seq_along(true_output_idx)) {
      cli_check(checkSubset(output_idx[[out_idx]], true_output_idx[[out_idx]]),
                paste0("output_idx[[", out_idx, "]]"))
    }
  } else {
    values <- unlist(lapply(true_output_idx, paste, collapse = ","))
    values <- paste0("[[", seq_along(values), "]] ", values, " ")
    stopf("The argument {.arg output_idx} has to be either a vector with value of '",
         paste(true_output_idx[[1]], collapse = ","),
         "' or a list of length '", length(true_output_idx),
         "' with values of '", values, "'. Only for these output nodes ",
         "the method has been applied!")
  }

  # Fill up with NULLs
  if (length(output_idx) < length(true_output_idx)) {
    output_idx <-
      append(output_idx,
             rep(list(NULL), length(true_output_idx) - length(output_idx)))
  }

  output_idx
}

transform_label_to_idx <- function(output_label, output_names) {
  if (!is.list(output_label)) {
    output_label <- list(output_label)
  }

  fun <- function(i) {
    if (!is.null(output_label[[i]])) {
      labels <- as.factor(output_label[[i]])
      out_names <- output_names[[i]][[1]]
      cli_check(checkSubset(labels, out_names), "output_label")
      match(labels, out_names)
    } else {
      NULL
    }
  }

  lapply(seq_along(output_label), fun)
}

move_channels_last <- function(names) {
  for (idx in seq_along(names)) {
    if (length(names[[idx]]) == 2) { # 1d input
      names[[idx]] <- names[[idx]][c(2, 1)]
    } else if (length(names[[idx]]) == 3) { # 2d input
      names[[idx]] <- names[[idx]][c(2, 3, 1)]
    }
  }

  names
}


apply_results <- function(result, FUN, ...) {
  # loop over all the output layers
  lapply(seq_along(result), function(out_idx) {
    # then loop over all input layers
    lapply(seq_along(result[[out_idx]]), function(in_idx) {
      # apply FUN to results
      FUN(result, out_idx, in_idx, ...)
    })
  })
}

get_aggr_function <- function(aggr_channels) {
  cli_check(c(
    checkFunction(aggr_channels),
    checkChoice(aggr_channels, c("norm", "sum", "mean"))
  ), "aggr_channels")

  if (!is.function(aggr_channels)) {
    if (aggr_channels == "norm") {
      aggr_channels <- function(x) sum(x^2)^0.5
    } else if (aggr_channels == "sum") {
      aggr_channels <- sum
    } else if (aggr_channels == "mean") {
      aggr_channels <- mean
    }
  }

  aggr_channels
}

# Define function for aggregating the channels
aggregate_channels <- function(result, out_idx, in_idx, idx_matches, data_idx,
                               channels_first, aggr_channels) {
  res <- result[[out_idx]][[in_idx]]
  if (is.null(res)) {
    res <- NULL
  } else {
    d <- length(dim(res))
    idx <- idx_matches[[out_idx]]

    # Select only relevant data and output class
    res <- res$index_select(1, as.integer(data_idx))
    res <- res$index_select(-1, as.integer(idx))
    res <- as_array(res)

    # Only aggregate if the input is non-tabular
    if (d != 3) {
      # get arguments for aggregating
      num_axis <- length(dim(res))
      channel_axis <- ifelse(channels_first, 2, num_axis - 1)
      aggr_axis <- setdiff(seq_len(num_axis), channel_axis)

      # aggregate channels
      res <- apply(res, aggr_axis, aggr_channels)
      dim(res) <- append(dim(res), 1, channel_axis)
    }
  }

  res
}

get_lrp_rule <- function(rule_name, rule_param, layer_class) {
  if (is.list(rule_name)) {
    if (layer_class %in% names(rule_name)) {
      rule_name <- rule_name[[layer_class]]
    } else {
      rule_name <- "simple"
    }
  }
  if (is.list(rule_param)) {
    if (layer_class %in% names(rule_param)) {
      rule_param <- rule_param[[layer_class]]
    } else {
      rule_param <- NULL
    }
  }

  list(rule_name = rule_name, rule_param = rule_param)
}
bips-hb/innsight documentation built on April 14, 2025, 6:01 p.m.