R/CatEncodeFit.R

Defines functions CatEncodeFit encode_binary_fit encode_onehot_fit encode_generic_fit

Documented in CatEncodeFit encode_binary_fit encode_generic_fit encode_onehot_fit

#' A fit function to encode categorical data
#'
#' Detects the categorical variables and treats it based on the fit file generated by train data.
#' @param x Any categorical vector which needs to be encoded
#' @param fit A list returned from "BestCatEncode" that is used to fit the test data.
#' @return Returns the encoded data vector
#' @export
#' @usage encode_generic_fit(x, fit)
encode_generic_fit <- function(x, fit)
{
  x <- factor(x)
  encoded <- unname(fit[x])
  encoded
}

#' A fit function to encode categorical data
#'
#' Detects the categorical variables and treats it based on the fit file generated by train data using One-Hot Encoding.
#' @param df Any dataset with atleast one categorical field
#' @param colname A string representing the name of the categorical field in the given dataset
#' @param fit A list returned from "BestCatEncode" that is used to fit the test data.
#' @return Returns the encoded data vector
#' @export
#' @usage encode_onehot_fit(df, colname, fit)
encode_onehot_fit <- function(df, colname, fit)
{
  # library(caret)
  # df <- data.frame(df)
  # df[, c(colname)] <- factor(df[, c(colname)])
  # df_dummy <- df
  # vars <- caret::dummyVars(paste(" ~ ", colname), data = df)
  # new_df <- data.frame(predict(vars, newdata = df_dummy))
  #
  # new_df[, fit[[colname]]]



  # library(mltools)
  df <- data.frame(df)
  df[, c(colname)] <- factor(df[, c(colname)])
  new_df <- one_hot(data.table(df), colname, sparsifyNAs=TRUE)
  data.frame(new_df)[, fit[[colname]]]
}

#' A fit function to encode categorical data
#'
#' Detects the categorical variables and treats it based on the fit file generated by train data using Binary Encoding.
#' @param x Any data vector which needs to be encoded
#' @param colname A string representing the name of the categorical field in the given dataset
#' @param fit A list returned from "BestCatEncode" that is used to fit the test data.
#' @return Returns the encoded data vector
#' @export
#' @usage encode_binary_fit(x, colname, fit)
encode_binary_fit <- function(x, colname, fit)
{
  x <- factor(x)
  encoded <- unname(fit[x])
  return (binary_encoding(encoded, name = paste0(colname, "_")))
}

#' A fit function to detect categorical data and encode the data using the pre-defined statistics for different encoding techniques
#'
#' Detects and treats categorical data by taking in a data set and fit values of the train data.
#' @param data Any dataset in which the categorical data needs to be detected and encoded
#' @param fit A list returned from "BestCatEncode" that is used to fit the test data.
#' @param dv Dependent variable in the given dataset
#' @return Returns the encoded dataset
#' @exports
#' @usage CatEncodeFit(data, dv, fit)
CatEncodeFit <- function(data, dv, fit)
{
  # require(data.table)
  # require(dplyr)

  if(is.null(fit))
    return(NULL)
  data <- data.frame(data)
  dvcol <- data[, dv]
  chosen_method <- fit$Method
  cat_cols <- names(fit$fit)
  cat_data <- data[, cat_cols]

  if (chosen_method == "One-Hot Encoding")
  {
    encoded_data <- data.frame()
    for (i in names(fit$fit))
    {
      one_hot <- encode_onehot_fit(cat_data, i, fit$fit[i])
      ifelse(length(encoded_data) == 0, encoded_data <- one_hot, encoded_data <- cbind(encoded_data, one_hot))
    }
  } else if (chosen_method == "Binary Encoding")
  {
    encoded_data <- data.frame(mapply(encode_binary_fit, cat_data[, cat_cols], cat_cols, fit$fit))
    # encoded_data <- data.frame()
    # for (i in names(fit$fit))
    # {
    #   print(i)
    #   one_hot <- encode_binary_fit(cat_data[,i], i, fit$fit[i])
    #   ifelse(length(encoded_data) == 0, encoded_data <- one_hot, encoded_data <- cbind(encoded_data, one_hot))
    # }
  } else if (chosen_method == "Leave One Out Encoding")
  {
    encoded_data <- data.table(cat_data[, cat_cols])[,lapply(.SD,function(x){encode_leave_one_out(x, dvcol)$encoded}),.SDcols=cat_cols]
  } else
  {
    encoded_data <- data.frame(mapply(encode_generic_fit, cat_data[, cat_cols], fit$fit))
  }
  encoded_data
}
akunuriYoshitha/CatEncode documentation built on July 16, 2021, 4:16 p.m.