R/FamiliarDataComputationSHAP.R

Defines functions ..extract_shap_dependence .extract_shap_dependence .extract_shap_force .extract_shap_summary .hash_mapping .predict_from_coalition .evaluate_shap_convergence .compute_shap_value_single_feature ..compute_shap_value .compute_shap_value .update_shap_matrices ..compute_shap_matrices .compute_shap_matrices .compute_shap_kernel_weights ...shap_randomise_mapping_from_coalition ..shap_randomise_mapping_from_coalition .shap_randomise_mapping_from_coalition .shap_mapping_to_feature_list .shap_mapping_to_data .shap_data_to_mapping .sample_shap_coalitions .get_shap_coalitions .get_shap_sample_set .get_shap_feature_set .compute_shap_iterative .extract_shap

#' @include FamiliarS4Generics.R
#' @include FamiliarS4Classes.R
NULL



# familiarDataElementSHAP object -----------------------------------------------
setClass(
  "familiarDataElementSHAP",
  contains = "familiarDataElement",
  slots = list(
    "sample_identifiers" = "ANY",
    "data_mapping" = "ANY",
    "predicted_values" = "ANY",
    "lookup_table" = "ANY"
  ),
  prototype = methods::prototype(
    detail_level = "ensemble",
    estimation_type = "point",
    value_column = "shap_value",
    grouping_column = c("feature_name", "feature_value_mapping", "sample_id", "phi_0"),
    sample_identifiers = NULL,
    data_mapping = NULL,
    predicted_values = NULL,
    lookup_table = NULL
  )
)


# familiarDataElementSHAPSummary object ----------------------------------------

# Objects for creating SHAP summary plots. These are created at run-time from
# data included in familiarDataElementSHAP objects.
setClass(
  "familiarDataElementSHAPSummary",
  contains = "familiarDataElement"
)


# familiarDataElementSHAPForce object ------------------------------------------

# Objects for creating SHAP force plots. These are created at run-time from
# data included in familiarDataElementSHAP objects.
setClass(
  "familiarDataElementSHAPForce",
  contains = "familiarDataElement"
)


# familiarDataElementSHAPDependence object -------------------------------------

# Objects for creating SHAP dependence plots. These are created at run-time from
# data included in familiarDataElementSHAP objects.
setClass(
  "familiarDataElementSHAPDependence",
  contains = "familiarDataElement"
)




# extract_shap (generic) -------------------------------------------------------

#'@title Internal function for computing SHAP values.
#'
#'@description Computes SHAP values for feature values using a
#'  `familiarEnsemble`.
#'
#'@param features Features for whose values SHAP values need to be computed.
#'  defaults to all features in the model.
#'@param n_sample_points Minimum number of values to sample for numeric
#'  features. By default, this is based on input dataset. But if the number of
#'  values of a feature within that dataset is too low, additional values are
#'  drawn from the feature distribution (stored with the model).
#'@param shap_tolerance Relative tolerance for convergence of SHAP values. The
#'  tolerance is scaled with the range in SHAP values. Default: 0.05.
#'@param shap_max_iterations Maximum iterations for convergence of SHAP values.
#'  Default: 1000
#'@param shap_phi_0 Reference predicted value(s). Determined from data by
#'  default.
#'
#'@inheritParams .extract_data
#'
#'@return A list of familiarDataElements with SHAP values.
#'@md
#'@keywords internal
setGeneric(
  "extract_shap",
  function(
    object,
    data,
    cl = NULL,
    features = NULL,
    n_sample_points = 20L,
    shap_tolerance = waiver(),
    shap_max_iterations = waiver(),
    shap_phi_0 = waiver(),
    ensemble_method = waiver(),
    evaluation_times = waiver(),
    sample_limit = waiver(),
    detail_level = waiver(),
    aggregate_results = waiver(),
    n_important_features = waiver(),
    is_pre_processed = FALSE,
    message_indent = 0L,
    verbose = FALSE,
    ...
  ) {
    standardGeneric("extract_shap")
  }
)



# extract_shap (familiarEnsemble) ----------------------------------------------
setMethod(
  "extract_shap",
  signature(object = "familiarEnsemble"),
  function(
    object,
    data,
    cl = NULL,
    features = NULL,
    n_sample_points = 20L,
    shap_tolerance = waiver(),
    shap_max_iterations = waiver(),
    shap_phi_0 = NULL,
    ensemble_method = waiver(),
    evaluation_times = waiver(),
    sample_limit = waiver(),
    detail_level = waiver(),
    aggregate_results = waiver(),
    n_important_features = waiver(),
    is_pre_processed = FALSE,
    message_indent = 0L,
    verbose = FALSE,
    ...
  ) {
    # Compute SHAP values.
    
    if (is.waive(features)) features <- NULL
    
    # Message extraction start
    logger_message(
      paste0("Extracting SHAP values for the ensemble."),
      indent = message_indent,
      verbose = verbose
    )
    
    if (is.null(features)) {
      logger_message(
        paste0(
          "Computing SHAP values for features in the dataset."
        ),
        indent = message_indent + 1L,
        verbose = verbose
      )
      
    } else {
      logger_message(
        paste0(
          "Computing SHAP values for selected features: ", paste_s(features), "."
        ),
        indent = message_indent + 1L,
        verbose = verbose
      )
    }
    
    # Load evaluation_times from the object settings attribute, if it is not provided.
    if (is.waive(evaluation_times)) evaluation_times <- object@settings$eval_times
    
    # Check evaluation_times argument
    if (object@outcome_type %in% c("survival")) {
      sapply(
        evaluation_times,
        .check_number_in_valid_range,
        var_name = "evaluation_times",
        range = c(0.0, Inf),
        closed = c(FALSE, TRUE)
      )
    }
    
    # Check n_sample_points argument. This defines the maximum number of feature
    # values for which SHAP values are computed.
    .check_number_in_valid_range(
      x = n_sample_points,
      var_name = "n_sample_points",
      range = c(2L, Inf),
      closed = c(TRUE, TRUE)
    )
    
    # Check shap_tolerance argument. This sets stopping criteria for convergence
    # of SHAP values.
    if (is.waive(shap_tolerance)) shap_tolerance <- object@settings$shap_tolerance
    
    .check_number_in_valid_range(
      x = shap_tolerance,
      var_name = "shap_tolerance",
      range = c(0.0, Inf),
      closed = c(FALSE, FALSE)
    )
    
    # Check shap_max_iterations argument. This sets the maximum number of
    # iterations for computing SHAP values.
    if (is.waive(shap_max_iterations)) shap_max_iterations <- object@settings$shap_max_iterations
    
    .check_number_in_valid_range(
      x = shap_max_iterations,
      var_name = "shap_max_iterations",
      range = c(1L, Inf),
      closed = c(TRUE, TRUE)
    )
    
    # Check shap_phi_0.
    if (is.waive(shap_phi_0)) shap_phi_0 <- object@settings$shap_phi_0
    
    if (!is.null(shap_phi_0)) {
      if (object@outcome_type == "survival") {
        .check_argument_length(
          x = shap_phi_0,
          var_name = "shap_phi_0",
          min = length(evaluation_times),
          max = length(evaluation_times)
        )
        
        .check_is_numeric(
          x = shap_phi_0,
          var_name = "shap_phi_0"
        )
        
      } else if (object@outcome_type == "multinomial") {
        .check_argument_length(
          x = shap_phi_0,
          var_name = "shap_phi_0",
          min = length(get_outcome_class_levels(object)),
          max = length(get_outcome_class_levels(object))
        )
        
        .check_is_numeric(
          x = shap_phi_0,
          var_name = "shap_phi_0"
        )
        
      } else {
        .check_is_numeric(
          x = shap_phi_0,
          var_name = "shap_phi_0"
        )
      }
    }
    
    # Obtain ensemble method from stored settings, if required.
    if (is.waive(ensemble_method)) ensemble_method <- object@settings$ensemble_method
    
    # Check ensemble_method argument
    .check_parameter_value_is_valid(
      x = ensemble_method,
      var_name = "ensemble_method",
      values = .get_available_ensemble_prediction_methods()
    )
    
    # Check the sample limit. This defines the subset of samples that are being
    # assessed (if real data is used instead of a minimum subset).
    sample_limit <- .parse_sample_limit(
      x = sample_limit,
      object = object,
      default = 200L,
      data_element = "shap"
    )
    
    # Check the level detail.
    detail_level <- .parse_detail_level(
      x = detail_level,
      object = object,
      default = "ensemble",
      data_element = "shap"
    )
    
    # Check whether results should be aggregated.
    aggregate_results <- .parse_aggregate_results(
      x = aggregate_results,
      object = object,
      default = FALSE,
      data_element = "shap"
    )
    
    # Test if models are properly loaded
    if (!is_model_loaded(object = object)) ..error_ensemble_models_not_loaded()
    
    # Test if any model in the ensemble was successfully trained.
    if (!model_is_trained(object = object)) return(NULL)
    
    # Check the number of important features.
    n_important_features <- .parse_n_important_features(
      x = n_important_features,
      object = object,
      default = 20,
      data_element = "shap"
    )
    
    # Set features to be assessed using SHAP.
    important_features <- .select_important_features(
      object = object,
      data = data,
      n_important_features = n_important_features
    )
    
    # Get and process the input data. Since we define SHAP values using the
    # features in their original scales, we need to apply only minimal
    # pre-processing.
    data <- process_input_data(
      object = object,
      data = data, 
      stop_at = "signature"
    )
    
    # Use sample limit to cap the number of samples that are assessed.
    data <- get_subsample(
      data = data,
      size = sample_limit,
      seed = 0L
    )
    
    # Generate a prototype data element.
    proto_data_element <- new(
      "familiarDataElementSHAP",
      detail_level = detail_level
    )
    
    # Generate elements to send to dispatch.
    shap_data <- extract_dispatcher(
      FUN = .extract_shap,
      has_internal_bootstrap = FALSE,
      cl = cl,
      object = object,
      data = data,
      features = features,
      n_sample_points = n_sample_points,
      tolerance = shap_tolerance,
      n_max_iter = shap_max_iterations,
      phi_0 = shap_phi_0,
      proto_data_element = proto_data_element,
      important_features = important_features,
      is_pre_processed = is_pre_processed,
      ensemble_method = ensemble_method,
      evaluation_times = evaluation_times,
      aggregate_results = aggregate_results,
      message_indent = message_indent + 1L,
      verbose = verbose
    )
  }
)



