R/PlotSensMLP.R

Defines functions PlotSensMLP

Documented in PlotSensMLP

#' Neural network structure sensitivity plot
#'
#' @description Plot a neural interpretation diagram colored by sensitivities
#' of the model
#' @param MLP.fit fitted neural network model
#' @param metric metric to plot in the NID. It can be "mean" (default), "median or "sqmean".
#' It can be any metric to combine the raw sensitivities
#' @param sens_neg_col \code{character} string indicating color of negative sensitivity
#'  measure, default 'red'. The same is passed to argument \code{neg_col} of
#'  \link[NeuralNetTools:plotnet]{plotnet}
#' @param sens_pos_col \code{character} string indicating color of positive sensitivity
#'  measure, default 'blue'. The same is passed to argument \code{pos_col} of
#'  \link[NeuralNetTools:plotnet]{plotnet}
#' @param ...	additional arguments passed to \link[NeuralNetTools:plotnet]{plotnet} and/or
#' \link[NeuralSens:SensAnalysisMLP]{SensAnalysisMLP}
#' @return A graphics object
#' @examples
#' ## Load data -------------------------------------------------------------------
#' data("DAILY_DEMAND_TR")
#' fdata <- DAILY_DEMAND_TR
#' ## Parameters of the NNET ------------------------------------------------------
#' hidden_neurons <- 5
#' iters <- 100
#' decay <- 0.1
#'
#' ################################################################################
#' #########################  REGRESSION NNET #####################################
#' ################################################################################
#' ## Regression dataframe --------------------------------------------------------
#' # Scale the data
#' fdata.Reg.tr <- fdata[,2:ncol(fdata)]
#' fdata.Reg.tr[,3] <- fdata.Reg.tr[,3]/10
#' fdata.Reg.tr[,1] <- fdata.Reg.tr[,1]/1000
#'
#' # Normalize the data for some models
#' preProc <- caret::preProcess(fdata.Reg.tr, method = c("center","scale"))
#' nntrData <- predict(preProc, fdata.Reg.tr)
#'
#' #' ## TRAIN nnet NNET --------------------------------------------------------
#' # Create a formula to train NNET
#' form <- paste(names(fdata.Reg.tr)[2:ncol(fdata.Reg.tr)], collapse = " + ")
#' form <- formula(paste(names(fdata.Reg.tr)[1], form, sep = " ~ "))
#'
#' set.seed(150)
#' nnetmod <- nnet::nnet(form,
#'                       data = nntrData,
#'                       linear.output = TRUE,
#'                       size = hidden_neurons,
#'                       decay = decay,
#'                       maxit = iters)
#' # Try SensAnalysisMLP
#' NeuralSens::PlotSensMLP(nnetmod, trData = nntrData)
#' @export PlotSensMLP
PlotSensMLP <- function(MLP.fit, metric = "mean",
                        sens_neg_col = "red", sens_pos_col = "blue",
                        ...) {

  # First obtain all derivatives of the model
  Derivatives <- SensAnalysisMLP(MLP.fit, plot = FALSE, ..., return_all_sens = TRUE)
  # Pull apart derivatives of layers and weights
  d <- Derivatives[[1]]
  mlpstr <- Derivatives[[2]]
  wts <- Derivatives[[3]]
  # Stored the length of colors needed
  color_lengths <- sapply(d, function(x){dim(x)[1]})
  sens <- list()
  for (i in 1:length(color_lengths)) {
    der <- aperm(d[[i]], c(3,1,2))
    if(is.function(metric)) {
      der <- apply(der, c(1,2), metric)
    } else if (metric == "mean") {
      der <- apply(der, c(1,2), mean, na.rm = TRUE)
    } else if (metric == "median") {
      der <- apply(der, c(1,2), stats::median, na.rm = TRUE)
    } else if (metric == "sqmean") {
      der <- apply(der, c(1,2), function(x){mean(x^2, na.rm = TRUE)})
    } else {
      stop("metric must be a function to combine rows")
    }
    # Apply metric to calculate
    if(is.function(metric)) {
      sens[[i]] <- apply(der, 2, metric)
    } else if (metric == "mean") {
      sens[[i]] <- apply(der, 2, mean, na.rm = TRUE)
    } else if (metric == "median") {
      sens[[i]] <- apply(der, c(2), stats::median, na.rm = TRUE)
    } else if (metric == "sqmean") {
      sens[[i]] <- apply(der, 2, function(x){mean(x^2, na.rm = TRUE)})
    } else {
      stop("metric must be a function to calculate over a column of sensitivities")
    }
  }
  # Collapse all sensitivities
  sens <- do.call("c",sens)

  # Rescale the sensitivities in order to obtain the colors
  sens_scaled <- sign(sens) *
    round(scales::rescale(abs(sens),c(1,max(ceiling(1/min(abs(sens))),50)))) +
    max(ceiling(1/min(abs(sens))),50) + 1

  colPal <- grDevices::colorRampPalette(c(sens_neg_col, "white", sens_pos_col))
  senscolors <- colPal(max(sens_scaled) + 1)[round(sens_scaled)]
  senscolors_list <- list()
  color_lengths <- c(0,color_lengths)
  for (i in 2:length(color_lengths)) {
    senscolors_list[[i-1]] <- senscolors[(cumsum(color_lengths)[i-1]+1):cumsum(color_lengths)[i]]
  }

  # Plot Neural network NID and legend scale
  graphics::layout(matrix(1:2,nrow =1), widths = c(0.8,0.2))
  op <- graphics::par(mar=c(5.1,1.1,4.1,2.1))
  on.exit(graphics::par(op))
  NeuralNetTools::plotnet(wts, mlpstr,
                          circle_col = senscolors_list,
                          pos_col = sens_pos_col,
                          neg_col = sens_neg_col,
                          bord_col = "black",
                          bias = FALSE,
                          x_names = Derivatives[[4]],
                          y_names = Derivatives[[5]],
                          ...)
  # Substitute the output points to avoid changing its colors
  args <- list(...)
  out_pos <- 0.4
  if ("pad_x" %in% names(args)) {
    out_pos <- args$pad_x * 0.4
  }
  # Default values for arguments
  cex_val <- 1
  circle_cex <- 5
  node_labs <- TRUE
  bord_col <- "black"
  in_col <- "white"
  line_stag <- NULL
  var_labs <- TRUE
  max_sp <- FALSE
  layer_points_args <- c("cex_val", "circle_cex", "bord_col","node_labs","line_stag",
                         "var_labs","max_sp")
  # Check if any argument must be changed
  for (i in which(layer_points_args %in% names(args))) {
    eval(paste0(layer_points_args[i]," <- args$",layer_points_args[i]))
  }
  x_range <- c(-1,1)
  y_range <- c(0,1)
  # Substitute output values
  layer <- mlpstr[length(mlpstr)]

  x <- rep(out_pos * diff(x_range), layer)
  if(max_sp){
    spacing <- diff(c(0 * diff(y_range), 0.9 * diff(y_range)))/layer
  } else {
    spacing <- diff(c(0 * diff(y_range), 0.9 * diff(y_range)))/max(mlpstr)
  }

  y <- seq(0.5 * (diff(y_range) + spacing * (layer - 1)), 0.5 * (diff(y_range) - spacing * (layer - 1)),
             length = layer)

  graphics::points(x, y, pch = 21, cex = circle_cex, col = bord_col, bg = in_col)
  if(node_labs) graphics::text(x, y, paste('O', 1:layer, sep = ''), cex = cex_val)
  graphics::text(x + line_stag * diff(x_range), y,  Derivatives[[5]], pos = 4, cex = cex_val)

  # Create color scale to know the sensitivities in the graph
  xl <- 1
  yb <- 1
  xr <- 1.5
  yt <- 2

  graphics::par(mar=c(5.1,0.5,4.1,0.5))
  on.exit(graphics::par(op))
  needed_colors <- sum(mlpstr[1:(length(mlpstr)-1)])
  graphics::plot(NA,type="n",ann=FALSE,xlim=c(1,2),ylim=c(1,2),xaxt="n",yaxt="n",bty="n")
  graphics::rect(
    xl,
    utils::head(seq(yb,yt,(yt-yb)/needed_colors),-1),
    xr,
    utils::tail(seq(yb,yt,(yt-yb)/needed_colors),-1),
    col=stats::na.omit(do.call("c",senscolors_list[1:(length(mlpstr)-1)])[order(sens[1:needed_colors])])
  )
  graphics::text(x = 1.25, y = rowMeans(cbind(utils::head(seq(yb,yt,(yt-yb)/needed_colors),-1),
                                              utils::tail(seq(yb,yt,(yt-yb)/needed_colors),-1))),
       label = c(paste0("I",1:mlpstr[1]),
                 paste0("H",1:sum(mlpstr[2:(length(mlpstr)-1)])))[order(sens[1:needed_colors])])
  graphics::mtext(round(sens[1:needed_colors],
                        digits = ifelse(any(abs(sens[1:needed_colors]) < 1),
                                        -floor(log10(min(abs(sens[1:needed_colors]))))+1,
                                        2))[order(sens[1:needed_colors])],
                  at=utils::tail(seq(yb,yt,(yt-yb)/sum(mlpstr[1:(length(mlpstr)-1)])),-1)-0.05,
                  side=2, las=2, cex=0.7)
  reset.graphics <- function(oldpar) {
    graphics::par(oldpar)
    graphics::layout(matrix(1,nrow =1), widths = 1)
  }
  on.exit(reset.graphics(op))
}

Try the NeuralSens package in your browser

Any scripts or data that you put into this service are public.

NeuralSens documentation built on June 22, 2024, 12:06 p.m.