R/acc_parity.R

Defines functions acc_parity

Documented in acc_parity

#' @title Accuracy parity
#'
#' @description
#' This function computes the Accuracy parity metric
#' 
#' Formula: (TP + TN) / (TP + FP + TN + FN)
#'
#' @details
#' This function computes the Accuracy parity metric as described by Friedler et al., 2018.
#' Accuracy metrics are calculated by the division of correctly predicted observations (the sum
#' of all true positives and true negatives) with the number of all predictions. In the returned
#' named vector, the reference group will be assigned 1, while all other groups will be assigned values
#' according to whether their accuracies are lower or higher compared to the reference group. Lower
#' accuracies will be reflected in numbers lower than 1 in the returned named vector, thus numbers
#' lower than 1 mean WORSE prediction for the subgroup.
#'
#' @param data Data.frame that contains the necessary columns.
#' @param group Column name indicating the sensitive group (character).
#' @param base Base level of the sensitive group (character).
#' @param group_breaks If group is continuous (e.g., age): either a numeric vector of two or more unique cut points or a single number >= 2 giving the number of intervals into which group feature is to be cut.
#' @param outcome Column name indicating the binary outcome variable (character).
#' @param outcome_base Base level of the outcome variable (i.e., negative class). Default is the first level of the outcome variable.
#' @param probs Column name or vector with the predicted probabilities (numeric between 0 - 1). Either probs or preds need to be supplied.
#' @param preds Column name or vector with the predicted binary outcome (0 or 1). Either probs or preds need to be supplied.
#' @param cutoff Cutoff to generate predicted outcomes from predicted probabilities. Default set to 0.5.
#'
#' @name acc_parity
#'
#' @return
#' \item{Metric}{Raw accuracy metrics for all groups and metrics standardized for the base group (accuracy parity metric). Lower values compared to the reference group mean lower accuracies in the selected subgroups}
#' \item{Metric_plot}{Bar plot of Accuracy parity metric}
#' \item{Probability_plot}{Density plot of predicted probabilities per subgroup. Only plotted if probabilities are defined}
#'
#' @examples
#' data(compas)
#' compas$Two_yr_Recidivism_01 <- ifelse(compas$Two_yr_Recidivism == 'yes', 1, 0) 
#' acc_parity(data = compas, outcome = 'Two_yr_Recidivism_01', group = 'ethnicity',
#' probs = 'probability', cutoff = 0.4, base = 'Caucasian')
#' acc_parity(data = compas, outcome = 'Two_yr_Recidivism_01', group = 'ethnicity',
#' preds = 'predicted', cutoff = 0.5, base = 'Hispanic')
#'
#' @export

acc_parity <- function(data, outcome, group,
                       probs        = NULL, 
                       preds        = NULL, 
                       outcome_base = NULL, 
                       cutoff       = 0.5, 
                       base         = NULL,
                       group_breaks = NULL) {
    
    # check if data is data.frame
    if (class(data)[1] != 'data.frame') {
        warning(paste0('Converting ', class(data)[1], ' to data.frame'))
        data <- as.data.frame(data)
    }
    
    # convert types, sync levels
    if (is.null(probs) & is.null(preds)) {
        stop({'Either probs or preds have to be supplied'})
    }
    if (is.null(probs)) {
        if (length(preds) == 1) {
            preds <- data[, preds]
        }
        preds_status <- as.factor(preds)
    } else {
        if (length(probs) == 1) {
            probs <- data[, probs]
        }
        preds_status         <- as.factor(as.numeric(probs > cutoff))
        levels(preds_status) <- levels(as.factor(data[, outcome]))
    }
    
    # check group feature and cut if needed
    if ((length(unique(data[, group])) > 10) & (is.null(group_breaks))) {
        warning('Number of unqiue group levels exceeds 10. Consider specifying `group_breaks`.')
    }
    if (!is.null(group_breaks)) {
        if (is.numeric(data[, group])) {
            data[, group] <- cut(data[, group], breaks = group_breaks)
        }else{
            warning('Attempting to bin a non-numeric group feature.')
        }
    }
    
    # convert to factor
    group_status   <- as.factor(data[, group])
    outcome_status <- as.factor(data[, outcome])
    
    # check levels matching
    if (!identical(levels(outcome_status), levels(preds_status))) {
        warn_preds   <- paste0(levels(preds_status),   collapse = ', ')
        warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
        stop({paste0(c('Levels of predictions and outcome do not match. ',
                       'Please relevel predictions or outcome.\n',
                       'Outcome levels: ', warn_preds, '\n',
                       'Preds   levels: ', warn_outcome))})}
    
    # relevel preds & outcomes
    if (is.null(outcome_base)) {
        outcome_base <- levels(outcome_status)[1]
    }else{
        outcome_base <- as.character(outcome_base)
    }
    outcome_status   <- relevel(outcome_status, outcome_base)
    preds_status     <- relevel(preds_status,   outcome_base)
    outcome_positive <- levels(outcome_status)[2]
    
    # check lengths
    if ((length(outcome_status) != length(preds_status)) | (length(outcome_status) !=
                                                            length(group_status))) {
        stop('Outcomes, predictions/probabilities and group status must be of the same length')
    }
    
    # relevel group
    if (is.null(base)) {base <- levels(group_status)[1]}
    group_status <- relevel(group_status, base)
    
    # placeholders
    val         <- rep(NA, length(levels(group_status)))
    names(val)  <- levels(group_status)
    sample_size <- val
    
    # compute value for all groups
    for (i in levels(group_status)) {
        cm <- caret::confusionMatrix(preds_status[group_status   == i], 
                                     outcome_status[group_status == i], 
                                     mode = 'everything',
                                     positive = outcome_positive)
        metric_i       <- cm$overall['Accuracy']
        val[i]         <- metric_i
        sample_size[i] <- sum(cm$table)
    }
    
    # aggregate results
    res_table <- rbind(val, val/val[[1]], sample_size)
    rownames(res_table) <- c('Accuracy', 'Accuracy Parity', 'Group size')
    
    # conversion of metrics to df
    val_df <- as.data.frame(res_table[2, ])
    colnames(val_df) <- c('val')
    val_df$groupst <- rownames(val_df)
    val_df$groupst <- as.factor(val_df$groupst)
    
    # relevel group
    if (is.null(base)) {
        val_df$groupst <- levels(val_df$groupst)[1]
    }
    val_df$groupst <- relevel(val_df$groupst, base)
    
    p <- ggplot(val_df, aes(x = groupst, weight = val, fill = groupst)) + geom_bar(alpha = 0.5) +
        coord_flip() + theme(legend.position = 'none') + labs(x = '', y = 'Accuracy Parity')
    
    # plotting
    if (!is.null(probs)) {
        q <- ggplot(data, aes(x = probs, fill = group_status)) + geom_density(alpha = 0.5) +
            labs(x = 'Predicted probabilities') + guides(fill = guide_legend(title = '')) +
            theme(plot.title = element_text(hjust = 0.5)) + xlim(0, 1) + geom_vline(xintercept = cutoff,
                                                                                    linetype = 'dashed')
    }
    
    if (is.null(probs)) {
        list(Metric = res_table, Metric_plot = p)
    } else {
        list(Metric = res_table, Metric_plot = p, Probability_plot = q)
    }
    
}
kozodoi/Fairness documentation built on March 22, 2023, 12:29 p.m.