# extract_shap (prediction table) ----------------------------------------------
setMethod(
  "extract_shap",
  signature(object = "familiarDataElementSHAP"),
  function(object, ...) {
    ..warning_no_data_extraction_from_prediction_table("SHAP values")
    
    return(NULL)
  }
)



.extract_shap <- function(
    object,
    data = NULL,
    proto_data_element,
    important_features,
    evaluation_times = NULL,
    features = NULL,
    n_sample_points,
    aggregate_results,
    is_pre_processed = FALSE,
    ensemble_method,
    cl,
    tolerance = 0.005,
    n_max_iter = 1000L,
    phi_0 = NULL,
    mapping_method = "fixed",
    sampling_method = "importance",
    message_indent = 0L,
    verbose = FALSE,
    progress_bar = FALSE,
    ...
) {
  # Step 1: Determine feature values that are to be sampled for determining SHAP
  # values.
  #
  # Step 2: Determine the minimum sampleset X required to determine SHAP values.
  # The number of samples (n) is equal to the feature with the largest number of 
  # values (m_i) to sample. Feature values can be randomly ordered for each
  # feature. Features with m_i < n can randomly draw additional features.
  #
  # Alternative: Use the actual dataset X (trigger on function argument).
  #
  # Step 3: Create coalition sets for coalitions with all but one off and all but
  # one on.
  #
  # Step 4: Iterate samples in sampleset X. Generate samples corresponding to
  # coalitions for each sample. Concatenate generated samples and add to table
  # with previously generated samples X_gen.
  #
  # Step 5: Select samples without existing predictions.
  #
  # Step 6: Predict samples. Merge new predictions into existing predictions.
  #
  # Step 7: Compute average predicted value (phi_0).
  #
  # Step 8: For each sample in sampleset X, determine coalition represented by
  # each sample in X_gen. Compute kernel weight based on coalition. Compute SHAP
  # values by solving linear equation.
  #
  # Step 9: Average SHAP value for each feature value.
  #
  # Step 10: Determine convergence and repeat steps 4-9 until convergence is
  # reached, or capacity is exhausted.
  #
  # Parallel processing: perform steps 4-6 multiple times within a parallel loop.
  # This allows for faster convergence.
  
  shap_value <- NULL

  # Check that the model requires any features.
  if (is_empty(important_features)) return(NULL)
  
  # Get set of feature values.
  feature_set <- .get_shap_feature_set(
    data = data,
    features = object@model_features,
    feature_info = object@feature_info[object@model_features],
    n_sample_points = n_sample_points
  )
  
  # Generate data if absent.
  if (is_empty(data)) {
    data <- .get_shap_sample_set(
      object = object,
      feature_set = feature_set
    )
  }
  
  # Ensure that data is unique.
  data <- select_unique_data(data = data)
  
  # Get sample identifiers.
  sample_identifiers <- get_unique_row_names(data)
  
  # From here, work with mapping representations of the data (h).
  mapping_input <- .shap_data_to_mapping(
    data = data,
    feature_set = feature_set
  )
  
  # Predict outcome values from the input data. Output may be more than one
  # column.
  predicted_values_input <- .predict_from_coalition(
    mapping = mapping_input,
    feature_set = feature_set,
    object = object,
    ensemble_method = ensemble_method,
    evaluation_time = evaluation_times
  )
  
  if (is_empty(predicted_values_input)) return(NULL)
  
  # Compute phi_0.
  if (is.null(phi_0)) phi_0 <- colMeans(predicted_values_input)
  
  if (length(important_features) == 1L) {
    # Single feature (shapley) -------------------------------------------------
    shap_values <- .compute_shap_value_single_feature(
      important_features = important_features,
      mapping = mapping_input,
      sample_id = sample_identifiers,
      feature_set = feature_set,
      predicted_values = predicted_values_input,
      phi_0 = phi_0
    )

  } else {
    # Multiple features (kernel) -----------------------------------------------
    
    # Generate coalitions (Z)
    input_coalitions <- .get_shap_coalitions(
      important_features = important_features,
      depth = 1L
    )
    
    # Check that coalitions are not empty: this happens if the data contains a
    # single feature: SHAP values cannot be computed.
    if (is.null(input_coalitions)) return(NULL)
    
    # Provide the initial set of coalitions.
    coalitions <- list(input_coalitions)
    
    # Compute weights for each coalition in a coalition set.
    kernel_weights <- .compute_shap_kernel_weights(
      n = ncol(input_coalitions),
      individual_coalition = TRUE
    )
    
    # Looping variables.
    iter_id <- 0L
    all_shap_converged <- FALSE
    shap_values <- NULL
    shap_matrices <- NULL
    n_parallel <- max(c(length(cl), 1L))
    
    while (!all_shap_converged && iter_id < n_max_iter) {
      # Determine the iteration identifiers.
      current_ids <- iter_id + seq_len(n_parallel)
      
      # Distribute computation.
      shap_values_iter <- fam_lapply(
        cl = cl,
        X = current_ids,
        FUN = .compute_shap_iterative,
        object = object,
        ensemble_method = ensemble_method,
        evaluation_times = evaluation_times,
        mapping_input = mapping_input,
        sample_identifiers = sample_identifiers,
        sample_predictions = predicted_values_input,
        important_features = important_features,
        phi_0 = phi_0,
        coalitions = coalitions,
        kernel_weights = kernel_weights,
        feature_set = feature_set,
        mapping_method = mapping_method
      )
      
      # Combine with previous shap_values.
      shap_values <- data.table::rbindlist(c(list(shap_values), shap_values_iter))
      
      # Compute variance of each SHAP value.
      shap_variance <- shap_values[
        ,
        list("shap_var" = stats::var(shap_value), "n" = .N),
        by = c("sample_id", "feature_name", "feature_value_mapping", "shap_outcome")
      ]
      
      # Check convergence.
      all_shap_converged <- .evaluate_shap_convergence(
        shap_variance = shap_variance,
        tolerance = max(c(
          tolerance * diff(range(shap_values$shap_value)),
          tolerance * diff(range(c(predicted_values_input))) / sqrt(length(feature_set))
        ))
      )
      
      # Update coalitions.
      coalitions <- .sample_shap_coalitions(
        coalitions = input_coalitions,
        kernel_weights = kernel_weights,
        shap_variance = shap_variance,
        sampling_method = sampling_method,
        seed = 19L + iter_id
      )
      
      # Update iteration id.
      iter_id <- tail(current_ids, n = 1L)
    }
  }
  
  # Compute final SHAP values.
  shap_values <- shap_values[
    ,
    list("shap_value" = mean(shap_value)),
    by = c("sample_id", "feature_name", "feature_value_mapping", "shap_outcome")
  ]
  
  # Add model name to data element.
  proto_data_element <- add_model_name(
    proto_data_element,
    object = object
  )
  
  # Store data mapping of feature values for input data.
  proto_data_element@data_mapping <- mapping_input
  
  # Store lookup-table translate feature mapping back to feature values.
  proto_data_element@lookup_table <- feature_set

  # Add predictions for input data.
  proto_data_element@predicted_values <- predicted_values_input
  
  # Add sample identifiers.
  proto_data_element@sample_identifiers <- sample_identifiers
  
  # Store shap data. Value column is "shap_value", grouping columns are
  # "feature_name" and "feature_value_mapping". For multinomial and survival
  # outcomes, "shap_outcome" is an additional grouping column.
  if (object@outcome_type %in% c("multinomial", "survival")) {
    # Add phi_0
    if (object@outcome_type == "survival") {
      phi_0_data <- data.table::data.table(
        "shap_outcome" = as.character(evaluation_times),
        "phi_0" = phi_0
      )
        
    } else if (object@outcome_type == "multinomial") {
      phi_0_data <- data.table::data.table(
        "shap_outcome" = get_outcome_class_levels(object),
        "phi_0" = phi_0
      )
    }
    
    shap_values <- merge(
      x = shap_values,
      y = phi_0_data,
      by = "shap_outcome"
    )
    
    proto_data_element@data <- data.table::copy(
      shap_values[, mget(c("feature_name", "feature_value_mapping", "shap_outcome", "shap_value", "phi_0", "sample_id"))]
    )
    
    # Add shap_outcome as additional grouping level.
    proto_data_element@grouping_column <- c(proto_data_element@grouping_column, "shap_outcome")
    
    if (object@outcome_type %in% c("multinomial")) {
      # Convert shap_outcome to categorical values corresponding to the levels
      # in the modelled endpoint.
      proto_data_element@data$shap_outcome <- factor(
        proto_data_element@data$shap_outcome,
        levels = get_outcome_class_levels(object)
      )
      
    } else if (object@outcome_type == "survival") {
      # Convert shap_outcome to categorical values corresponding to the
      # evaluation time
      proto_data_element@data$shap_outcome <- factor(
        proto_data_element@data$shap_outcome,
        levels = as.character(evaluation_times)
      )
    }
    
  } else {
    # Add phi_0.
    shap_values[, "phi_0" := phi_0]
    
    proto_data_element@data <- data.table::copy(
      shap_values[, mget(c("feature_name", "feature_value_mapping", "shap_value", "phi_0", "sample_id"))]
    )
  }
  
  return(proto_data_element)
}



