R/dataset_metric.R

Defines functions dataset_metric difference ratio num_instances

Documented in dataset_metric difference num_instances ratio

#' Class for computing metrics based on one StructuredDataset
#' @param data (StructuredDataset) A StructuredDataset
#' @param privileged_groups (list(list(list))) Privileged groups. Format is a list of `lists` where the keys are `protected_attribute_names` and the values are values in `protected_attributes`. Each `list` element describes a single group. See examples for more details.
#' @param unprivileged_groups (list(list(list))): Unprivileged groups in the same format as `privileged_groups`
#' @export
#' @importFrom reticulate py_suppress_warnings import
#'
dataset_metric <- function(data,
                           privileged_groups,
                           unprivileged_groups){

   p_dict <- lapply(privileged_groups, py_dict_conv)

   u_dict <- lapply(unprivileged_groups, py_dict_conv)

   dm <- py_suppress_warnings(metrics$DatasetMetric(data,
                                              privileged_groups = p_dict,
                                              unprivileged_groups = u_dict))

   class(dm) <- c("dataset_metric", class(dm))
   return(dm)
}

#' Compute difference of the metric for unprivileged and privileged groups
#' @param ds_metric dataset metric class instance
#' @param metric_fun dataset metric function
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
difference <- function(ds_metric,metric_fun){
   metric_fun(ds_metric, privileged=FALSE) - metric_fun(ds_metric, privileged=TRUE)
}

#' Compute ratio of the metric for unprivileged and privileged groups
#' @param ds_metric dataset metric class instance
#' @param metric_fun dataset metric function
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
ratio <- function(ds_metric,metric_fun){
   metric_fun(ds_metric, privileged=FALSE) / metric_fun(ds_metric, privileged=TRUE)
}

#' Compute the number of instances in the dataset conditioned on protected attributes if necessary
#' @param ds_metric dataset metric class instance
#' @export
#' @importFrom reticulate py_suppress_warnings
#'
num_instances <- function(ds_metric){
   py_suppress_warnings(ds_metric$num_instances())
}
SSaishruthi/raif-test documentation built on Oct. 30, 2019, 11:12 p.m.