R/binary_label_dataset_metric.R

Defines functions binary_label_dataset_metric num_positives num_negatives base_rate disparate_impact statistical_parity_difference consistency mean_difference

Documented in base_rate binary_label_dataset_metric consistency disparate_impact mean_difference num_negatives num_positives statistical_parity_difference

#' Class for computing metrics based on a single `aif360.datasets.BinaryLabelDataset`
#' @param data (BinaryLabelDataset) A BinaryLabelDataset
#' @param privileged_groups (list(list(list))) Privileged groups. Format is a list of `lists` where the keys are `protected_attribute_names` and the values are values in `protected_attributes`. Each `list` element describes a single group. See examples for more details.
#' @param unprivileged_groups (list(list(list))): Unprivileged groups in the same format as `privileged_groups`
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
binary_label_dataset_metric <- function(data,
                                        privileged_groups,
                                        unprivileged_groups){

   p_dict <- lapply(privileged_groups, py_dict_conv)
   u_dict <- lapply(unprivileged_groups, py_dict_conv)

   bldm <- py_suppress_warnings(metrics$BinaryLabelDatasetMetric(data,
                                                         privileged_groups = p_dict,
                                                         unprivileged_groups = u_dict))

   class(bldm) <- c("binary_label_dataset_metric",class(bldm))
   return(bldm)
}

#' Compute the number of positives optionally conditioned on protected attributes
#' @param privileged (bool, optional) Boolean prescribing whether to condition this metric on the `privileged_groups`, if `TRUE`, or the `unprivileged_groups`, if `FALSE`. Defaults to `NULL` meaning this metric is computed over the entire dataset
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
num_positives <- function(binary_label_ds_metric, privileged=NULL){
   py_suppress_warnings(binary_label_ds_metric$num_positives(privileged))
}

#' Compute the number of negatives optionally conditioned on protected attributes
#' @param privileged (bool, optional) Boolean prescribing whether to condition this metric on the `privileged_groups`, if `TRUE`, or the `unprivileged_groups`, if `FALSE`. Defaults to `NULL` meaning this metric is computed over the entire dataset
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
num_negatives <- function(binary_label_ds_metric, privileged=NULL){
   py_suppress_warnings(binary_label_ds_metric$num_negatives(privileged))
}

#' Compute the base rate optionally conditioned on protected attributes
#' @param privileged (bool, optional) Boolean prescribing whether to condition this metric on the `privileged_groups`, if `TRUE`, or the `unprivileged_groups`, if `FALSE`. Defaults to `NULL` meaning this metric is computed over the entire dataset
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
base_rate <- function(binary_label_ds_metric, privileged=NULL){
   py_suppress_warnings(binary_label_ds_metric$base_rate(privileged))
}

#' Compute the disparate impact between the privileged and unprivileged groups
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
disparate_impact <- function(binary_label_ds_metric){
   py_suppress_warnings(binary_label_ds_metric$disparate_impact())
}

#' Compute the statistical parity difference between the privileged and unprivileged groups
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
statistical_parity_difference <- function(binary_label_ds_metric){
   py_suppress_warnings(binary_label_ds_metric$statistical_parity_difference())
}

#' Individual fairness metric that measures how similar the labels are for similar instances
#' @param n_neighbors (int, optional) Number of neighbors for the knn computation
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
consistency <- function(binary_label_ds_metric, n_neighbors=5){
   py_suppress_warnings(binary_label_ds_metric$consistency(as.integer(n_neighbors)))
}

#' Alias of `statistical_parity_difference`.
#' @param binary_label_ds_metric binary label metric instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
mean_difference <- function(binary_label_ds_metric){
   py_suppress_warnings(binary_label_ds_metric$mean_difference())
}
SSaishruthi/raif-test documentation built on Oct. 30, 2019, 11:12 p.m.