.compute_shap_iterative <- function(
    iter_id,
    object,
    ensemble_method,
    evaluation_times,
    mapping_input,
    important_features,
    sample_identifiers,
    sample_predictions,
    phi_0,
    coalitions,
    kernel_weights,
    feature_set,
    mapping_method
) {
  # Determine additional mapping.
  mapping_iter <- .shap_randomise_mapping_from_coalition(
    important_features = important_features,
    samples = mapping_input,
    coalitions = coalitions,
    feature_set = feature_set,
    seed = iter_id,
    mapping_method = mapping_method
  )
  
  # Predict from new unique mappings.
  predicted_values_iter <- .predict_from_coalition(
    mapping = mapping_iter,
    feature_set = feature_set,
    object = object,
    ensemble_method = ensemble_method,
    evaluation_time = evaluation_times
  )
  
  if (is.null(predicted_values_iter)) return(NULL)
  
  # Compute and update A and b matrices.
  shap_matrices <- .compute_shap_matrices(
    important_features = important_features,
    samples = mapping_input,
    sample_predictions = sample_predictions,
    sample_id = sample_identifiers,
    mapping = mapping_iter,
    predicted_values = predicted_values_iter,
    phi_0 = phi_0,
    kernel_weights = kernel_weights
  )
  
  # Compute SHAP values for this iteration.
  shap_values <- .compute_shap_value(
    shap_matrices = shap_matrices
  )
  
  if (is.null(shap_values)) return(NULL)
  
  return(shap_values)
}



.get_shap_feature_set <- function(
    data = NULL,
    features,
    feature_info,
    n_sample_points
) {
  # Gets set of feature values for the features of interest. For categorical
  # features, all levels are used. For numerical features, the data is sampled
  # (if available), and additional values are drawn based on the known
  # distribution of feature values for each feature.
  feature_set <- list()
  for (feature in features) {
    # For categorical features, use all levels.
    if (feature_info[[feature]]@feature_type == "factor") {
      feature_set[[feature]] <- factor(feature_info[[feature]]@levels, levels = feature_info[[feature]]@levels)
      next
    }
    
    # Select (numeric) feature values from the data.
    feature_values <- NULL
    if (!is_empty(data)) {
      feature_values <- unique_na(data@data[[feature]])
    }
    
    # Check number of values to sample.
    n_to_sample <- n_sample_points - length(feature_values)
    if (n_to_sample <= 0L) {
      feature_set[[feature]] <- feature_values
      next
    }
    
    # Add feature values by sampling distribution.
    feature_set[[feature]] <- unique(c(
      feature_values,
      stats::spline(
        x = (seq_along(feature_info[[feature]]@distribution$pctl) - 1L) / 
          (length(feature_info[[feature]]@distribution$pctl) - 1L),
        y = as.numeric(feature_info[[feature]]@distribution$pctl),
        xout = get_percentiles(n_to_sample),
        method = "hyman"
      )$y
    ))
  }
  
  return(feature_set)
}



.get_shap_sample_set <- function(
    object,
    feature_set
) {
  # Determine the number of samples that are required.
  n_samples <- max(lengths(feature_set))
  
  # Fill sample set. Feature values are randomly ordered and then distributed
  # over features.
  sample_set <- list()
  for (ii in seq_along(feature_set)) {
    feature <- names(feature_set)[ii]
    feature_values <- feature_set[[feature]]
    feature_values <- feature_values[fam_sample(
      seq_along(feature_values),
      n = length(feature_values),
      replace = FALSE,
      seed = ii
    )]
    sample_set[[feature]] <- rep_len(feature_values, length.out = n_samples)
  }
  
  # Convert to data.table and batch and sample identifiers.
  data <- data.table::as.data.table(sample_set)
  data[, ":="(
    "batch_id" = "generated",
    "sample_id" = seq_len(nrow(data))
  )]
  
  return(as_data_object(
    data = data,
    object = object,
    batch_id_column = "batch_id",
    sample_id_column = "sample_id",
    check_stringency = "external"
  ))
}



