R/infogram.R

Defines functions plot.H2OInfogram .h2o.train_segments_infogram h2o.infogram

Documented in h2o.infogram plot.H2OInfogram

# This file is auto-generated by h2o-3/h2o-bindings/bin/gen_R.py
# Copyright 2016 H2O.ai;  Apache License Version 2.0 (see LICENSE for details) 
#'
# -------------------------- Infogram -------------------------- #
#'
#' H2O Infogram
#' 
#' The infogram is a graphical information-theoretic interpretability tool which allows the user to quickly spot the core, decision-making variables 
#' that uniquely and safely drive the response, in supervised classification problems. The infogram can significantly cut down the number of predictors needed to build 
#' a model by identifying only the most valuable, admissible features. When protected variables such as race or gender are present in the data, the admissibility 
#' of a variable is determined by a safety and relevancy index, and thus serves as a diagnostic tool for fairness. The safety of each feature can be quantified and 
#' variables that are unsafe will be considered inadmissible. Models built using only admissible features will naturally be more interpretable, given the reduced 
#' feature set.  Admissible models are also less susceptible to overfitting and train faster, while providing similar accuracy as models built using all available features.
#' 
#' The infogram allows the user to quickly spot the admissible decision-making variables that are driving the response.  
#' There are two types of infogram plots: Core and Fair Infogram.
#' 
#' The Core Infogram plots all the variables as points on two-dimensional grid of total vs net information.  The x-axis is total information, 
#' a measure of how much the variable drives the response (the more predictive, the higher the total information). 
#' The y-axis is net information, a measure of how unique the variable is.  The top right quadrant of the infogram plot is the admissible section; the variables
#' located in this quadrant are the admissible features.  In the Core Infogram, the admissible features are the strongest, unique drivers of 
#' the response.
#' 
#' If sensitive or protected variables are present in data, the user can specify which attributes should be protected while training using the \code{protected_columns} 
#' argument. All non-protected predictor variables will be checked to make sure that there's no information pathway to the response through a protected feature, and 
#' deemed inadmissible if they possess little or no informational value beyond their use as a dummy for protected attributes. The Fair Infogram plots all the features 
#' as points on two-dimensional grid of relevance vs safety.  The x-axis is relevance index, a measure of how much the variable drives the response (the more predictive, 
#' the higher the relevance). The y-axis is safety index, a measure of how much extra information the variable has that is not acquired through the protected variables.  
#' In the Fair Infogram, the admissible features are the strongest, safest drivers of the response.
#' 
#'
#' @param x (Optional) A vector containing the names or indices of the predictor variables to use in building the model.
#'        If x is missing, then all columns except y are used.
#' @param y The name or column index of the response variable in the data. 
#'        The response must be either a numeric or a categorical/factor variable. 
#'        If the response is numeric, then a regression model will be trained, otherwise it will train a classification model.
#' @param training_frame Id of the training data frame.
#' @param model_id Destination id for this model; auto-generated if not specified.
#' @param validation_frame Id of the validation data frame.
#' @param seed Seed for random numbers (affects certain parts of the algo that are stochastic and those might or might not be enabled by default).
#'        Defaults to -1 (time-based random number).
#' @param keep_cross_validation_models \code{Logical}. Whether to keep the cross-validation models. Defaults to TRUE.
#' @param keep_cross_validation_predictions \code{Logical}. Whether to keep the predictions of the cross-validation models. Defaults to FALSE.
#' @param keep_cross_validation_fold_assignment \code{Logical}. Whether to keep the cross-validation fold assignment. Defaults to FALSE.
#' @param nfolds Number of folds for K-fold cross-validation (0 to disable or >= 2). Defaults to 0.
#' @param fold_assignment Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will
#'        stratify the folds based on the response variable, for classification problems. Must be one of: "AUTO",
#'        "Random", "Modulo", "Stratified". Defaults to AUTO.
#' @param fold_column Column with cross-validation fold index assignment per observation.
#' @param ignore_const_cols \code{Logical}. Ignore constant columns. Defaults to TRUE.
#' @param score_each_iteration \code{Logical}. Whether to score during each iteration of model training. Defaults to FALSE.
#' @param offset_column Offset column. This will be added to the combination of columns before applying the link function.
#' @param weights_column Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from
#'        the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative
#'        weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the
#'        data frame. This is typically the number of times a row is repeated, but non-integer values are supported as
#'        well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If
#'        you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get
#'        an accurate prediction, remove all rows with weight == 0.
#' @param standardize \code{Logical}. Standardize numeric columns to have zero mean and unit variance. Defaults to FALSE.
#' @param distribution Distribution function Must be one of: "AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma",
#'        "tweedie", "laplace", "quantile", "huber". Defaults to AUTO.
#' @param plug_values Plug Values (a single row frame containing values that will be used to impute missing values of the
#'        training/validation frame, use with conjunction missing_values_handling = PlugValues).
#' @param max_iterations Maximum number of iterations. Defaults to 0.
#' @param stopping_rounds Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the
#'        stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable) Defaults to 0.
#' @param stopping_metric Metric to use for early stopping (AUTO: logloss for classification, deviance for regression and anomaly_score
#'        for Isolation Forest). Note that custom and custom_increasing can only be used in GBM and DRF with the Python
#'        client. Must be one of: "AUTO", "deviance", "logloss", "MSE", "RMSE", "MAE", "RMSLE", "AUC", "AUCPR",
#'        "lift_top_group", "misclassification", "mean_per_class_error", "custom", "custom_increasing". Defaults to
#'        AUTO.
#' @param stopping_tolerance Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this
#'        much) Defaults to 0.001.
#' @param balance_classes \code{Logical}. Balance training data class counts via over/under-sampling (for imbalanced data). Defaults to
#'        FALSE.
#' @param class_sampling_factors Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling factors will
#'        be automatically computed to obtain class balance during training. Requires balance_classes.
#' @param max_after_balance_size Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires
#'        balance_classes. Defaults to 5.0.
#' @param max_runtime_secs Maximum allowed runtime in seconds for model training. Use 0 to disable. Defaults to 0.
#' @param custom_metric_func Reference to custom evaluation function, format: `language:keyName=funcName`
#' @param auc_type Set default multinomial AUC type. Must be one of: "AUTO", "NONE", "MACRO_OVR", "WEIGHTED_OVR", "MACRO_OVO",
#'        "WEIGHTED_OVO". Defaults to AUTO.
#' @param algorithm Type of machine learning algorithm used to build the infogram. Options include 'AUTO' (gbm), 'deeplearning'
#'        (Deep Learning with default parameters), 'drf' (Random Forest with default parameters), 'gbm' (GBM with
#'        default parameters), 'glm' (GLM with default parameters), or 'xgboost' (if available, XGBoost with default
#'        parameters). Must be one of: "AUTO", "deeplearning", "drf", "gbm", "glm", "xgboost". Defaults to AUTO.
#' @param algorithm_params Customized parameters for the machine learning algorithm specified in the algorithm parameter.
#' @param protected_columns Columns that contain features that are sensitive and need to be protected (legally, or otherwise), if
#'        applicable. These features (e.g. race, gender, etc) should not drive the prediction of the response.
#' @param total_information_threshold A number between 0 and 1 representing a threshold for total information, defaulting to 0.1. For a specific
#'        feature, if the total information is higher than this threshold, and the corresponding net information is also
#'        higher than the threshold ``net_information_threshold``, that feature will be considered admissible. The total
#'        information is the x-axis of the Core Infogram. Default is -1 which gets set to 0.1. Defaults to -1.
#' @param net_information_threshold A number between 0 and 1 representing a threshold for net information, defaulting to 0.1.  For a specific
#'        feature, if the net information is higher than this threshold, and the corresponding total information is also
#'        higher than the total_information_threshold, that feature will be considered admissible. The net information
#'        is the y-axis of the Core Infogram. Default is -1 which gets set to 0.1. Defaults to -1.
#' @param relevance_index_threshold A number between 0 and 1 representing a threshold for the relevance index, defaulting to 0.1.  This is only
#'        used when ``protected_columns`` is set by the user.  For a specific feature, if the relevance index value is
#'        higher than this threshold, and the corresponding safety index is also higher than the
#'        safety_index_threshold``, that feature will be considered admissible.  The relevance index is the x-axis of
#'        the Fair Infogram. Default is -1 which gets set to 0.1. Defaults to -1.
#' @param safety_index_threshold A number between 0 and 1 representing a threshold for the safety index, defaulting to 0.1.  This is only used
#'        when protected_columns is set by the user.  For a specific feature, if the safety index value is higher than
#'        this threshold, and the corresponding relevance index is also higher than the relevance_index_threshold, that
#'        feature will be considered admissible.  The safety index is the y-axis of the Fair Infogram. Default is -1
#'        which gets set to 0.1. Defaults to -1.
#' @param data_fraction The fraction of training frame to use to build the infogram model. Defaults to 1.0, and any value greater than
#'        0 and less than or equal to 1.0 is acceptable. Defaults to 1.
#' @param top_n_features An integer specifying the number of columns to evaluate in the infogram.  The columns are ranked by variable
#'        importance, and the top N are evaluated.  Defaults to 50. Defaults to 50.
#' @examples
#' \dontrun{
#' h2o.init()
#' 
#' # Convert iris dataset to an H2OFrame    
#' df <- as.h2o(iris)
#' 
#' # Infogram
#' ig <- h2o.infogram(y = "Species", training_frame = df) 
#' plot(ig)
#' 
#' }
#' @export
h2o.infogram <- function(x,
                         y,
                         training_frame,
                         model_id = NULL,
                         validation_frame = NULL,
                         seed = -1,
                         keep_cross_validation_models = TRUE,
                         keep_cross_validation_predictions = FALSE,
                         keep_cross_validation_fold_assignment = FALSE,
                         nfolds = 0,
                         fold_assignment = c("AUTO", "Random", "Modulo", "Stratified"),
                         fold_column = NULL,
                         ignore_const_cols = TRUE,
                         score_each_iteration = FALSE,
                         offset_column = NULL,
                         weights_column = NULL,
                         standardize = FALSE,
                         distribution = c("AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber"),
                         plug_values = NULL,
                         max_iterations = 0,
                         stopping_rounds = 0,
                         stopping_metric = c("AUTO", "deviance", "logloss", "MSE", "RMSE", "MAE", "RMSLE", "AUC", "AUCPR", "lift_top_group", "misclassification", "mean_per_class_error", "custom", "custom_increasing"),
                         stopping_tolerance = 0.001,
                         balance_classes = FALSE,
                         class_sampling_factors = NULL,
                         max_after_balance_size = 5.0,
                         max_runtime_secs = 0,
                         custom_metric_func = NULL,
                         auc_type = c("AUTO", "NONE", "MACRO_OVR", "WEIGHTED_OVR", "MACRO_OVO", "WEIGHTED_OVO"),
                         algorithm = c("AUTO", "deeplearning", "drf", "gbm", "glm", "xgboost"),
                         algorithm_params = NULL,
                         protected_columns = NULL,
                         total_information_threshold = -1,
                         net_information_threshold = -1,
                         relevance_index_threshold = -1,
                         safety_index_threshold = -1,
                         data_fraction = 1,
                         top_n_features = 50)
{
  # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object
  training_frame <- .validate.H2OFrame(training_frame, required=TRUE)
  validation_frame <- .validate.H2OFrame(validation_frame, required=FALSE)

  # Validate other required args
  # If x is missing, then assume user wants to use all columns as features.
  if (missing(x)) {
     if (is.numeric(y)) {
         x <- setdiff(col(training_frame), y)
     } else {
         x <- setdiff(colnames(training_frame), y)
     }
  }

  # Build parameter list to send to model builder
  parms <- list()
  parms$training_frame <- training_frame
  args <- .verify_dataxy(training_frame, x, y)
  if (missing(protected_columns)) { 
    # core infogram
    if (!missing(safety_index_threshold)) {
      warning("Should not set safety_index_threshold for Core Infogram runs. Set net_information_threshold instead.")
    }
    if (!missing(relevance_index_threshold)) {
      warning("Should not set relevance_index_threshold for Core Infogram runs. Set total_information_threshold instead.")
    }
  } else { 
    # fair infogram
    if (!missing(net_information_threshold)) {
    warning("Should not set net_information_threshold for Fair Infogram runs, set safety_index_threshold instead.")
    }
    if (!missing(total_information_threshold)) {
      warning("Should not set total_information_threshold for Fair Infogram runs, set relevance_index_threshold instead.")
    }
  }

  if (!missing(offset_column) && !is.null(offset_column))  args$x_ignore <- args$x_ignore[!( offset_column == args$x_ignore )]
  if (!missing(weights_column) && !is.null(weights_column)) args$x_ignore <- args$x_ignore[!( weights_column == args$x_ignore )]
  if (!missing(fold_column) && !is.null(fold_column)) args$x_ignore <- args$x_ignore[!( fold_column == args$x_ignore )]
  parms$ignored_columns <- args$x_ignore
  parms$response_column <- args$y

  if (!missing(model_id))
    parms$model_id <- model_id
  if (!missing(validation_frame))
    parms$validation_frame <- validation_frame
  if (!missing(seed))
    parms$seed <- seed
  if (!missing(keep_cross_validation_models))
    parms$keep_cross_validation_models <- keep_cross_validation_models
  if (!missing(keep_cross_validation_predictions))
    parms$keep_cross_validation_predictions <- keep_cross_validation_predictions
  if (!missing(keep_cross_validation_fold_assignment))
    parms$keep_cross_validation_fold_assignment <- keep_cross_validation_fold_assignment
  if (!missing(nfolds))
    parms$nfolds <- nfolds
  if (!missing(fold_assignment))
    parms$fold_assignment <- fold_assignment
  if (!missing(fold_column))
    parms$fold_column <- fold_column
  if (!missing(ignore_const_cols))
    parms$ignore_const_cols <- ignore_const_cols
  if (!missing(score_each_iteration))
    parms$score_each_iteration <- score_each_iteration
  if (!missing(offset_column))
    parms$offset_column <- offset_column
  if (!missing(weights_column))
    parms$weights_column <- weights_column
  if (!missing(standardize))
    parms$standardize <- standardize
  if (!missing(distribution))
    parms$distribution <- distribution
  if (!missing(plug_values))
    parms$plug_values <- plug_values
  if (!missing(max_iterations))
    parms$max_iterations <- max_iterations
  if (!missing(stopping_rounds))
    parms$stopping_rounds <- stopping_rounds
  if (!missing(stopping_metric))
    parms$stopping_metric <- stopping_metric
  if (!missing(stopping_tolerance))
    parms$stopping_tolerance <- stopping_tolerance
  if (!missing(balance_classes))
    parms$balance_classes <- balance_classes
  if (!missing(class_sampling_factors))
    parms$class_sampling_factors <- class_sampling_factors
  if (!missing(max_after_balance_size))
    parms$max_after_balance_size <- max_after_balance_size
  if (!missing(max_runtime_secs))
    parms$max_runtime_secs <- max_runtime_secs
  if (!missing(custom_metric_func))
    parms$custom_metric_func <- custom_metric_func
  if (!missing(auc_type))
    parms$auc_type <- auc_type
  if (!missing(algorithm))
    parms$algorithm <- algorithm
  if (!missing(protected_columns))
    parms$protected_columns <- protected_columns
  if (!missing(total_information_threshold))
    parms$total_information_threshold <- total_information_threshold
  if (!missing(net_information_threshold))
    parms$net_information_threshold <- net_information_threshold
  if (!missing(relevance_index_threshold))
    parms$relevance_index_threshold <- relevance_index_threshold
  if (!missing(safety_index_threshold))
    parms$safety_index_threshold <- safety_index_threshold
  if (!missing(data_fraction))
    parms$data_fraction <- data_fraction
  if (!missing(top_n_features))
    parms$top_n_features <- top_n_features

  if (!missing(algorithm_params))
      parms$algorithm_params <- as.character(toJSON(algorithm_params, pretty = TRUE))

  # Error check and build model
  model <- .h2o.modelJob('infogram', parms, h2oRestApiVersion=3, verbose=FALSE)

  # Convert algorithm_params back to list if not NULL, added after obtaining model
  if (!missing(algorithm_params)) {
      model@parameters$algorithm_params <- list(fromJSON(model@parameters$algorithm_params))[[1]] #Need the `[[ ]]` to avoid a nested list
  }

  infogram_model <- new("H2OInfogram", model_id=model@model_id)       
  model <- infogram_model                
  return(model)
}
.h2o.train_segments_infogram <- function(x,
                                         y,
                                         training_frame,
                                         validation_frame = NULL,
                                         seed = -1,
                                         keep_cross_validation_models = TRUE,
                                         keep_cross_validation_predictions = FALSE,
                                         keep_cross_validation_fold_assignment = FALSE,
                                         nfolds = 0,
                                         fold_assignment = c("AUTO", "Random", "Modulo", "Stratified"),
                                         fold_column = NULL,
                                         ignore_const_cols = TRUE,
                                         score_each_iteration = FALSE,
                                         offset_column = NULL,
                                         weights_column = NULL,
                                         standardize = FALSE,
                                         distribution = c("AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber"),
                                         plug_values = NULL,
                                         max_iterations = 0,
                                         stopping_rounds = 0,
                                         stopping_metric = c("AUTO", "deviance", "logloss", "MSE", "RMSE", "MAE", "RMSLE", "AUC", "AUCPR", "lift_top_group", "misclassification", "mean_per_class_error", "custom", "custom_increasing"),
                                         stopping_tolerance = 0.001,
                                         balance_classes = FALSE,
                                         class_sampling_factors = NULL,
                                         max_after_balance_size = 5.0,
                                         max_runtime_secs = 0,
                                         custom_metric_func = NULL,
                                         auc_type = c("AUTO", "NONE", "MACRO_OVR", "WEIGHTED_OVR", "MACRO_OVO", "WEIGHTED_OVO"),
                                         algorithm = c("AUTO", "deeplearning", "drf", "gbm", "glm", "xgboost"),
                                         algorithm_params = NULL,
                                         protected_columns = NULL,
                                         total_information_threshold = -1,
                                         net_information_threshold = -1,
                                         relevance_index_threshold = -1,
                                         safety_index_threshold = -1,
                                         data_fraction = 1,
                                         top_n_features = 50,
                                         segment_columns = NULL,
                                         segment_models_id = NULL,
                                         parallelism = 1)
{
  # formally define variables that were excluded from function parameters
  model_id <- NULL
  verbose <- NULL
  destination_key <- NULL
  # Validate required training_frame first and other frame args: should be a valid key or an H2OFrame object
  training_frame <- .validate.H2OFrame(training_frame, required=TRUE)
  validation_frame <- .validate.H2OFrame(validation_frame, required=FALSE)

  # Validate other required args
  # If x is missing, then assume user wants to use all columns as features.
  if (missing(x)) {
     if (is.numeric(y)) {
         x <- setdiff(col(training_frame), y)
     } else {
         x <- setdiff(colnames(training_frame), y)
     }
  }

  # Build parameter list to send to model builder
  parms <- list()
  parms$training_frame <- training_frame
  args <- .verify_dataxy(training_frame, x, y)
  if (missing(protected_columns)) { 
    # core infogram
    if (!missing(safety_index_threshold)) {
      warning("Should not set safety_index_threshold for Core Infogram runs. Set net_information_threshold instead.")
    }
    if (!missing(relevance_index_threshold)) {
      warning("Should not set relevance_index_threshold for Core Infogram runs. Set total_information_threshold instead.")
    }
  } else { 
    # fair infogram
    if (!missing(net_information_threshold)) {
    warning("Should not set net_information_threshold for Fair Infogram runs, set safety_index_threshold instead.")
    }
    if (!missing(total_information_threshold)) {
      warning("Should not set total_information_threshold for Fair Infogram runs, set relevance_index_threshold instead.")
    }
  }

  if (!missing(offset_column) && !is.null(offset_column))  args$x_ignore <- args$x_ignore[!( offset_column == args$x_ignore )]
  if (!missing(weights_column) && !is.null(weights_column)) args$x_ignore <- args$x_ignore[!( weights_column == args$x_ignore )]
  if (!missing(fold_column) && !is.null(fold_column)) args$x_ignore <- args$x_ignore[!( fold_column == args$x_ignore )]
  parms$ignored_columns <- args$x_ignore
  parms$response_column <- args$y

  if (!missing(validation_frame))
    parms$validation_frame <- validation_frame
  if (!missing(seed))
    parms$seed <- seed
  if (!missing(keep_cross_validation_models))
    parms$keep_cross_validation_models <- keep_cross_validation_models
  if (!missing(keep_cross_validation_predictions))
    parms$keep_cross_validation_predictions <- keep_cross_validation_predictions
  if (!missing(keep_cross_validation_fold_assignment))
    parms$keep_cross_validation_fold_assignment <- keep_cross_validation_fold_assignment
  if (!missing(nfolds))
    parms$nfolds <- nfolds
  if (!missing(fold_assignment))
    parms$fold_assignment <- fold_assignment
  if (!missing(fold_column))
    parms$fold_column <- fold_column
  if (!missing(ignore_const_cols))
    parms$ignore_const_cols <- ignore_const_cols
  if (!missing(score_each_iteration))
    parms$score_each_iteration <- score_each_iteration
  if (!missing(offset_column))
    parms$offset_column <- offset_column
  if (!missing(weights_column))
    parms$weights_column <- weights_column
  if (!missing(standardize))
    parms$standardize <- standardize
  if (!missing(distribution))
    parms$distribution <- distribution
  if (!missing(plug_values))
    parms$plug_values <- plug_values
  if (!missing(max_iterations))
    parms$max_iterations <- max_iterations
  if (!missing(stopping_rounds))
    parms$stopping_rounds <- stopping_rounds
  if (!missing(stopping_metric))
    parms$stopping_metric <- stopping_metric
  if (!missing(stopping_tolerance))
    parms$stopping_tolerance <- stopping_tolerance
  if (!missing(balance_classes))
    parms$balance_classes <- balance_classes
  if (!missing(class_sampling_factors))
    parms$class_sampling_factors <- class_sampling_factors
  if (!missing(max_after_balance_size))
    parms$max_after_balance_size <- max_after_balance_size
  if (!missing(max_runtime_secs))
    parms$max_runtime_secs <- max_runtime_secs
  if (!missing(custom_metric_func))
    parms$custom_metric_func <- custom_metric_func
  if (!missing(auc_type))
    parms$auc_type <- auc_type
  if (!missing(algorithm))
    parms$algorithm <- algorithm
  if (!missing(protected_columns))
    parms$protected_columns <- protected_columns
  if (!missing(total_information_threshold))
    parms$total_information_threshold <- total_information_threshold
  if (!missing(net_information_threshold))
    parms$net_information_threshold <- net_information_threshold
  if (!missing(relevance_index_threshold))
    parms$relevance_index_threshold <- relevance_index_threshold
  if (!missing(safety_index_threshold))
    parms$safety_index_threshold <- safety_index_threshold
  if (!missing(data_fraction))
    parms$data_fraction <- data_fraction
  if (!missing(top_n_features))
    parms$top_n_features <- top_n_features

  if (!missing(algorithm_params))
      parms$algorithm_params <- as.character(toJSON(algorithm_params, pretty = TRUE))

  # Build segment-models specific parameters
  segment_parms <- list()
  if (!missing(segment_columns))
    segment_parms$segment_columns <- segment_columns
  if (!missing(segment_models_id))
    segment_parms$segment_models_id <- segment_models_id
  segment_parms$parallelism <- parallelism

  # Error check and build segment models
  segment_models <- .h2o.segmentModelsJob('infogram', segment_parms, parms, h2oRestApiVersion=3)
  return(segment_models)
}


