#' Functions for prediction
#'
#' List of prediction functions
#'
#' \itemize{
#' \item \code{get_prediction} Calculates predicted outcomes and propensity scores from the trained model, and Add the prediction columns to original data frame
#' \item \code{plot_distribution} generates plots of distributions with classes
#' \item \code{get_confusion_matrix} calculates confusion matrix from prediction results. This confusion matrix contains model accuracy, target coverage, lift up value at each rank percentile
#' \item \code{get_important_variables} provides the list of important variables from the ready-built model object
#' }
#'
#' @param dta A dataframe that represents model prediction results, score, truth, label to be predicted
#' @param model A ready-built model object
#' @param truth A column name (in dta) that represents true labels of the predicted target
#' @param predicted A column name (in dta) that represents predicted classes
#' @param target_name A string that represents label name of the predicted target
#' @param rank_percentile A list of percentile to display
#' @param ... A list of selected variables
#'
#'
#' @name prediction
#'
#' @examples
#' library(tidymodel)
#' library(dplyr)
#'
#' # Show data
#' as_tibble(dta)
#'
#' get_important_variables(model)
#'
#' get_confusion_matrix(dta, TARGET1, PREDICTED)
#'
#' out = plot_distribution(dta,
#' "TARGET1",
#' "VAR1",
#' "VAR2",
#' "VAR3",
#' "VAR4")
#'
#' out = plot_distribution(dta,
#' NULL,
#' "VAR1",
#' "VAR2",
#' "VAR3",
#' "VAR4")
#'
#' # Train data
#' dta2 = select(dta, TARGET1, VAR1:VAR20)
#'
#'
###-----------------------------------------------------------------------------
NULL
###-----------------------------------------------------------------------------
#' @export
#' @return
#' \itemize{
#' \item \code{get_confusion_matrix} returns a numeric dataframe of confusion matrix.
#' }
#'
#' @rdname prediction
###-----------------------------------------------------------------------------
get_confusion_matrix = function(dta, truth, predicted)
{
### Set variable names
truth = dplyr::enquo(truth)
predicted = dplyr::enquo(predicted)
out = table(dplyr::pull(dta, !!truth),
dplyr::pull(dta, !!predicted))
### Return
out
}
###-----------------------------------------------------------------------------
#' @export
#' @return
#' \itemize{
#' \item \code{get_important_variables} returns a list of important variables
#' }
#' @rdname prediction
###-----------------------------------------------------------------------------
get_important_variables = function(model, plot = FALSE)
{
out = h2o.varimp(model) #Model summary
if(plot == TRUE)
{
out %>%
mutate(variable = reorder(variable, scaled_importance)) %>% # Sort high to low
ggplot(aes(variable, scaled_importance)) +
geom_bar(stat="identity", show.legend=FALSE) +
labs(x=NULL, y = "Normalized Importance",
title = "Variables of Importance",
subtitle = "Dormancy Model") +
theme(panel.spacing = unit(0, "lines")) +
theme_minimal() +
coord_flip()
}
### Return tibble
tibble::as_tibble(out)
}
###-----------------------------------------------------------------------------
#' @export
#' @return
#' \itemize{
#' \item \code{plot_distribution} returns a plot of distributions
#' }
#' @rdname prediction
###-----------------------------------------------------------------------------
plot_distribution = function(dta, truth = NULL, ...)
{
### No category
if(is.null(truth))
{
### Set variable names
#var = rlang::enexpr(...)
### Set variable names (temporary solutions)
var = rlang::quos(...)
tmp =
dta %>%
dplyr::select(!!!var) %>%
#dplyr::select(!!var) %>%
tidyr::gather(key = VARIABLE,
value = VALUE)
out =
tmp %>%
ggplot2::ggplot(aes_(x = ~VALUE, y = ~VARIABLE, alpha = 0.1)) +
#ggplot2::ggplot(aes(x = VALUE, y = VARIABLE, alpha = 0.1)) +
ggridges::geom_density_ridges() +
ggridges::scale_fill_cyclical(values = c("gray")) +
ggridges::theme_ridges() +
ggplot2::theme(legend.position = "none")
}
### Plot by categories
if(!is.null(truth))
{
### Set variable names
#truth = dplyr::enquo(truth)
#var = rlang::enexpr(...)
### Set variable names (temporary solutions)
truth = rlang::sym(truth)
var = rlang::quos(...)
tmp =
dta %>%
dplyr::select(!!truth,
!!!var) %>%
#dplyr::select(!!truth,
# !!var) %>%
tidyr::gather(key = VARIABLE,
value = VALUE,
-!!truth)
out =
tmp %>%
ggplot2::ggplot(aes_(x = ~VALUE, y = ~VARIABLE, fill = truth, alpha = 0.1)) +
#ggplot2::ggplot(aes(x = VALUE, y = VARIABLE, fill = !!truth, alpha = 0.1)) +
ggridges::geom_density_ridges() +
ggridges::scale_fill_cyclical(values = c("red", "green")) +
ggridges::theme_ridges() +
ggplot2::theme(legend.position = "none")
}
### Return tibble
out
}
###-----------------------------------------------------------------------------
#' @export
#' @return
#' \itemize{
#' \item \code{get_prediction} returns a dataframe with additional columns: predicted outcomes and propensity score
#' }
#' @rdname prediction
###-----------------------------------------------------------------------------
get_prediction = function(dta, model, predicted, score, target_name)
{
### Set variable names
score_name = dplyr::enquo(score)
predicted_name = dplyr::enquo(predicted)
target_name = dplyr::enquo(target_name)
out = NA
### H2o Library
if (attr(class(model), "package") == "h2o")
{
cat("Running ", attr(class(model), "package"), ": ", class(model)[1])
### Suppress warning text
options(warn=0)
# Start H2O cluster using all available CPU threads
#h2o.init(nthreads = -1)
suppressWarnings(invisible(capture.output(
h2o.init(nthreads = -1)
)))
### Convert to H2O dataframe
invisible(capture.output(
dta.hex <- h2o::as.h2o(dta)
))
### Predict
#dta.predicted = stats::predict(model, dta.hex)
invisible(capture.output(
dta.predicted <- as.data.frame(stats::predict(model, dta.hex))
))
out = dplyr::mutate(dta,
!!quo_name(predicted_name) := dta.predicted$predict,
!!quo_name(score_name) := dplyr::pull(dta.predicted, !!target_name))
}
### TensorFLow by Keras library
if (class(model)[1] == "keras.models.Sequential")
{
}
### Return tibble
tibble::as_tibble(out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.