.get_shap_coalitions <- function(
    important_features,
    depth = 1L
) {
  # Get initial set of coalitions. Antithetic coalitions are created within
  # ...shap_randomise_mapping_from_coalition
  
  # Helper function that inserts TRUE at the indices indicated by ones.
  ..fun <- function(ones, n_features) {
    x <- logical(n_features)
    x[ones] <- TRUE
    return(x)
  }

  # Set number of features.
  n_features <- length(important_features)
  
  # Check that depth is at most n_features - 1L.
  if (depth >= n_features) depth <- n_features - 1L
  if (depth < 1L) return(NULL)
  
  z <- list()
  for (ii in seq_len(depth)) {
    z <- c(
      z,
      utils::combn(
        n_features,
        m = ii,
        FUN = ..fun,
        simplify = FALSE,
        n_features = n_features
      )
    )
  }
  
  # To matrix. Data are stored row-wise.
  z <- matrix(unlist(z), ncol = n_features, byrow = TRUE)
  colnames(z) <- important_features
  
  return(z)
}



.sample_shap_coalitions <- function(
    coalitions,
    kernel_weights,
    shap_variance,
    sampling_method,
    seed
) {
  shap_var <- coalition_id <- NULL
  if (sampling_method == "fixed") {
    sampled_coalitions <- list(coalitions)
    
  } else if (sampling_method == "importance") {
    # Probabilistic selection of coalitions. This is very much similar to
    # importance sampling in MCMC. The central concept is that SHAP values for
    # some features are more difficult to determine than those of others. The
    # more difficult features will have a larger overall variance. Therefore we
    # want to focus more on coalitions where these features are present.
    
    # Check if shap_variance contains useful info. Particularly, it cannot be
    # empty, or contain NA or inf values. In that case, just return the fixed
    # coalitions
    if (is_empty(shap_variance)) return(list(coalitions))
    if (any(!is.finite(shap_variance$shap_var))) return(list(coalitions))
    
    # Determine the cost of each input coalition, normalised by their total
    # cost. We want to sample coalitions until the budget (1.0) is exceeded.
    # This is the same budget available to the input coalition. The cost is
    # equal to the corresponding kernel weight for the individual coalition.
    
    coalition_size <- rowSums(coalitions)
    coalition_cost <- kernel_weights[coalition_size + 1L]
    coalition_cost <- coalition_cost / sum(coalition_cost)
    
    # The selection probability for each individual coalition is the variance(s)
    # of the on-feature(s) times the kernel weight, normalised by the total
    # probability over all input coalitions.
    shap_variance <- shap_variance[, list("feature_shap_var" = sum(shap_var)), by = "feature_name"]
    
    # Ensure that feature shap variance and the respective features in the
    # coalition matrix are ordered the same way.
    feature_shap_variance <- shap_variance$feature_shap_var
    names(feature_shap_variance) <- shap_variance$feature_name
    feature_shap_variance <- feature_shap_variance[colnames(coalitions)]
    
    # Compute selection likelihood.
    coalition_probability <- colSums(t(coalitions * kernel_weights[coalition_size + 1L]) * feature_shap_variance)
    coalition_probability <- coalition_probability / sum(coalition_probability)
    
    # Draw 1 / min(coalition_cost) coalitions with resampling, and use these up to
    # and including the coalition where the budget is exceeded.
    n_to_sample <- ceiling(1.0 / min(coalition_cost))
    u_sampled <- fam_runif(n = n_to_sample, seed = seed)
    
    selected_coalition <- sapply(
      u_sampled,
      function(x, u) (which.min(u < x)),
      u = cumsum(coalition_probability)
    )
    
    selected_coalition <- head(
      selected_coalition,
      n = which.min(cumsum(coalition_cost[selected_coalition]) < 1.0)
    )
    
    # Form subsets of coalitions so that each individual coalition only appears
    # once in its bag. This is important so that a different sample will be drawn
    # when the fixed mapping method in ...shap_randomise_mapping_from_coalition is
    # used.
    coalition_table <- data.table::data.table("coalition_id" = selected_coalition)
    coalition_table[, "bag_id" := data.table::rowid(coalition_id)]
    
    sampled_coalitions <- lapply(
      split(coalition_table, by = "bag_id"),
      function(x, coalitions){
        coalitions[x$coalition_id, , drop = FALSE]
      },
      coalitions = coalitions
    )
    
    # Note that antithetic sampling is done later in
    # ...shap_randomise_mapping_from_coalition.
    
  } else {
    ..error_reached_unreachable_code(paste0("unknown sampling_method: ", sampling_method))
  }
  
  return(sampled_coalitions)
}



.shap_data_to_mapping <- function(
  data,
  feature_set
) {
  # Convert data to mapping matrix. This uses the fact that all the values in
  # the feature set
  mapping <- list()
  
  # Maps data to a matrix of integers that establishes a mapping to the feature
  # values in feature set.
  for (feature in names(feature_set)) {
    mapping[[feature]] <- match(data@data[[feature]], feature_set[[feature]])
  }
  
  # Create matrix.
  h <- matrix(unlist(mapping), ncol = length(feature_set))
  colnames(h) <- names(feature_set)
  
  return(h)
}



.shap_mapping_to_data <- function(
    mapping,
    feature_set,
    object
) {
  ..fun <- function(feature, y, x) {
    # Use column in mapping matrix x to lookup value from y.
    return(y[x[, feature]])
  }
  
  # Use lookup to fill data.
  data <- mapply(
    ..fun,
    feature = names(feature_set),
    y = feature_set,
    MoreArgs = list("x" = mapping),
    SIMPLIFY = FALSE
  )
  
  # Set names of list elements.
  names(data) <- names(feature_set)
  
  # Convert to data.table and add identifiers.
  data <- data.table::as.data.table(data)
  
  data <- as_data_object(
    data = data,
    object = object,
    check_stringency = "external"
  )
  
  # Update pre-processing level from none to signature, because we are strictly
  # working with model features here.
  data@preprocessing_level <- "signature"
  
  return(data)
}



.shap_mapping_to_feature_list <- function(
    feature,
    mapping_value, 
    lookup_table
) {
  y <- lookup_table[[feature[1L]]]
  if (is.factor(y)) {
    feature_value <- as.numeric(y)[mapping_value]
    feature_label <- as.character(y)[mapping_value]
  } else {
    feature_value <- y[mapping_value]
    feature_label <- rep_len(NA_character_, length(mapping_value))
  }
  
  return(list(
    "feature_value" = feature_value,
    "feature_label" = feature_label
  ))
}



