R/one_hot_decode.R

#' Decode one hot encoding
#'
#' @description Decodes one hot encoding generated by one_hot_encode-function utilizing matrix attributes generated in one_hot_encode-function.
#'
#' @param mat A matrix created by one_hot_encode-function and passed through impute_xgboost
#' @param verbose Should decoding progress be printed
#'
#' @import data.table
#'
#' @return A data.frame or data.table depending on the original class given to one_hot_encode-function. Column classes will be retained.
#' @export
#'
#' @examples
#'
#' mat = one_hot_encode(iris)
#' mis_mat = generate_na(mat)
#' imp_mat = impute_xgboost(mis_mat)
#' imp_iris = one_hot_decode(imp_mat)
#'


one_hot_decode = function(mat , verbose = 1){


  # ensure attrs exist

  original_class = attr(mat , 'xgbimpute_original_class' )
  original_column_names = attr(mat , 'xgbimpute_original_column_names')
  original_column_classes = attr(mat , 'xgbimpute_original_column_classes')

  has_xgbimpute_attributes = !is.null( original_column_names ) & !is.null(original_column_classes) & !is.null(original_class)

  if(!has_xgbimpute_attributes){
    stop("Matrix needs to have attributes 'xgbimpute_original_class', 'xgbimpute_original_column_names' and 'xgbimpute_original_column_classes'")
  }

  # decode

  categorical_classes = c('factor' , 'character')

  variables_to_decode = original_column_names[ original_column_classes %in% categorical_classes ]

  # collect orignal columns to new environment to enable side effect in sapply
  encoded_columns_to_remove_environment = new.env(parent = emptyenv())

  if(verbose > 0){
    cat('\nstarting decoding\n')
  }

  predicted_categoricals = sapply(variables_to_decode, function(variable_to_decode){

    if(verbose > 0){
      cat('.')
    }

    grep_clause = paste0('^',variable_to_decode,'_')
    encoded_columns = colnames(mat)[ grepl( grep_clause , colnames(mat) )  ]
    assign(variable_to_decode , encoded_columns , envir = encoded_columns_to_remove_environment )
    original_levels = sub(".*_","",encoded_columns)

    predicted_values = apply(mat[,encoded_columns], 1 , function(x){
      original_levels[which.max(x)]
    })

    return(predicted_values)
  })

  if(verbose > 0){
    cat('\nstarting decoding\n')
  }

  colnames(predicted_categoricals) = variables_to_decode
  predicted_categoricals = as.data.table(predicted_categoricals)

  columns_to_remove_environment_variables = ls(encoded_columns_to_remove_environment)
  columns_to_remove = lapply(variables_to_decode , function(columns_to_remove_environment_variable){
    eval(parse(text = paste('encoded_columns_to_remove_environment$' , columns_to_remove_environment_variable ) ))
  })
  columns_to_remove = unlist(columns_to_remove)
  columns_to_remove = which(colnames(mat) %in% columns_to_remove )

  # collect

  mat = mat[,-columns_to_remove , drop=FALSE]
  mat = cbind( as.data.table(mat) , as.data.table(predicted_categoricals)) # actually not matrix anymore

  # to original types

  factor_variables = original_column_names[ which(original_column_classes == 'factor') ]
  if(length(factor_variables) > 0){
    mat[ , (factor_variables) := lapply(.SD , as.factor) , .SDcols = factor_variables  ]
  }

  integer_variables = original_column_names[ which(original_column_classes == 'integer') ]
  if(length(integer_variables) > 0){
    mat[ , (integer_variables) := lapply(.SD , function(x) as.integer(round(x)) ) , .SDcols = integer_variables  ]
  }

  if('data.table' %ni% original_class ){
    setDF(mat)
  }

  return(mat)

}
yatzy/xgbimpute documentation built on June 7, 2019, 8:16 p.m.