R/prediction.R

Defines functions get_prediction plot_distribution get_important_variables get_confusion_matrix

Documented in get_confusion_matrix get_important_variables get_prediction plot_distribution

#' Functions for prediction
#' 
#' List of prediction functions
#' 
#' \itemize{
#'   \item \code{get_prediction} Calculates predicted outcomes and propensity scores from the trained model, and Add the prediction columns to original data frame
#'   \item \code{plot_distribution} generates plots of distributions with classes
#'   \item \code{get_confusion_matrix} calculates confusion matrix from prediction results. This confusion matrix contains model accuracy, target coverage, lift up value at each rank percentile
#'   \item \code{get_important_variables} provides the list of important variables from the ready-built model object
#' }
#' 
#' @param dta A dataframe that represents model prediction results, score, truth, label to be predicted
#' @param model A ready-built model object
#' @param truth A column name (in dta) that represents true labels of the predicted target
#' @param predicted A column name (in dta) that represents predicted classes
#' @param target_name A string that represents label name of the predicted target 
#' @param rank_percentile A list of percentile to display
#' @param ... A list of selected variables
#' 
#' 
#' @name prediction
#' 
#' @examples
#' library(tidymodel)
#' library(dplyr)
#' 
#' # Show data
#' as_tibble(dta)
#' 
#' get_important_variables(model)
#' 
#' get_confusion_matrix(dta, TARGET1, PREDICTED)
#'                             
#' out = plot_distribution(dta,
#'                         "TARGET1",
#'                         "VAR1",
#'                         "VAR2",
#'                         "VAR3",
#'                         "VAR4")
#'                   
#' out = plot_distribution(dta,
#'                         NULL,
#'                         "VAR1",
#'                         "VAR2",
#'                         "VAR3",
#'                         "VAR4")
#'                   
#' # Train data
#' dta2 = select(dta, TARGET1, VAR1:VAR20)
#' 
#'                       
###-----------------------------------------------------------------------------
NULL

###-----------------------------------------------------------------------------
#' @export
#' @return
#' \itemize{
#'   \item \code{get_confusion_matrix} returns a numeric dataframe of confusion matrix.
#'  }
#'  
#' @rdname prediction
###-----------------------------------------------------------------------------
get_confusion_matrix = function(dta, truth, predicted)
{
  ### Set variable names
  truth = dplyr::enquo(truth)
  predicted = dplyr::enquo(predicted)
  
  out = table(dplyr::pull(dta, !!truth), 
              dplyr::pull(dta, !!predicted))
  
  ### Return
  out
}

###-----------------------------------------------------------------------------
#' @export
#' @return 
#' \itemize{
#'   \item \code{get_important_variables} returns a list of important variables
#'  }
#' @rdname prediction
###-----------------------------------------------------------------------------
get_important_variables = function(model, plot = FALSE)
{
  out = h2o.varimp(model)  #Model summary
  
  if(plot == TRUE)
  {
    out %>%
      mutate(variable = reorder(variable, scaled_importance)) %>%    # Sort high to low
      ggplot(aes(variable, scaled_importance))  +
      geom_bar(stat="identity", show.legend=FALSE) +
      labs(x=NULL, y = "Normalized Importance",
           title = "Variables of Importance",
           subtitle = "Dormancy Model") +
      theme(panel.spacing = unit(0, "lines")) +
      theme_minimal() +
      coord_flip()
  }
  
  ### Return tibble
  tibble::as_tibble(out)
}

###-----------------------------------------------------------------------------
#' @export
#' @return 
#' \itemize{
#'   \item \code{plot_distribution} returns a plot of distributions
#'  }
#' @rdname prediction
###-----------------------------------------------------------------------------
plot_distribution = function(dta, truth = NULL, ...)
{
  ### No category
  if(is.null(truth))
  {
    ### Set variable names
    #var   = rlang::enexpr(...)
    
    ### Set variable names (temporary solutions)
    var   = rlang::quos(...)
    
    tmp =
      dta %>%
      dplyr::select(!!!var) %>%
      #dplyr::select(!!var) %>%
      tidyr::gather(key   = VARIABLE,
                    value = VALUE)
    
    out =
      tmp %>%
      ggplot2::ggplot(aes_(x = ~VALUE, y = ~VARIABLE, alpha = 0.1)) +
      #ggplot2::ggplot(aes(x = VALUE, y = VARIABLE, alpha = 0.1)) +
      ggridges::geom_density_ridges() +
      ggridges::scale_fill_cyclical(values = c("gray")) +
      ggridges::theme_ridges() +
      ggplot2::theme(legend.position = "none")
    
  }
  
  ### Plot by categories
  if(!is.null(truth))
  {
    ### Set variable names
    #truth = dplyr::enquo(truth)
    #var   = rlang::enexpr(...)
    
    ### Set variable names (temporary solutions)
    truth = rlang::sym(truth)
    var   = rlang::quos(...)
    
    tmp =
      dta %>%
      dplyr::select(!!truth,
                    !!!var) %>%
      #dplyr::select(!!truth,
      #              !!var) %>%
      tidyr::gather(key   = VARIABLE,
                    value = VALUE,
                    -!!truth)
    
    out =
      tmp %>%
      ggplot2::ggplot(aes_(x = ~VALUE, y = ~VARIABLE, fill = truth, alpha = 0.1)) +
      #ggplot2::ggplot(aes(x = VALUE, y = VARIABLE, fill = !!truth, alpha = 0.1)) +
      ggridges::geom_density_ridges() +
      ggridges::scale_fill_cyclical(values = c("red", "green")) +
      ggridges::theme_ridges() +
      ggplot2::theme(legend.position = "none")
    
  }
  
  
  ### Return tibble
  out
}



###-----------------------------------------------------------------------------
#' @export
#' @return 
#' \itemize{
#'   \item \code{get_prediction} returns a dataframe with additional columns: predicted outcomes and propensity score
#'  }
#' @rdname prediction
###-----------------------------------------------------------------------------
get_prediction = function(dta, model, predicted, score, target_name)
{
  ### Set variable names
  score_name     = dplyr::enquo(score)
  predicted_name = dplyr::enquo(predicted)
  target_name    = dplyr::enquo(target_name)
  
  out = NA
  
  ### H2o Library
  if (attr(class(model), "package") == "h2o")
  {
    cat("Running ", attr(class(model), "package"), ": ", class(model)[1])
    
    ### Suppress warning text
    options(warn=0)
    
    # Start H2O cluster using all available CPU threads
    #h2o.init(nthreads = -1) 
    suppressWarnings(invisible(capture.output(
      h2o.init(nthreads = -1)
    )))
    
    ### Convert to H2O dataframe
    invisible(capture.output(
      dta.hex <- h2o::as.h2o(dta)
    ))
    
    ### Predict 
    #dta.predicted = stats::predict(model, dta.hex)
    invisible(capture.output(
      dta.predicted <- as.data.frame(stats::predict(model, dta.hex))
    ))
    
    out = dplyr::mutate(dta,
                        !!quo_name(predicted_name) := dta.predicted$predict,
                        !!quo_name(score_name)     := dplyr::pull(dta.predicted, !!target_name))
  }
  
  ### TensorFLow by Keras library
  if (class(model)[1] == "keras.models.Sequential")
  {
    
  }
  
  ### Return tibble
  tibble::as_tibble(out)
}
ldanai/tidymodel documentation built on Jan. 4, 2020, 6:25 a.m.