SensMatPlot: Plot sensitivities of a neural network model

View source: R/SensMatPlot.R

SensMatPlotR Documentation

Plot sensitivities of a neural network model

Description

Function to plot the sensitivities created by HessianMLP.

Usage

SensMatPlot(
  hess,
  sens = NULL,
  output = 1,
  metric = c("mean", "std", "meanSensSQ"),
  senstype = c("matrix", "interactions"),
  ...
)

Arguments

hess

HessMLP object created by HessianMLP.

sens

SensMLP object created by SensAnalysisMLP.

output

numeric or character specifying the output neuron or output name to be plotted. By default is the first output (output = 1).

metric

character specifying the metric to be plotted. It can be "mean", "std" or "meanSensSQ".

senstype

character specifying the type of plot to be plotted. It can be "matrix" or "interactions". If type = "matrix", only the second derivatives are plotted. If type = "interactions" the main diagonal are the first derivatives respect each input variable.

...

further argument passed similar to ggcorrplot arguments.

Details

Most of the code of this function is based on ggcorrplot() function from package ggcorrplot. However, due to the inhability of changing the limits of the color scale, it keeps giving a warning if that function is used and the color scale overwritten.

Value

a list of ggplots, one for each output neuron.

Examples

## Load data -------------------------------------------------------------------
data("DAILY_DEMAND_TR")
fdata <- DAILY_DEMAND_TR
## Parameters of the NNET ------------------------------------------------------
hidden_neurons <- 5
iters <- 100
decay <- 0.1

################################################################################
#########################  REGRESSION NNET #####################################
################################################################################
## Regression dataframe --------------------------------------------------------
# Scale the data
fdata.Reg.tr <- fdata[,2:ncol(fdata)]
fdata.Reg.tr[,3] <- fdata.Reg.tr[,3]/10
fdata.Reg.tr[,1] <- fdata.Reg.tr[,1]/1000

# Normalize the data for some models
preProc <- caret::preProcess(fdata.Reg.tr, method = c("center","scale"))
nntrData <- predict(preProc, fdata.Reg.tr)

#' ## TRAIN nnet NNET --------------------------------------------------------
# Create a formula to train NNET
form <- paste(names(fdata.Reg.tr)[2:ncol(fdata.Reg.tr)], collapse = " + ")
form <- formula(paste(names(fdata.Reg.tr)[1], form, sep = " ~ "))

set.seed(150)
nnetmod <- nnet::nnet(form,
                      data = nntrData,
                      linear.output = TRUE,
                      size = hidden_neurons,
                      decay = decay,
                      maxit = iters)
# Try HessianMLP
H <- NeuralSens::HessianMLP(nnetmod, trData = nntrData, plot = FALSE)
NeuralSens::SensMatPlot(H)
S <- NeuralSens::SensAnalysisMLP(nnetmod, trData = nntrData, plot = FALSE)
NeuralSens::SensMatPlot(H, S, senstype = "interactions")

NeuralSens documentation built on June 22, 2024, 12:06 p.m.