#' Plot an H2O Infogram
#'
#' Plots the Infogram for an H2OInfogram object.
#'
#' @param x A fitted \linkS4class{H2OInfogram} object.
#' @param ... additional arguments to pass on.
#' @return A ggplot2 object.
#' @seealso \code{\link{h2o.infogram}}
#' @examples
#' \dontrun{
#' h2o.init()
#' 
#' # Convert iris dataset to an H2OFrame
#' train <- as.h2o(iris)
#' 
#' # Create and plot infogram
#' ig <- h2o.infogram(y = "Species", training_frame = train)
#' plot(ig)
#' 
#' }
#' @method plot H2OInfogram
#' @export
plot.H2OInfogram <- function(x, ...) {
  .check_for_ggplot2() # from explain.R
  .data <- NULL
  varargs <- list(...)
  if ("title" %in% names(varargs)) {
    title <- varargs$title
  } else {
    title <- "Infogram"
  }
  if ("total_information" %in% names(x@admissible_score)) {
    # core infogram
    xlab <- "Total Information"
    ylab <- "Net Information"
    xthresh <- x@total_information_threshold
    ythresh <- x@net_information_threshold
  } else {
    # fair infogram
    xlab <- "Relevance Index"
    ylab <- "Safety Index"
    xthresh <- x@relevance_index_threshold
    ythresh <- x@safety_index_threshold
  }
  df <- as.data.frame(x@admissible_score)
  # use generic names for x, y for easier ggplot code
  names(df) <- c("column",
                 "admissible",
                 "admissible_index",
                 "ig_x",
                 "ig_y",
                 "raw")
  ggplot2::ggplot(data = df, ggplot2::aes_(~ig_x, ~ig_y)) +
    ggplot2::geom_point() +
    ggplot2::geom_polygon(ggplot2::aes(.data$x_coordinates, .data$y_coordinates), data = data.frame(
      x_coordinates = c(xthresh, xthresh, -Inf, -Inf, Inf, Inf, xthresh),
      y_coordinates = c(ythresh, Inf, Inf, -Inf, -Inf, ythresh, ythresh)
    ), alpha = 0.1, fill = "#CC663E") +
    ggplot2::geom_path(ggplot2::aes(.data$x_coordinates, .data$y_coordinates), data = data.frame(
      x_coordinates = c(xthresh, xthresh, NA, xthresh, Inf),
      y_coordinates = c(ythresh,     Inf, NA, ythresh, ythresh)
    ), color = "red", linetype = "dashed") +
    ggplot2::geom_text(ggplot2::aes_(~ig_x, ~ig_y, label = ~column),
                       data = df[as.logical(df$admissible),], nudge_y = -0.0325,
                       color = "blue", size = 2.5) +
    ggplot2::xlab(xlab) +
    ggplot2::ylab(ylab) +
    ggplot2::coord_fixed(xlim = c(0, 1.1), ylim = c(0, 1.1), expand = FALSE) +
    ggplot2::theme_bw() +
    ggplot2::ggtitle(title) +
    ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5))
}

Try the h2o package in your browser

Any scripts or data that you put into this service are public.

h2o documentation built on Aug. 9, 2023, 9:06 a.m.