R/one_hot.R

scale_onehot <- function(data) {
    n_categories <- ncol(data)
    col_means <- apply(data, 2, mean)
    norm_factors <- sqrt(n_categories * col_means * (1 - col_means))
    divide_norm_factors <- sweep(data, 2, norm_factors, "/")
    return(divide_norm_factors)
}

remove_last_col <- function(data) {
    cols <- last(colnames(data))
    select(data, -cols)
}


ohe <- function(data) {
    data %>%
        tibble::as_tibble() %>%
        dplyr::mutate(one = 1L) %>%
        dplyr::mutate(rowid = 1:nrow(.)) %>%
        tidyr::spread(1, 2, fill = 0L) %>%
        dplyr::select(-rowid)
}

#' Categorical Data Matrix to One-Hot Binary Matrix
#'
#' @param data a categorical data matrix.
#' @param minus_level logical (default FALSE); if TRUE then create binary
#' encodings for m-1 levels of a variable with m original levels
#' @param clarify_levels logical (default TRUE); if TRUE then disambiguate resulting
#' column names
#' @param scale logical, if FALSE then binary matrix is returned. If TRUE, then
#' normalization (see details) is applied to each binary transformed variable.
#' @return A transformed one hot encoded matrix is returned.
#' @details The normalization technique is taken from Outlier Analysis (Aggarwal, 2017),
#' section 8.3. For each column j in the binary transformed matrix, a normalization
#' factor is defined as sqrt(ni \* pj \* (1-pj)), where ni is the number of distinct
#' categories in the reference variable from the raw data set and pj is the proportion
#' of records taking the value of 1 for the jth variable
#' @examples
#' df <- data.frame(gender = sample(c("male", "female"), 25, T),
#'                  age = sample(c("young", "old", "unknown"), 25, T))
#' make_onehot(data = df)
#' @importFrom magrittr %>%
#' @export
make_onehot <- function(data, minus_level = FALSE, clarify_levels = TRUE,
                        scale = FALSE) {
    
    # create list of OHE tibbles for each original categorical variable
    xi_onehot <- purrr::map(data, ohe)
    
    # disambiguate new column names
    if (clarify_levels == TRUE) {
        for (i in seq_along(names(xi_onehot))) {
            colnames(xi_onehot[[i]]) <- paste0(names(xi_onehot[i]),
                                               "_",
                                               colnames(xi_onehot[[i]]))
        }
    }
    
    # scaling
    if (scale == TRUE) {
        xi_onehot <- purrr::map(xi_onehot, scale_onehot)
    }
    
    # minus one level
    if (minus_level == TRUE) {
        xi_onehot <- purrr::map(xi_onehot, remove_last_col)
    }
    
    # output
    output <- dplyr::bind_cols(xi_onehot) %>%
        tibble::as_tibble()
    
    return(output)
}
dannymorris/smltools documentation built on May 15, 2019, 10:49 a.m.