Nothing
make_form <- function(.vars, .enc = c("dummy", "onehot")) {
.enc_type <- match.arg(.enc)
string <- paste(
"~",
paste(.vars, collapse = " + "),
if (.enc_type == "onehot") "- 1"
)
stats::as.formula(string)
}
set_levels <- function(.f, .l) {
factor(.f, levels = .l)
}
model_matrix <- function(.df,
.enc_type = c("dummy", "onehot"),
.cols = names(.df),
.levels = NULL) {
.enc_type <- match.arg(.enc_type)
.df[.cols] <- lapply(.df[.cols], as.factor)
if (!is.null(.levels)) {
.df[.cols] <- Map(set_levels, .f = .df[.cols], .l = .levels)
}
form <- make_form(.cols, .enc_type)
mf <- stats::model.frame(form, data = .df, na.action = stats::na.pass)
if (identical(.enc_type, "onehot")) {
cs <- lapply(.df[.cols], stats::contrasts, contrasts = FALSE)
mm <- stats::model.matrix(form, mf, contrasts.arg = cs)
} else {
mm <- stats::model.matrix(form, mf)
mm <- mm[, colnames(mm) != "(Intercept)", drop = FALSE]
}
attr(mm, "contrasts") <- NULL
attr(mm, "assign") <- NULL
rownames(mm) <- NULL
mm
}
na_new_levels <- function(.x, .orig_levels) {
replace(.x, !is.element(.x, .orig_levels), NA)
}
df_to_binary <- function(.df, .enc, .cats, .levs = NULL) {
df_cat <- .df[.cats]
df_keep <- .df[setdiff(names(.df), .cats)]
rm(.df)
df_cat <- model_matrix(df_cat, .enc_type = .enc, .levels = .levs)
dplyr::bind_cols(df_keep, as.data.frame(df_cat))
}
dummy_onehot <- function(.enc_type) {
function(train, ..., test = NULL, verbose = TRUE) {
validate_col_types(train)
test_also <- !is.null(test)
if (test_also) check_train_test(train, test)
cats <- pick_cols(train, deparse(substitute(train)), ...)
train_expanded <- df_to_binary(train, .enc_type, cats)
if (test_also) {
train_levels <- lapply(train[cats], unique)
test[cats] <- Map(na_new_levels,
.x = test[cats],
.orig_levels = train_levels
)
test_expanded <- df_to_binary(test, .enc_type, cats, train_levels)
}
if (!test_also) {
cattonum_df(train_expanded)
} else {
cattonum_df2(train = train_expanded, test = test_expanded)
}
}
}
#' One-hot encoding
#'
#' @param train The training data, in a `data.frame` or `tibble`.
#' @param ... The columns to be encoded. If none are specified, then
#' all character and factor columns are encoded.
#' @param test The test data, in a `data.frame` or `tibble`.
#' @param verbose Should informative messages be printed? Defaults to
#' `TRUE` (not yet used).
#' @return The encoded dataset in a `cattonum_df` if no test dataset was
#' provided, and the encoded datasets in a `cattonum_df2` otherwise.
#' @examples
#' catto_onehot(iris)
#' @export
catto_onehot <- function(train, ..., test, verbose = TRUE) {
UseMethod("catto_onehot")
}
#' @export
# nolint start
catto_onehot.data.frame <- dummy_onehot("onehot")
# nolint end
#' Dummy encoding
#'
#' @param train The training data, in a `data.frame` or `tibble`.
#' @param ... The columns to be encoded. If none are specified, then
#' all character and factor columns are encoded.
#' @param test The test data, in a `data.frame` or `tibble`.
#' @param verbose Should informative messages be printed? Defaults to
#' `TRUE` (not yet used).
#' @return The encoded dataset in a `cattonum_df` if no test dataset was
#' provided, and the encoded datasets in a `cattonum_df2` otherwise.
#' @examples
#' catto_dummy(iris)
#' @export
catto_dummy <- function(train, ..., test, verbose = TRUE) {
UseMethod("catto_dummy")
}
#' @export
# nolint start
catto_dummy.data.frame <- dummy_onehot("dummy")
# nolint end
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.