.shap_randomise_mapping_from_coalition <- function(
    important_features,
    samples,
    coalitions,
    feature_set,
    seed,
    n_min_mappings = 300L,
    mapping_method = "fixed"
) {
  # Determine the number of feature values for each value.
  n_feature_values <- lengths(feature_set)

  # Start random stream
  rstream_object <- .start_random_number_stream(seed)
  
  mapping <- list()
  n_mappings <- 0L
  
  # Ensure that sufficient mappings are generated to limit the effect of
  # overhead on the computation of SHAP values.
  while (n_mappings < n_min_mappings) {
    
    # Generate random values.
    if (mapping_method == "fixed") {
      # Generate a random number for each feature and each sample and coalition
      # set.
      n_random <- nrow(samples) * length(coalitions) * length(important_features)
      
    } else if (mapping_method == "random") {
      # Generate a random number for each feature and each sample and each
      # single coalition. Due to antithetic sampling, the number of coalitions
      # is doubled.
      n_random <- nrow(samples) * 2L * sum(sapply(coalitions, nrow)) * length(important_features)
      
    } else {
      ..error_reached_unreachable_code(paste0("unknown mapping_method: ", mapping_method))
    }
    
    # Stream random numbers.
    x_random <- fam_runif(n = n_random, rstream_object = rstream_object)
    
    random_mapping <- mapply(
      FUN = ..shap_randomise_mapping_from_coalition,
      x = asplit(samples, 1L),
      x_random = split(x_random, fam_cut(seq_along(x_random), nrow(samples))),
      MoreArgs = list(
        "coalitions" = coalitions,
        "n_feature_values" = n_feature_values,
        "mapping_method" = mapping_method
      ),
      USE.NAMES = FALSE,
      SIMPLIFY = FALSE
    )
    
    mapping <- c(mapping, random_mapping)
    n_mappings <- n_mappings + sum(sapply(random_mapping, nrow))
  }
  
  # Concatenate by rows.
  mapping <- do.call(rbind, mapping)

  # Remove duplicates.
  mapping <- unique(mapping)
  
  return(mapping)
}



..shap_randomise_mapping_from_coalition <- function(
  x, 
  x_random,
  coalitions,
  n_feature_values,
  mapping_method
) {
  # Split random values.
  if (length(coalitions) == 1L) {
    x_random <- list(x_random)
    
  } else {
    if (mapping_method == "fixed") {
      # Each coalition set requires the same number of random values.
      x_random <- split(x_random, fam_cut(seq_along(x_random), length(coalitions)))
      
    } else if (mapping_method == "random") {
      # Due to antithetic sampling, the number of coalitions is doubled.
      n_coalitions_in_bag <- 2L * sapply(coalitions, nrow)
      x_random <- split(
        x_random, 
        rep(seq_along(coalitions), n_coalitions_in_bag * length(n_feature_values))
      )
    }
  }
  
  # Loop over coalition sets.
  mapping <- mapply(
    FUN = ...shap_randomise_mapping_from_coalition,
    coalitions = coalitions,
    x_random = x_random,
    MoreArgs = list(
      "x" = x,
      "n_feature_values" = n_feature_values,
      "mapping_method" = mapping_method
    ),
    USE.NAMES = FALSE,
    SIMPLIFY = FALSE
  )
  
  return(do.call(rbind, mapping))
}



...shap_randomise_mapping_from_coalition <- function(
    coalitions,
    x_random,
    x,
    n_feature_values,
    mapping_method
) {
  
  ..select_shap_feature_mapping <- function(
    available_feature_values,
    random_values
  ) {
    # Sample with replacement.
    x_indices <- as.integer(ceiling(random_values * length(available_feature_values)))
    x_indices[x_indices == 0L] <- 1L
    
    return(available_feature_values[x_indices])
  }
  
  # Generate antithetic coalitions from input.
  coalitions <- rbind(coalitions, !coalitions)
  
  # Get important features.
  important_features <- colnames(coalitions)
  unimportant_features <- setdiff(names(n_feature_values), important_features)
  
  # The mapping method determines how values are drawn.
  #
  # - "fixed": a single sample, not_x, that is fully distinct from x is 
  #    randomly drawn. on-features are copied from x, and off-features from
  #    not_x.
  # - "random": off-features are drawn randomly.
  
  mapping <- list()
  if (mapping_method == "fixed") {
    for (ii in seq_along(important_features)) {
      # Determine eligible features from in-coalition (on) and off-coalition
      # (off) features.
      feature <- important_features[ii]
      on_feature_set <- unname(x[feature])
      off_feature_set <- seq_len(n_feature_values[feature])[-on_feature_set]
      
      # Sample a single value from the off-feature set, and append to the
      # on-feature set. This forms the look-up table for forming coalitions.
      feature_set <- c(
        on_feature_set,
        ..select_shap_feature_mapping(
          available_feature_values = off_feature_set,
          random_values = x_random[ii]
        )
      )
      
      # Determine which value from feature_set should be used. Index 1L
      # corresponds to the in-coalition feature value, and index 2L to the
      # off-coalition feature value.
      lookup_vector <- 1L + !coalitions[, feature]
      
      # Add features to mapping.
      mapping[[feature]] <- feature_set[lookup_vector]
    }
    
    for (ii in seq_along(unimportant_features)) {
      # Unimportant features are simply repeated without coalition.
      feature <- unimportant_features[ii]
      mapping[[feature]] <- rep(unname(x[feature]), nrow(coalitions))
    }
      
  } else if (mapping_method == "random") {
    # Determine the number of feature values to samples for off-coalition
    # features. This should be the same number for each feature in antithetical
    # sampling.
    n_to_draw <- colSums(!coalitions)
    jj <- 0L
    
    for (ii in seq_along(important_features)) {
      # Determine eligible features from in-coalition (on) and off-coalition
      # (off) features.
      feature <- important_features[ii]
      on_feature_set <- unname(x[feature])
      off_feature_set <- seq_len(n_feature_values[feature])[-on_feature_set]
      
      # Sample the off-feature set, and append to the on-feature set. This forms
      # the look-up table for forming coalitions.
      feature_set <- c(
        on_feature_set,
        ..select_shap_feature_mapping(
          available_feature_values = off_feature_set,
          random_values = x_random[(1L:n_to_draw[feature]) + jj]
        )
      )
      
      # Update offset.
      jj <- jj + n_to_draw[feature]
      
      # Accumulate off-coalition elements. E.g. with coalitions (across samples
      # for each feature) [0, 1, 1, 0], the lookup-vector is [1, 1, 1, 2].
      lookup_vector <- cumsum(!coalitions[, feature])
      
      # Reset in-coalition elements of the lookup vector, yielding, e.g. [1, 0, 0,
      # 2], and increment by 1. This results in indices (e.g. [2, 1, 1, 3])
      # referring to the feature set, with index 1 corresponding to the
      # in-coalition value.
      lookup_vector <- 1L + lookup_vector * !coalitions[, feature]
      
      # Add features to mapping.
      mapping[[feature]] <- feature_set[lookup_vector]
    }
    
    for (ii in seq_along(unimportant_features)) {
      # Unimportant features are simply repeated without coalition.
      feature <- unimportant_features[ii]
      mapping[[feature]] <- rep(unname(x[feature]), nrow(coalitions))
    }
    
  } else {
    ..error_reached_unreachable_code(paste0("unknown mapping_method: ", mapping_method))
  }
  
  # Convert to matrix. Mapping consists of columns (that are first ordered
  # correctly before being flattened), and the matrix is then filled by column.
  mapping <- matrix(
    unlist(mapping[names(n_feature_values)]),
    ncol = length(n_feature_values)
  )
  colnames(mapping) <- names(n_feature_values)
  
  return(mapping)
}



.compute_shap_kernel_weights <- function(n, individual_coalition = FALSE) {
  # Form a lookup-table for kernel weights.
  n_present <- seq_len(n + 1L) - 1L
  n_permutations <- choose(n, n_present)
  kernel_weights <- (n - 1.0) / (n_permutations * n_present * (n - n_present))
  kernel_weights[!is.finite(kernel_weights)] <- 0.0
  
  # Normalise kernel-weights to 1.
  kernel_weights <- kernel_weights / sum(kernel_weights)
  
  if (individual_coalition) {
    # Determine weights for individual unique coalitions.
    kernel_weights <- kernel_weights / n_permutations
  }
  
  return(kernel_weights)
}



