Nothing
# Preparing data for evaluation
prepare_id_level_evaluation <- function(data,
target_col,
prediction_cols,
family,
id_col,
id_method,
groups_col,
cutoff = NULL,
apply_softmax = length(prediction_cols) > 1,
new_prediction_col_name = "prediction",
new_std_col_name = "std") {
# Prepare data for id evaluation
if (is.null(id_col)) {
stop("'id_col' was NULL.")
}
if (!is.character(id_col)) {
stop("'id_col' must be either the name of a column in 'data' or NULL.")
}
if (id_col %ni% colnames(data)) {
stop(paste0("could not find 'id_col', ", id_col, ", in 'data'."))
}
if (isTRUE(apply_softmax)) {
if (length(prediction_cols) < 2) {
stop("can only apply softmax when there are more than one 'prediction_cols'.")
}
data <- softmax(data, cols = prediction_cols)
}
if (family == "binomial") {
if (is.null(cutoff)) {
stop("when 'family' is 'binomial', 'cutoff' must be numeric between 0 and 1.")
}
if (is.null(cutoff)) {
stop("when 'family' is 'binomial', 'cutoff' must be numeric between 0 and 1.")
}
}
num_groups <- length(unique(data[[groups_col]]))
# Add actual class
id_classes <-
extract_id_classes(
data = data,
groups_col = groups_col,
id_col = id_col,
target_col = target_col
)
if (id_method == "mean") {
data_for_id_evaluation <- data %>%
dplyr::group_by(!!as.name(groups_col), !!as.name(id_col))
# Average prediction
data_for_id_evaluation_mean <- data_for_id_evaluation %>%
dplyr::summarise_at(dplyr::vars(prediction_cols), .funs = mean)
# Standard deviation
data_for_id_evaluation_std <- data_for_id_evaluation %>%
dplyr::summarise_at(dplyr::vars(prediction_cols), .funs = sd)
colnames(data_for_id_evaluation_std) <- ifelse(
colnames(data_for_id_evaluation_std) %in% prediction_cols,
paste0(colnames(data_for_id_evaluation_std), "_STD"),
colnames(data_for_id_evaluation_std)
)
# Combine
data_for_id_evaluation <- data_for_id_evaluation_mean %>%
dplyr::left_join(data_for_id_evaluation_std,
by = c(groups_col, id_col)) %>%
dplyr::left_join(id_classes, by = c(id_col, groups_col)) %>%
dplyr::ungroup()
} else if (id_method == "majority") {
## By majority vote
# If multiple classes share the majority, they also share the probability! #mustShare!
if (family == "multinomial") {
data[["predicted_class_index"]] <- data %>%
base_select(cols = prediction_cols) %>%
argmax()
data[["predicted_class"]] <- purrr::map_chr(
data[["predicted_class_index"]],
.f = function(x) {
prediction_cols[[x]]
}
)
data_majority_count_by_id <- data %>%
dplyr::group_by(!!as.name(groups_col), !!as.name(id_col)) %>%
dplyr::count(.data$predicted_class) %>%
dplyr::left_join(id_classes, by = c(id_col, groups_col)) %>%
dplyr::ungroup()
# In case not all classes were predicted
# We add them with NA values
# TODO Add unit test to make sure this part works!
classes_to_add <- setdiff(
prediction_cols, # use for testing: c(prediction_cols, "cl_7"),
data_majority_count_by_id[["predicted_class"]]
)
if (length(classes_to_add) > 0) {
na_cols <- dplyr::bind_rows(setNames(rep(
list(rep(NA, num_groups * nlevels(factor(data_majority_count_by_id[[id_col]])))),
length(classes_to_add)
), classes_to_add))
}
# Spread to probability tibble with NAs where NO majority was
majority_vote_probabilities <- data_majority_count_by_id %>%
tidyr::spread(key = "predicted_class", value = "n")
if (length(classes_to_add) > 0) {
majority_vote_probabilities <- majority_vote_probabilities %>%
dplyr::bind_cols(na_cols) %>%
base_select(cols = c(groups_col, id_col, target_col, prediction_cols))
}
# Set NAs to 0
majority_vote_probabilities[is.na(majority_vote_probabilities)] <- 0
# Make all the majority votes ~10000
# so they will be extremely close to 1 when applying softmax
# then we will apply softmax in a later step
majority_vote_probabilities <- majority_vote_probabilities %>%
dplyr::mutate_at(prediction_cols, function(x) {
(x / (x + 1e-30)) * 1e+4
})
data_for_id_evaluation <- majority_vote_probabilities %>%
dplyr::ungroup()
} else if (family == "binomial") {
if (length(prediction_cols) > 1) {
stop("when 'family' is 'binomial', length of 'prediction_cols' should be 1.")
}
data[["predicted_class"]] <- ifelse(data[[prediction_cols]] > cutoff, 1, 0)
# Create majority count by id
data_for_id_evaluation <- data %>%
dplyr::group_by(!!as.name(groups_col), !!as.name(id_col)) %>%
dplyr::summarise(mean_prediction = mean(.data$predicted_class)) %>%
dplyr::mutate(mean_prediction = dplyr::case_when(
mean_prediction > cutoff ~ 1 - 1e-40, # Almost 1
mean_prediction == cutoff ~ cutoff, # split decision (unlikely to happen)
mean_prediction < cutoff ~ 1e-40 # Almost 0
)) %>%
dplyr::rename_at("mean_prediction", ~prediction_cols) %>%
dplyr::left_join(id_classes, by = c(id_col, groups_col)) %>%
dplyr::ungroup()
} else {
stop(paste0("family ", family, " not currently supported for majority vote aggregated ID evaluation."))
}
}
if (family == "multinomial") {
data_for_id_evaluation <- prepare_multinomial_evaluation(
data_for_id_evaluation,
target_col = target_col,
prediction_cols = prediction_cols,
apply_softmax = TRUE,
new_prediction_col_name = new_prediction_col_name,
new_std_col_name = new_std_col_name
)
} else {
if (id_method == "mean"){
data_for_id_evaluation <- data_for_id_evaluation %>%
base_rename(before = paste0(prediction_cols, "_STD"),
after = new_std_col_name)
}
}
data_for_id_evaluation
}
prepare_multinomial_evaluation <- function(data,
target_col,
prediction_cols,
apply_softmax,
new_prediction_col_name,
new_std_col_name) {
# Split data into probabilities and the rest
col_split <- extract_and_remove_probability_cols(
data = data,
prediction_cols = prediction_cols
)
predicted_probabilities <- col_split[["predicted_probabilities"]]
standard_deviations <- col_split[["standard_deviations"]]
data <- col_split[["data"]]
if (isTRUE(apply_softmax)) {
predicted_probabilities <- softmax(predicted_probabilities)
}
data[[new_prediction_col_name]] <- predicted_probabilities %>%
nest_rowwise()
if (!is.null(standard_deviations)){
data[[new_std_col_name]] <- standard_deviations %>%
nest_rowwise()
}
# TODO Do we need to do anything to the dependent column? E.g. make numeric or check against names in prediction_cols?
# TODO What if the classes are numeric, then what will prediction_cols contain?
if (length(setdiff(levels_as_characters(data[[target_col]]), prediction_cols)) > 0) {
stop("Not all levels in 'target_col' was found in 'prediction_cols'.")
}
data
}
extract_and_remove_probability_cols <- function(data, prediction_cols) {
# Split data into probabilities and the rest
predicted_probabilities <- data %>%
base_select(cols = prediction_cols)
data <- data %>%
base_deselect(cols = prediction_cols)
# Split data into standard deviations and the rest
if (paste0(prediction_cols[[1]], "_STD") %in% colnames(data)){
standard_deviations <- data %>%
base_select(cols = paste0(prediction_cols, "_STD"))
data <- data %>%
base_deselect(cols = colnames(standard_deviations))
colnames(standard_deviations) <- prediction_cols
} else {
standard_deviations <- NULL
}
# Test that the probability columns are numeric
if (any(!sapply(predicted_probabilities, is.numeric))) {
stop("the prediction columns must be numeric.")
}
if (any(!sapply(standard_deviations, is.numeric))) {
stop("the standard deviation columns must be numeric.")
}
list(
"data" = data,
"predicted_probabilities" = predicted_probabilities,
"standard_deviations" = standard_deviations
)
}
extract_id_classes <- function(data, groups_col, id_col, target_col) {
id_classes <- data %>%
base_select(cols = c(groups_col, id_col, target_col)) %>%
dplyr::distinct()
# Make sure the targets are constant within IDs
check_constant_targets_within_ids(
distinct_data = id_classes,
groups_col = groups_col,
id_col = id_col
)
id_classes
}
check_constant_targets_within_ids <- function(distinct_data, groups_col, id_col) {
# Checks that there's only one unique value for each ID (per group)
counts <- distinct_data %>%
dplyr::group_by(!!as.name(groups_col), !!as.name(id_col)) %>%
dplyr::summarize(n = dplyr::n())
if (any(counts$n > 1)) {
non_constant_ids <- counts %>%
dplyr::filter(.data$n > 1) %>%
dplyr::pull(!!as.name(id_col))
stop(paste0(
"The targets must be constant within the IDs with the current ID method. ",
"These IDs had more than one unique value in the target column: ",
paste0(head(non_constant_ids, 5),
collapse = ", "
),
ifelse(length(non_constant_ids) > 5, ", ...", ""),
"."
))
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.