SensitivityPlots: Plot sensitivities of a neural network model

View source: R/SensitivityPlots.R

SensitivityPlotsR Documentation

Plot sensitivities of a neural network model

Description

Function to plot the sensitivities created by SensAnalysisMLP.

Usage

SensitivityPlots(
  sens = NULL,
  der = TRUE,
  zoom = TRUE,
  quit.legend = FALSE,
  output = 1,
  plot_type = NULL,
  inp_var = NULL
)

Arguments

sens

SensAnalysisMLP object created by SensAnalysisMLP or HessMLP object created by HessianMLP.

der

logical indicating if density plots should be created. By default is TRUE

zoom

logical indicating if the distributions should be zoomed when there is any of them which is too tiny to be appreciated in the third plot. facet_zoom function from ggforce package is required.

quit.legend

logical indicating if legend of the third plot should be removed. By default is FALSE

output

numeric or character specifying the output neuron or output name to be plotted. By default is the first output (output = 1).

plot_type

character indicating which of the 3 plots to show. Useful when several variables are analyzed. Acceptable values are 'mean_sd', 'square', 'raw' corresponding to first, second and third plot respectively. If NULL, all plots are shown at the same time. By default is NULL.

inp_var

character indicating which input variable to show in density plot. Only useful when choosing plot_type='raw' to show the density plot of one input variable. If NULL, all variables are plotted in density plot. By default is NULL.

Value

List with the following plot for each output:

  • Plot 1: colorful plot with the classification of the classes in a 2D map

  • Plot 2: b/w plot with probability of the chosen class in a 2D map

  • Plot 3: plot with the stats::predictions of the data provided if param der is FALSE

References

Pizarroso J, Portela J, Muñoz A (2022). NeuralSens: Sensitivity Analysis of Neural Networks. Journal of Statistical Software, 102(7), 1-36.

Examples

## Load data -------------------------------------------------------------------
data("DAILY_DEMAND_TR")
fdata <- DAILY_DEMAND_TR

## Parameters of the NNET ------------------------------------------------------
hidden_neurons <- 5
iters <- 250
decay <- 0.1

################################################################################
#########################  REGRESSION NNET #####################################
################################################################################
## Regression dataframe --------------------------------------------------------
# Scale the data
fdata.Reg.tr <- fdata[,2:ncol(fdata)]
fdata.Reg.tr[,3] <- fdata.Reg.tr[,3]/10
fdata.Reg.tr[,1] <- fdata.Reg.tr[,1]/1000

# Normalize the data for some models
preProc <- caret::preProcess(fdata.Reg.tr, method = c("center","scale"))
nntrData <- predict(preProc, fdata.Reg.tr)

#' ## TRAIN nnet NNET --------------------------------------------------------
# Create a formula to train NNET
form <- paste(names(fdata.Reg.tr)[2:ncol(fdata.Reg.tr)], collapse = " + ")
form <- formula(paste(names(fdata.Reg.tr)[1], form, sep = " ~ "))

set.seed(150)
nnetmod <- nnet::nnet(form,
                           data = nntrData,
                           linear.output = TRUE,
                           size = hidden_neurons,
                           decay = decay,
                           maxit = iters)
# Try SensAnalysisMLP
sens <- NeuralSens::SensAnalysisMLP(nnetmod, trData = nntrData, plot = FALSE)
NeuralSens::SensitivityPlots(sens)

NeuralSens documentation built on July 9, 2023, 6:18 p.m.