.compute_shap_matrices <- function(
  important_features,
  samples,
  sample_predictions,
  sample_id,
  mapping,
  predicted_values,
  phi_0,
  kernel_weights
) {
  # Replace NA in samples and mapping.
  samples[is.na(samples)] <- 0L
  mapping[is.na(mapping)] <- 0L
  
  # Select only important features.
  samples <- samples[, important_features, drop = FALSE]
  mapping <- mapping[, important_features, drop = FALSE]
  
  # Compute A and b matrices for each sample.
  new_matrices <- mapply(
    ..compute_shap_matrices,
    x = asplit(samples, MARGIN = 1L),
    v_0 = asplit(sample_predictions, MARGIN = 1L),
    sample_id = sample_id,
    MoreArgs = list(
      "kernel_weights" = kernel_weights,
      "mapping" = mapping,
      "predicted_values" = predicted_values,
      "phi_0" = phi_0
    ),
    SIMPLIFY = FALSE,
    USE.NAMES = FALSE
  )
  
  return(new_matrices)
  
  # Update full matrix. THIS IS CURRENTLY NOT USED.
  # We follow the recipe by Covert and Lee (2021), which means that we update
  # the A and b matrices each iteration.
  # if (is.null(matrices)) {
  #   matrices <- new_matrices
  #   names(matrices) <- sample_id
  #   matrices <- lapply(
  #     matrices,
  #     function(x) {
  #       x$n_iter <- 1L
  #       return(x)
  #     }
  #   )
  #   
  # } else {
  #   matrices <- mapply(
  #     FUN = .update_shap_matrices,
  #     old = matrices,
  #     new = new_matrices,
  #     SIMPLIFY = FALSE,
  #     USE.NAMES = FALSE
  #   )
  # }
  
  # Return both the full and temporary matrices.
}



..compute_shap_matrices <- function(
    x,
    v_0,
    sample_id,
    kernel_weights,
    mapping,
    predicted_values,
    phi_0
) {
  # Prevent notes due to data.table.
  weight <- NULL
  
  # x is the mapping corresponding to the sample. First we determine the 
  # coalitions pertaining to current sample. Since `==` is operating by column,
  # we can simply transpose the mapping matrix so that rows become columns. Then
  # the comparison is performed on the columns representing each row, and the
  # result is transposed again.
  coalitions <- t(t(mapping) == c(x))
  
  # Compute the number of features "present" in each coalition. 
  n_coalition_size <- rowSums(coalitions)
  weights <- kernel_weights[n_coalition_size + 1L]
  non_zero_weights <- weights > 0.0
  
  # Check for empty weights.  
  if (!any(non_zero_weights)) return(NULL)
  
  # Weighted least squares solves for coefficients beta as follows:
  # beta = (t(X) W X)^-1 t(X)W y
  # In the context of kernelSHAP, this means:
  #    beta = phi
  #       X = Z (coalitions),
  #       W = diag(pi) (kernel_weights)
  #   and y = f(h(z)) - phi_0
  X <- coalitions[non_zero_weights, , drop = FALSE]
  
  # This ensures that phi_0 is subtracted row-wise.
  y <- t(t(predicted_values[non_zero_weights, , drop = FALSE] - phi_0))
  
  # kernelSHAP originally was defined by sampling coalitions, with each
  # coalition drawn probabilistically according to the SHAP kernel weights. This
  # means that the appearance of coalitions would stochastically mirror their
  # probabilities. Here we have a fixed set of coalitions, and coalitions that
  # are randomly formed from the entire set of decisions that should be
  # explained (and their fixed coalitions). This means that our coalitions are
  # not distributed as expected, and cannot be used without updating the
  # weights.
  
  # First determine how often individual coalitions appear. This is currently
  # the most expensive part of this function. I thought of solutions to speed up
  # the process, other than data.table::frank. Hashing each row using
  # rlang::hash or paste0 is several times slower. Using a filter technique
  # where we count which individual coalitions are present is not scalable for
  # large feature sets due to large number of permutations of coalitions. Even
  # though that solution could be quite fast when few features are present (and
  # thus a non-sparse coalition set is present in `coalitions`), the likely
  # gains will be minimal. Another alternative is to encode coalitions as a sum
  # of positional powers of 2 ( feature 1: TRUE / FALSE * 2^0; feature 2: TRUE /
  # FALSE 2^1; etc.), but that would still be problematic with large numbers of
  # features due to integer bit precision.
  coalition_id <- data.table::frank(data.table::as.data.table(X), ties.method = "dense")
  n_coalition_instances <- integer(length(coalition_id))
  for (ii in seq_len(max(coalition_id))) {
    instance_in_coalition <- coalition_id == ii
    n_coalition_instances[instance_in_coalition] <- sum(instance_in_coalition)
  }
  
  # Then re-weight probability of individual coalitions, and normalise.
  w <- weights[non_zero_weights]
  w <- w / n_coalition_instances
  w <- w / sum(w)
  
  # Instead of computing a diagonal matrix, we rely on equivalent element-wise
  # multiplications (which are considerably cheaper).
  return(list(
    "A" = t(X) %*% (X * w),
    "b" = t(X) %*% (y * w),
    "v_0" = as.numeric(v_0),
    "phi_0" = phi_0,
    "sample_id" = sample_id,
    "sample_mapping" = x
  ))
}



.update_shap_matrices <- function(
    old,
    new
) {
  old$A <- old$A + new$A
  old$b <- old$b + new$b
  old$n_iter <- old$n_iter + 1L
  old$phi_0 <- new$phi_0
  
  return(old)
}



.compute_shap_value <- function(
  shap_matrices
) {
  # Compute shap values.
  shap_values <- lapply(
    shap_matrices,
    ..compute_shap_value
  )
  shap_values <- data.table::rbindlist(shap_values)
  
  return(shap_values)
}



..compute_shap_value <- function(x) {
  if (is.null(x)) return(NULL)
  
  A_inv <- matrix_pseudo_inverse(x$A)
  b <- x$b
  
  # Compute initial coefficients.
  phi <- A_inv %*% b
  
  # Due to the local accuracy criterion the sum of of the SHAP values plus phi_0
  # should be equal to the predicted value. We estimate the SHAP values under
  # this constraint.
  phi <- A_inv %*% t(t(b) - (colSums(phi) - (x$v_0 - x$phi_0)) / sum(A_inv))
  
  shap_values <- list(
    "sample_id" = x$sample_id,
    "feature_name" = rep(colnames(x$A), times = ncol(x$b)),
    "feature_value_mapping" = rep(x$sample_mapping, times = ncol(x$b)),
    "shap_value" = c(phi),
    "shap_outcome" = rep(colnames(x$b), each = ncol(x$A))
  )
  
  return(shap_values)
}



.compute_shap_value_single_feature <- function(
    important_features,
    mapping,
    sample_id,
    feature_set,
    predicted_values,
    phi_0
) {
  # Get value mapping.
  mapping <- mapping[, important_features]
  
  data <- data.table::data.table(
    "sample_id" = sample_id,
    "feature_name" = important_features,
    "feature_value_mapping" = rep(mapping, times = ncol(predicted_values)),
    "shap_value" = c(t(t(predicted_values) - phi_0)),
    "shap_outcome" = rep(colnames(predicted_values), each = nrow(predicted_values))
  )
  
  return(data)
}



.evaluate_shap_convergence <- function(
    shap_variance,
    tolerance
) {
  # Compute sample error of the mean for each shap value.
  sem_values <- sqrt(shap_variance$shap_var / shap_variance$n)
  if (any(!is.finite(sem_values))) return(FALSE)
  # TODO: remove
  # cat(paste0("sum SEM: ", sum(sem_values), " ; total converged: ", sum(sem_values <= tolerance), "\n"))
  return(all(sem_values <= tolerance))
}



