#' The internal function used in the cat2cat one
#' @description apply the ml models to the cat2cat data
#' @param ml `list` the same `ml` argument as provided to `cat2cat` function.
#' @param mapp `list` a mapping table
#' @param target_data `data.frame`
#' @param cat_var_target `character(1)` name of the categorical variable
#' in the target period.
#' @keywords internal
cat2cat_ml <- function(ml, mapp, target_data, cat_var_target) {
validate_ml(ml)
stopifnot(all(ml$features %in% colnames(target_data)))
stopifnot(cat_var_target %in% colnames(target_data))
stopifnot(all(vapply(
target_data[, ml$features, drop = FALSE],
function(x) is.numeric(x) || is.logical(x), logical(1)
)))
features <- unique(ml$features)
methods <- unique(ml$method)
ml_names <- paste0("wei_", methods, "_c2c")
target_data[, ml_names] <- target_data["wei_freq_c2c"]
cat_ml_year_g <- split(
ml$data[, c(features, ml$cat_var), drop = FALSE],
factor(ml$data[[ml$cat_var]], exclude = NULL)
)
target_data_cats <- target_data[[cat_var_target]]
target_data_cat_c2c <- split(
target_data,
factor(target_data_cats, exclude = NULL)
)
for (cat in unique(names(target_data_cat_c2c))) {
try(
{
matched_cat <- match(cat, names(target_data_cat_c2c))
target_data_cat <- target_data_cat_c2c[[matched_cat]]
dis <- do.call(rbind, cat_ml_year_g[mapp[[match(cat, names(mapp))]]])
udc <- unique(dis[[ml$cat_var]])
if (length(udc) <= 1) {
target_data_cat_c2c[[matched_cat]][ml_names] <-
target_data_cat$wei_freq_c2c
next
}
if (
length(unique(target_data_cat$g_new_c2c)) > 1 &&
length(udc) >= 1 &&
nrow(target_data_cat) > 0 &&
any(unique(target_data_cat$g_new_c2c) %in% names(cat_ml_year_g))
) {
base_ml <-
target_data_cat[
!duplicated(target_data_cat[["index_c2c"]]),
c("index_c2c", features)
]
cc <- complete.cases(base_ml[, features])
for (m in methods) {
ml_name <- paste0("wei_", m, "_c2c")
if (m == "knn") {
group_prediction <- suppressWarnings(
caret::knn3(
x = dis[, features, drop = FALSE],
y = factor(dis[[ml$cat_var]]),
k = min(ml$args$k, ceiling(nrow(dis) / 4))
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
base_ml[cc, features, drop = FALSE],
type = "prob"
)
)
} else if (m == "rf") {
group_prediction <- suppressWarnings(
randomForest::randomForest(
y = factor(dis[[ml$cat_var]]),
x = dis[, features, drop = FALSE],
ntree = min(ml$args$ntree, 100)
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
base_ml[cc, features, drop = FALSE],
type = "prob"
)
)
} else if (m == "lda") {
group_prediction <- suppressWarnings(
MASS::lda(
grouping = factor(dis[[ml$cat_var]]),
x = as.matrix(dis[, features, drop = FALSE])
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
as.matrix(base_ml[cc, features, drop = FALSE])
)$posterior
)
}
ll <- setdiff(unique(target_data_cat$g_new_c2c), colnames(pp))
# imputing rest of the class to zero prob
if (length(ll)) {
pp[ll] <- 0
}
pp_stack <- utils::stack(pp)
pp[["index_c2c"]] <- base_ml[["index_c2c"]][cc]
res <- cbind(pp_stack, index_c2c = rep(pp$index_c2c, ncol(pp) - 1))
colnames(res) <- c("val", "g_new_c2c", "index_c2c")
ress <- merge(
target_data_cat[, c("index_c2c", "g_new_c2c")],
res,
by = c("index_c2c", "g_new_c2c"),
all.x = TRUE,
sort = FALSE
)
resso <- ress[order(ress$index_c2c), ]
target_data_cat_c2c[[
match(cat, names(target_data_cat_c2c))
]][[ml_name]] <- resso$val
}
}
},
silent = TRUE
)
}
target_data <- do.call(rbind, target_data_cat_c2c)
target_data <- target_data[order(target_data[["index_c2c"]]), ]
list(target_data = target_data)
}
# " Validate cat2cat ml
#' @keywords internal
validate_ml <- function(ml) {
stopifnot(all(c("method", "features", "data") %in% names(ml)))
stopifnot(all(ml$method %in% c("knn", "rf", "lda")))
if ("rf" %in% ml$method) {
delayed_package_load(
"randomForest",
sprintf("Please install %s package to use the %s model in the cat2cat function.", "randomForest", "rf")
)
}
if ("knn" %in% ml$method) {
delayed_package_load(
"caret",
sprintf("Please install %s package to use the %s model in the cat2cat function.", "caret", "knn")
)
}
stopifnot(ml$cat_var %in% colnames(ml$data))
stopifnot(all(ml$features %in% colnames(ml$data)))
stopifnot(all(vapply(
ml$data[, ml$features, drop = FALSE],
function(x) is.numeric(x) || is.logical(x), logical(1)
)))
}
# " Delayed load of a package
#' @keywords internal
delayed_package_load <- function(package, msg = sprintf("Please install %s package.", package)) {
if (isFALSE(suppressPackageStartupMessages(requireNamespace(package, quietly = TRUE)))) {
stop(msg)
}
}
#' Function to check cat2cat ml models performance
#' @description ml and mappings arguments in \code{\link{cat2cat}} function can be used to run cross validation across all groups in ml data.
#' @param ... additional options.
#' \itemize{
#' \item{test_prop}{ numeric(1) percent for test sample, Default 0.2}
#' \item{min_match}{ numeric(1) minimal match for cat_var variable with mappings, Default 0.5}
#' }
#' @inheritParams cat2cat
#' @seealso \code{\link{cat2cat}}
#' @export
#' @rdname cat2cat_ml_run
#' @examples
#' \dontrun{
#' library("cat2cat")
#' data("occup", package = "cat2cat")
#' data("trans", package = "cat2cat")
#'
#' occup_2006 <- occup[occup$year == 2006, ]
#' occup_2008 <- occup[occup$year == 2008, ]
#' occup_2010 <- occup[occup$year == 2010, ]
#' occup_2012 <- occup[occup$year == 2012, ]
#'
#' library("caret")
#' ml_setup <- list(
#' data = rbind(occup_2010, occup_2012),
#' cat_var = "code",
#' method = c("knn", "rf", "lda"),
#' features = c("age", "sex", "edu", "exp", "parttime", "salary"),
#' args = list(k = 10, ntree = 50)
#' )
#' data <- list(
#' old = occup_2008, new = occup_2010,
#' cat_var_old = "code", cat_var_new = "code", time_var = "year"
#' )
#' mappings <- list(trans = trans, direction = "backward")
#' res <- cat2cat_ml_run(mappings, ml_setup, test_prop = 0.2)
#' res
#' }
#'
cat2cat_ml_run <- function(mappings, ml, ...) {
stopifnot(is.list(ml))
stopifnot(is.list(mappings))
elargs <- list(...)
if (is.null(elargs$test_prop)) elargs$test_prop <- 0.2
stopifnot(length(elargs$test_prop) == 1 && elargs$test_prop > 0 && elargs$test_prop < 1)
if (is.null(elargs$min_match)) elargs$min_match <- 0.5
stopifnot(length(elargs$min_match) == 1 && elargs$min_match >= 0 && elargs$min_match < 1)
validate_mappings(mappings)
validate_ml(ml)
if (mappings$direction == "forward") {
base_name <- "old"
target_name <- "new"
} else if (mappings$direction == "backward") {
base_name <- "new"
target_name <- "old"
}
mapps <- get_mappings(mappings$trans)
mapp <- mapps[[paste0("to_", base_name)]]
cat_var <- ml$data[[ml$cat_var]]
cat_var_vals <- unlist(mappings$trans[, base_name])
if (sum(cat_var %in% cat_var_vals) / length(cat_var) < elargs$min_match) {
stop(
paste0(
"There is no mappings to group the cat_var variable.\n",
"Probably you should change the direction in the mappings argument.\n"
)
)
}
features <- unique(ml$features)
methods <- unique(ml$method)
train_g <- split(
ml$data[, c(features, ml$cat_var), drop = FALSE],
factor(ml$data[[ml$cat_var]], exclude = NULL)
)
res <- list()
for (cat in names(mapp)) {
try(
{
matched_cat <- mapp[[match(cat, names(mapp))]]
cat_nam <- if (cat == "") " " else cat
res[[cat_nam]] <- list(naive = NA_real_,
acc = stats::setNames(rep(NA_real_, length(methods)), methods), freq = NA_real_)
data_small_g <- do.call(rbind, train_g[matched_cat])
if (isTRUE(is.null(data_small_g) || nrow(data_small_g) < 5 ||
length(matched_cat) < 2 || sum(matched_cat %in% data_small_g[[ml$cat_var]]) == 1)) {
next
}
res[[cat_nam]][["naive"]] <- 1 / length(matched_cat)
index_tt <- sample(c(0, 1),
nrow(data_small_g),
prob = c(1 - elargs$test_prop, elargs$test_prop), replace = TRUE)
data_test_small <- data_small_g[index_tt == 1, ]
data_train_small <- data_small_g[index_tt == 0, ]
cc <- complete.cases(data_test_small[, features])
if (isTRUE(nrow(data_test_small[cc, ]) == 0 || nrow(data_train_small) < 5)) {
next
}
gcounts <- table(data_train_small[[ml$cat_var]])
gfreq <- names(gcounts)[which.max(gcounts)]
res[[cat_nam]][["freq"]] <- mean(gfreq == data_test_small[[ml$cat_var]])
for (m in methods) {
if (m == "knn") {
group_prediction <- suppressWarnings(
caret::knn3(
x = data_train_small[, features, drop = FALSE],
y = factor(data_train_small[[ml$cat_var]]),
k = min(ml$args$k, ceiling(nrow(data_train_small) / 4))
)
)
pred <- stats::predict(
group_prediction,
data_test_small[cc, features, drop = FALSE],
type = "class"
)
} else if (m == "rf") {
group_prediction <- suppressWarnings(
randomForest::randomForest(
y = factor(data_train_small[[ml$cat_var]]),
x = data_train_small[, features, drop = FALSE],
ntree = min(ml$args$ntree, 100)
)
)
pred <- stats::predict(
group_prediction,
data_test_small[cc, features, drop = FALSE]
)
} else if (m == "lda") {
group_prediction <- suppressWarnings(
MASS::lda(
grouping = factor(data_train_small[[ml$cat_var]]),
x = as.matrix(data_train_small[, features, drop = FALSE])
)
)
pred <- stats::predict(
group_prediction,
as.matrix(data_test_small[cc, features, drop = FALSE])
)$class
}
res[[cat_nam]][["acc"]][m] <- mean(pred == data_test_small[[ml$cat_var]])
}
},
silent = TRUE
)
}
invisible(structure(res, ml_models = methods, class = c("cat2cat_ml_run", "list")))
}
#' @rdname cat2cat_ml_run
#' @param x cat2cat_ml_run instance created with \code{\link{cat2cat_ml_run}} function.
#' @param ... other arguments
#' @return argument x invisibly
#' @method print cat2cat_ml_run
#' @export
print.cat2cat_ml_run <- function(x, ...) {
# Average accurecy - please take into account it is multi-level classification
ml_models <- attr(x, "ml_models")
ml_message <- NULL
how_ml_message_n <- NULL
how_ml_message_f <- NULL
na_message <- NULL
for (m in ml_models) {
acc <- mean(vapply(x, function(i) i$acc[m], numeric(1)), na.rm = T)
ml_message <- c(
ml_message,
sprintf("Average (groups) accuracy for %s ml models: %f", m, acc)
)
howaccn <- mean(vapply(x, function(i) i$naive < mean(i$acc[m], na.rm = TRUE), numeric(1)), na.rm = T)
how_ml_message_n <- c(
how_ml_message_n,
sprintf("How often %s ml model is better than naive guess: %f", m, howaccn)
)
howaccf <- mean(vapply(x, function(i) i$freq < mean(i$acc[m], na.rm = TRUE), numeric(1)), na.rm = T)
how_ml_message_f <- c(
how_ml_message_f,
sprintf("How often %s ml model is better than most frequent category solution: %f", m, howaccf)
)
nna <- vapply(x, function(i) is.na(i$acc[m]), logical(1))
pna <- sum(nna) / length(nna) * 100
na_message <- c(
na_message,
sprintf("Percent of failed %s ml models: %f", m, pna)
)
}
acc_freq <- mean(vapply(x, function(i) i$freq, numeric(1)), na.rm = T)
acc_naive <- mean(vapply(x, function(i) i$naive, numeric(1)), na.rm = T)
ml_over_freq <- mean(vapply(x, function(i) i$freq < mean(i$acc, na.rm = TRUE), numeric(1)), na.rm = T)
cat(
paste(
c(
"Selected prediction stats:",
"",
sprintf("Average naive (equal probabilities) guess: %f", acc_naive),
sprintf("Average (groups) accuracy for most frequent category solution: %f", acc_freq),
ml_message,
"",
na_message,
"",
how_ml_message_n,
"",
how_ml_message_f
),
collapse = "\n"
),
"\n"
)
invisible(x)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.