.predict_from_coalition <- function(
    mapping,
    feature_set,
    object,
    ensemble_method,
    evaluation_time
) {
  # Set prediction type.
  prediction_type <- ifelse(
    object@outcome_type %in% c("survival", "competing_risk"),
    "survival_probability", 
    "default"
  )
  
  # Convert input to dataObject
  data <- .shap_mapping_to_data(
    mapping = mapping,
    feature_set = feature_set,
    object = object
  )
  
  if (object@outcome_type == "survival") {
    prediction_list <- list()
    
    for (ii in seq_along(evaluation_time)) {
      # Predict input data
      prediction_data <- predict(
        object = object,
        newdata = data,
        ensemble_method = ensemble_method,
        time = evaluation_time[ii],
        type = prediction_type,
        .as_prediction_table = TRUE
      )
      
      # Check if all predictions are valid.
      if (!all_predictions_valid(prediction_data)) return(NULL)
      
      # Convert to data.table.
      prediction_data <- .as_data_table(prediction_data)[, mget(prediction_data@value_column)]
      prediction_list[[ii]] <- matrix(prediction_data$predicted_outcome, ncol = 1L)
    }
    
    prediction_data <- do.call(cbind, prediction_list)
    colnames(prediction_data) <- as.character(evaluation_time)
    
  } else {
    # Predict input data
    prediction_data <- predict(
      object = object,
      newdata = data,
      ensemble_method = ensemble_method,
      type = prediction_type,
      .as_prediction_table = TRUE
    )
    
    # Check if all predictions are valid.
    if (!all_predictions_valid(prediction_data)) return(NULL)
    
    # Convert to data.table.
    prediction_data <- .as_data_table(prediction_data)[, mget(prediction_data@value_column)]
    
    if (object@outcome_type == "continuous") {
      prediction_data <- matrix(prediction_data$predicted_outcome, ncol = 1L)
      colnames(prediction_data) <- "predicted_outcome"
      
    } else if (object@outcome_type %in% c("multinomial")) {
      probability_columns <- get_outcome_class_levels(object)
      prediction_data <- as.matrix(prediction_data[, mget(probability_columns)])
      
    } else if (object@outcome_type %in% c("binomial")) {
      probability_column <- utils::tail(get_outcome_class_levels(object), n = 1L)
      prediction_data <- as.matrix(prediction_data[, mget(probability_column)])
      
    } else {
      ..error_outcome_type_not_implemented(object@outcome_type)
    }
  }
  
  return(prediction_data)
}



.hash_mapping <- function(x) {
  return(apply(
    x,
    MARGIN = 1L,
    FUN = rlang::hash,
    simplify = TRUE
  ))
}



.extract_shap_summary <- function(
    x
) {
  # Prevent NOTES due to non-standard evaluation
  feature_name <- feature_value_mapping <- NULL
  
  # Generate object using the incoming familiarDataElementSHAP object as a
  # template.
  data_element <- methods::new(
    "familiarDataElementSHAPSummary",
    x
  )
  
  # Clean reporting elements.
  data_element@data <- NULL
  
  if (is_empty(x@data)) return(data_element)
  
  # Get data_mapping and turn into a long data.table.
  mapping_data <- data.table::as.data.table(x@data_mapping)
  mapping_data[, "sample_id" := x@sample_identifiers]
  mapping_data <- data.table::melt(
    data = mapping_data,
    id.vars = "sample_id",
    variable.name = "feature_name",
    value.name = "feature_value_mapping"
  )
  
  # Insert feature values in mapping data.
  mapping_data[
    ,
    c("feature_value", "feature_label") := .shap_mapping_to_feature_list(
      feature = feature_name,
      mapping_value = feature_value_mapping, 
      lookup_table = x@lookup_table
    ),
    by = "feature_name"
  ]
  
  # Cartesian merge.
  summary_data <- merge(
    x = x@data,
    y = mapping_data,
    by = c("feature_name", "feature_value_mapping", "sample_id"),
    allow.cartesian = TRUE
  )
  
  # Drop unused columns.
  summary_data[, feature_value_mapping := NULL]
  
  # Set data.
  data_element@data <- summary_data
  
  # Set identifiers
  data_element@grouping_column <- c(
    setdiff(data_element@grouping_column, "feature_value_mapping"),
    c("feature_value", "feature_label")
  )
  
  return(data_element)
}



.extract_shap_force <- function(
    x
) {
  # Prevent NOTES due to non-standard evaluation
  feature_name <- feature_value_mapping <- NULL
  
  # Generate object using the incoming familiarDataElementSHAP object as a
  # template.
  data_element <- methods::new(
    "familiarDataElementSHAPForce",
    x
  )
  
  # Clean reporting elements.
  data_element@data <- NULL
  
  if (is_empty(x@data)) return(data_element)
  
  # Get data_mapping and turn into a long data.table.
  mapping_data <- data.table::as.data.table(x@data_mapping)
  mapping_data[, "sample_id" := x@sample_identifiers]
  mapping_data <- data.table::melt(
    data = mapping_data,
    id.vars = "sample_id",
    variable.name = "feature_name",
    value.name = "feature_value_mapping"
  )
  
  # Insert feature values in mapping data.
  mapping_data[
    ,
    c("feature_value", "feature_label") := .shap_mapping_to_feature_list(
      feature = feature_name,
      mapping_value = feature_value_mapping, 
      lookup_table = x@lookup_table
    ),
    by = "feature_name"
  ]
  
  # Prediction data
  prediction_data <- data.table::as.data.table(x@predicted_values)
  prediction_data[, "sample_id" := x@sample_identifiers]
  prediction_data <- data.table::melt(
    data = prediction_data,
    id.vars = "sample_id",
    variable.name = "shap_outcome",
    value.name = "prediction"
  )
  
  # Cartesian merge.
  force_data <- merge(
    x = x@data,
    y = mapping_data,
    by = c("feature_name", "feature_value_mapping", "sample_id"),
    allow.cartesian = TRUE
  )
  
  merge_cols <- "sample_id"
  if ("shap_outcome" %in% colnames(force_data)) merge_cols <- c(merge_cols, "shap_outcome")
  
  force_data <- merge(
    x = force_data,
    y = prediction_data,
    by = merge_cols
  )
  
  # Drop unused columns.
  force_data[, feature_value_mapping := NULL]

  # Set data.
  data_element@data <- force_data
  
  # Set identifiers
  data_element@grouping_column <- c(
    setdiff(data_element@grouping_column, "feature_value_mapping"),
    c("feature_value", "feature_label", "prediction")
  )
  
  return(data_element)
}



.extract_shap_dependence <- function(
    x,
    feature_x,
    feature_y
) {
  data_element_list <- list()
  iter_id <- 1L
  for (current_feature_x in feature_x) {
    if (!current_feature_x %in% names(x@lookup_table)) {
      ..warning(paste0(
        current_feature_x, " is not part of the feature set used by the model. ",
        "The following features were used: ", paste_s(names(x@lookup_table))
      ))
      next
    }
      
    for (current_feature_y in feature_y) {
      if (!current_feature_y %in% names(x@lookup_table)) {
        ..warning(paste0(
          current_feature_y, " is not part of the feature set used by the model. ",
          "The following features were used: ", paste_s(names(x@lookup_table))
        ))
        next
      }
      
      data_element_list[[iter_id]] <- ..extract_shap_dependence(
        x = x,
        feature_x = current_feature_x,
        feature_y = current_feature_y
      )
      
      iter_id <- iter_id + 1L
    }
  }
  
  if (is_empty(data_element_list)) return(NULL)
  
  return(data_element_list)
}



..extract_shap_dependence <- function(
    x,
    feature_x,
    feature_y
) {
  # Prevent NOTES due to non-standard evaluation
  feature_name <- feature_value_mapping <- NULL
  
  # Generate object using the incoming familiarDataElementSHAP object as a
  # template.
  data_element <- methods::new(
    "familiarDataElementSHAPDependence",
    x
  )
  
  # Add feature x and feature y as identifiers to the dataset.
  data_element <- add_data_element_identifier(
    x = data_element,
    feature_x = feature_x
  )[[1L]]
  data_element <- add_data_element_identifier(
    x = data_element,
    feature_y = feature_y
  )[[1L]]
  
  # Clean reporting elements.
  data_element@data <- NULL
  
  if (is_empty(x@data)) return(data_element)
  
  # Get data_mapping and turn into a long data.table.
  mapping_data <- data.table::as.data.table(x@data_mapping)
  mapping_data[, "sample_id" := x@sample_identifiers]
  mapping_data <- data.table::melt(
    data = mapping_data,
    id.vars = "sample_id",
    variable.name = "feature_name",
    value.name = "feature_value_mapping"
  )
  
  # Insert feature values in mapping data.
  mapping_data[
    ,
    c("feature_value", "feature_label") := .shap_mapping_to_feature_list(
      feature = feature_name,
      mapping_value = feature_value_mapping, 
      lookup_table = x@lookup_table
    ),
    by = "feature_name"
  ]
  
  # To create a dependence plot, we need for:
  # feature x: its value and its SHAP value.
  # feature y: its value
  # These are linked by the sample identifier.
  
  feature_x_data <- data.table::copy(mapping_data[feature_name == feature_x])
  feature_y_data <- data.table::copy(mapping_data[feature_name == feature_y])
  
  # Cartesian merge.
  dependence_data <- merge(
    x = x@data,
    y = feature_x_data,
    by = c("feature_name", "feature_value_mapping", "sample_id"),
    allow.cartesian = TRUE
  )
  
  # Update column names.
  data.table::setnames(
    dependence_data,
    old = c("feature_value", "feature_label"),
    new = c("feature_value_x", "feature_label_x")
  )
  dependence_data[, feature_value_mapping := NULL]
  dependence_data[, feature_name := NULL]
  
  # Add feature_y_data.
  dependence_data <- merge(
    x = dependence_data,
    y = feature_y_data,
    by = "sample_id"
  )
  
  # Update column names.
  data.table::setnames(
    dependence_data,
    old = c("feature_value", "feature_label"),
    new = c("feature_value_y", "feature_label_y")
  )
  dependence_data[, feature_value_mapping := NULL]
  dependence_data[, feature_name := NULL]
  
  # Update column order.
  col_order <- c("sample_id", "feature_value_x", "feature_label_x")
  if ("shap_outcome" %in% colnames(dependence_data)) col_order <- c(col_order, "shap_outcome")
  col_order <- c(col_order, "shap_value", "feature_value_y", "feature_label_y")
  data.table::setcolorder(
    dependence_data,
    neworder = col_order
  )
  
  # Set data.
  data_element@data <- dependence_data
  
  # Update grouping columns.
  data_element@grouping_column <- c(
    setdiff(data_element@grouping_column, c("feature_name", "feature_value_mapping")),
    c("feature_value_x", "feature_label_x", "feature_value_y", "feature_label_y")
  )
  
  return(data_element)
}


# export_shap (generic) --------------------------------------------------------

#'@title Extract and export individual conditional expectation data.
#'
#'@description Extract and export individual conditional expectation data.
#'
#'@param feature_x (*optional*) Feature(s) whose SHAP values are used for
#'  determining dependence.
#'@param feature_y (*optional*) Feature(s) whose values are used to show 
#'  interaction with the feature(s) in `feature_x`.
#'
#'@inheritParams export_all
#'@inheritParams export_univariate_analysis_data
#'
#'@inheritDotParams as_familiar_collection
#'
#'@details Data is usually collected from a `familiarCollection` object.
#'  However, you can also provide one or more `familiarData` objects, that will
#'  be internally converted to a `familiarCollection` object. It is also
#'  possible to provide a `familiarEnsemble` or one or more `familiarModel`
#'  objects together with the data from which data is computed prior to export.
#'  Paths to the previous files can also be provided.
#'
#'  All parameters aside from `object` and `dir_path` are only used if `object`
#'  is not a `familiarCollection` object, or a path to one.
#'
#'@return A list of data.tables (if `dir_path` is not provided), or nothing, as
#'  all data is exported to `csv` files.
#'@exportMethod export_ice_data
#'@md
#'@rdname export_shap-methods
setGeneric(
  "export_shap",
  function(
    object,
    dir_path = NULL,
    aggregate_results = TRUE,
    export_collection = FALSE,
    feature_x = NULL,
    feature_y = NULL,
    ...
  ) {
    standardGeneric("export_shap")
  }
)



# export_shap (collection) -----------------------------------------------------

#'@rdname export_shap-methods
setMethod(
  "export_shap",
  signature(object = "familiarCollection"),
  function(
    object,
    dir_path = NULL,
    aggregate_results = TRUE,
    export_collection = FALSE,
    feature_x = NULL,
    feature_y = NULL,
    ...
  ) {
    
    # Make sure the collection object is updated.
    object <- update_object(object = object)

    # Generate data for summary plots.
    summary_data_elements <- lapply(
      object@shap_data,
      .extract_shap_summary
    )
    
    # Export summary data.
    summary_data <- .export(
      x = object,
      data_elements = summary_data_elements,
      dir_path = dir_path,
      aggregate_results = TRUE,
      object_class = "familiarDataElementSHAPSummary",
      type = "explanation",
      subtype = "shap_summary"
    )
    
    # Generate data for force plots.
    force_data_elements <- lapply(
      object@shap_data,
      .extract_shap_force
    )
    
    # Export data for force plots.
    force_data <- .export(
      x = object,
      data_elements = force_data_elements,
      dir_path = dir_path,
      aggregate_results = TRUE,
      object_class = "familiarDataElementSHAPForce",
      type = "explanation",
      subtype = "shap_force"
    )
    
    dependence_data <- NULL
    if (!is.null(feature_x) && !is.null(feature_y)) {
      # Generate data for SHAP dependence plots.
      dependence_data_elements <- lapply(
        object@shap_data,
        .extract_shap_dependence,
        feature_x = feature_x,
        feature_y = feature_y
      )
      
      dependence_data <- .export(
        x = object,
        data_elements = dependence_data_elements,
        dir_path = dir_path,
        aggregate_results = TRUE,
        object_class = "familiarDataElementSHAPDependence",
        type = "explanation",
        subtype = "shap_dependence"
      )
    }
    
    # Add to list.
    data_list <- c(
      "shap_summary" = summary_data,
      "shap_force" = force_data,
      "shap_dependence" = dependence_data
    )
    
    if (export_collection) {
      data_list <- c(
        data_list,
        list("collection" = object)
      )
    } 
    
    return(data_list)
  }
)



# export_shap (general) ----------------------------------------------------

#'@rdname export_shap-methods
setMethod(
  "export_shap",
  signature(object = "ANY"),
  function(
    object,
    dir_path = NULL,
    aggregate_results = TRUE,
    export_collection = FALSE,
    ...
  ) {
    
    # Attempt conversion to familiarCollection object.
    object <- do.call(
      as_familiar_collection,
      args = c(
        list(
          "object" = object,
          "data_element" = "export_shap",
          "aggregate_results" = aggregate_results
        ),
        list(...)
      )
    )
    
    return(do.call(
      export_shap,
      args = c(
        list(
          "object" = object,
          "dir_path" = dir_path,
          "aggregate_results" = aggregate_results,
          "export_collection" = export_collection
        ),
        list(...)
      )
    ))
  }
)

Try the familiar package in your browser

Any scripts or data that you put into this service are public.

familiar documentation built on May 23, 2026, 1:07 a.m.