R/plotNN.R

Defines functions plot.nnet

Documented in plot.nnet

# ------------------------------------------------------------------------------
# this functio is found in the web 
#  require(scales)
# require(reshape)
#-------------------------------------------------------------------------------
plot.nnet <- function(x,      # neural network object or numeric vector of weights, if model
                                   # object must be from nnet, mlp, or neuralnet functions
                      nid = TRUE,  # logical value indicating if neural interpretation diagram is            
                                   # plotted, default T
                  all.out = TRUE,  # character string indicating names of response variables for which 
                                   # connections are plotted, default all
                   all.in = TRUE,  # character string indicating names of input variables for which 
                                   # connections are plotted, default all
                     bias = TRUE,  # logical value indicating if bias nodes and connections are  
                                   # plotted, not applicable for networks from mlp function, default T
                 wts.only = FALSE, # logical value indicating if connections weights are returned
                                   # rather than a plot, default F
                  rel.rsc = 5,     # numeric value indicating maximum width of connection lines,
                                   # default 5       
               circle.cex = 5,     # numeric value indicating size of nodes, passed to cex argument, 
                                   # default 5
                node.labs = TRUE,  # logical value indicating if text labels are plotted, default T
                 var.labs = TRUE , # logical value indicating if variable names are plotted next to 
                                   # nodes, default T
                    x.lab = NULL,  # character string indicating names for input variables, default
                                   # from model object
                    y.lab = NULL,  #  character string indicating names for output variables, default 
                                   # from model object
                line.stag = NULL,  # numeric value that specifies distance of connection weights from 
                                   # nodes
                   struct = NULL,  # numeric value of length three indicating network architecture(no.   
                                   # nodes for input, hidden, output), required only if mod.in is a 
                                   # numeric vector
                  cex.val = 1,     # numeric value indicating size of text labels, default 1
                alpha.val = 1,     # numeric value (0-1) indicating transparency of connections, 
                                   # default 1  
               circle.col = 'lightblue', #text value indicating color of nodes, default "lightgray"
                  pos.col = 'black',# text value indicating color of positive connection weights,  
                                    # default "black"
                  neg.col = 'grey', #text value indicating color of negative connection weights, 
                                   # default "gray"
               max.sp = FALSE,     #
               ...)
{
#-------------------------------------------------------------------------------  
#-------------------------------------------------------------------------------
#  local functions 
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
  #gets weights for neural network, output is list
  #if rescaled argument is true, weights are returned but rescaled based on abs value
  nnet.vals<-function(mod.in,nid,rel.rsc,struct.out=struct)
  {
#    require(scales)
    #    require(reshape)
    if('numeric' %in% class(mod.in))
    {
      struct.out <- struct
      wts <- mod.in
    }
    #neuralnet package----------------------------------------------------------
    if('nn' %in% class(mod.in))
    {
      struct.out<-unlist(lapply(mod.in$weights[[1]],ncol))
      struct.out<-struct.out[-length(struct.out)]
      struct.out<-c(
        length(mod.in$model.list$variables),
        struct.out,
        length(mod.in$model.list$response)
      )    		
      wts<-unlist(mod.in$weights[[1]])   
    }
    #nnet package----------------------------------------------------------------
    if('nnet' %in% class(mod.in))
    {
      struct.out<-mod.in$n
      wts<-mod.in$wts
    }  
    #RSNNS package--------------------------------------------------------------
#     if('mlp' %in% class(mod.in))
#     {
#       struct.out<-c(mod.in$nInputs,mod.in$archParams$size,mod.in$nOutputs)
#       hid.num<-length(struct.out)-2
#       wts<-mod.in$snnsObject$getCompleteWeightMatrix()
#       
#       #get all input-hidden and hidden-hidden wts-------------------------------
#       inps<-wts[grep('Input',row.names(wts)),grep('Hidden_2',colnames(wts)),drop=F]
#       inps<-melt(rbind(rep(NA,ncol(inps)),inps))$value
#       uni.hids<-paste0('Hidden_',1+seq(1,hid.num))
#       for(i in 1:length(uni.hids))
#       {
#         if(is.na(uni.hids[i+1])) break
#         tmp<-wts[grep(uni.hids[i],rownames(wts)),grep(uni.hids[i+1],colnames(wts)),drop=F]
#         inps<-c(inps,melt(rbind(rep(NA,ncol(tmp)),tmp))$value)
#       }
#       
#       #get connections from last hidden to output layers------------------------
#       outs<-wts[grep(paste0('Hidden_',hid.num+1),row.names(wts)),grep('Output',colnames(wts)),drop=F]
#       outs<-rbind(rep(NA,ncol(outs)),outs)
#       
#       #weight vector for all
#       wts<-c(inps,melt(outs)$value)
#       assign('bias',F,envir=environment(nnet.vals))
#    } # end RSNNS package ---------------------------------------------------  
    if(nid) wts<-rescale(abs(wts),c(1,rel.rsc))   
    #convert wts to list with appropriate names 
    hid.struct <- struct.out[-c(length(struct.out))]
    row.nms<-NULL
    for(i in 1:length(hid.struct))
    {
      if(is.na(hid.struct[i+1])) break
      row.nms<-c(row.nms,rep(paste('hidden',i,seq(1:hid.struct[i+1])),each=1+hid.struct[i]))
    }
    row.nms<-c(
      row.nms,
      rep(paste('out',seq(1:struct.out[length(struct.out)])),each=1+struct.out[length(struct.out)-1])
    )
    out.ls <- data.frame(wts,row.nms)
    out.ls$row.nms <- factor(row.nms,levels=unique(row.nms),labels=unique(row.nms))
    out.ls <-split(out.ls$wts,f=out.ls$row.nms)
    assign('struct',struct.out,envir=environment(nnet.vals))
    out.ls    
  }# end nnet.vals() ----------------------------------------------------------- 
#-------------------------------------------------------------------------------
# this function is from package scale
#-------------------------------------------------------------------------------
  alpha <- function (colour, alpha = NA) 
  {
    col <- col2rgb(colour, TRUE)/255
    if (length(colour) != length(alpha)) {
      if (length(colour) > 1 && length(alpha) > 1) {
        stop("Only one of colour and alpha can be vectorised")
      }
      if (length(colour) > 1) {
        alpha <- rep(alpha, length.out = length(colour))
      }
      else if (length(alpha) > 1) {
        col <- col[, rep(1, length(alpha)), drop = FALSE]
      }
    }
    alpha[is.na(alpha)] <- col[4, ][is.na(alpha)]
    new_col <- rgb(col[1, ], col[2, ], col[3, ], alpha)
    new_col[is.na(colour)] <- NA
    new_col
  }
#-------------------------------------------------------------------------------
rescale <- function (x, to = c(0, 1), from = range(x, na.rm = TRUE)) 
{
  if (zero_range(from) || zero_range(to)) 
    return(rep(mean(to), length(x)))
  (x - from[1])/diff(from) * diff(to) + to[1]
}
#-------------------------------------------------------------------------------
zero_range <- function (x, tol = .Machine$double.eps * 100) 
{
  if (length(x) == 1) 
    return(TRUE)
  if (length(x) != 2) 
    stop("x must be length 1 or 2")
  if (any(is.na(x))) 
    return(NA)
  if (x[1] == x[2]) 
    return(TRUE)
  if (all(is.infinite(x))) 
    return(FALSE)
  m <- min(abs(x))
  if (m == 0) 
    return(FALSE)
  abs((x[1] - x[2])/m) < tol
}
#-------------------------------------------------------------------------------  
bias.lines<-function(bias.x,mod.in,nid,rel.rsc,all.out,pos.col,neg.col,...)
{
  
  if(is.logical(all.out)) all.out<-1:struct[length(struct)]
  else all.out<-which(y.names==all.out)
  
  for(val in 1:length(bias.x)){
    
    wts<-nnet.vals(mod.in,nid=F,rel.rsc)
    wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
    
    if(val != length(bias.x)){
      wts<-wts[grep('out',names(wts),invert=T)]
      wts.rs<-wts.rs[grep('out',names(wts.rs),invert=T)]
      sel.val<-grep(val,substr(names(wts.rs),8,8))
      wts<-wts[sel.val]
      wts.rs<-wts.rs[sel.val]
    }
    
    else{
      wts<-wts[grep('out',names(wts))]
      wts.rs<-wts.rs[grep('out',names(wts.rs))]
    }
    
    cols<-rep(pos.col,length(wts))
    cols[unlist(lapply(wts,function(x) x[1]))<0]<-neg.col
    wts.rs<-unlist(lapply(wts.rs,function(x) x[1]))
    
    if(nid==F){
      wts.rs<-rep(1,struct[val+1])
      cols<-rep('black',struct[val+1])
    }
    
    if(val != length(bias.x)){
      segments(
        rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]),
        rep(bias.y*diff(y.range),struct[val+1]),
        rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]),
        get.ys(struct[val+1]),
        lwd=wts.rs,
        col=cols
      )
    } 
    else{
      segments(
        rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]),
        rep(bias.y*diff(y.range),struct[val+1]),
        rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]),
        get.ys(struct[val+1])[all.out],
        lwd=wts.rs[all.out],
        col=cols[all.out]
      )
    }
  }
} # end of   bias.lines
#------------------------------------------------------------------------------- 
#-------------------------------------------------------------------------------
#-- MAIN function -------------------------------------------------------------- 
#-------------------------------------------------------------------------------
# require(scales)
#sanity checks-----------------------------------------------------------------
    mod.in <- x # mod.in is the original name
  if('mlp' %in% class(mod.in)) warning('Bias layer not applicable for rsnns object')
  if('numeric' %in% class(mod.in)){
    if(is.null(struct)) stop('Three-element vector required for struct')
    if(length(mod.in) != ((struct[1]*struct[2]+struct[2]*struct[3])+(struct[3]+struct[2])))
      stop('Incorrect length of weight matrix for given network structure')
  }
  if('train' %in% class(mod.in)){
    if('nnet' %in% class(mod.in$finalModel)){
      mod.in<-mod.in$finalModel
      warning('Using best nnet model from train output')
    }
    else stop('Only nnet method can be used with train object')
  }  
#-------------------------------------------------------------------------------
# get the weights
   wts <- nnet.vals(mod.in,nid=F)
  if(wts.only) return(wts)
  
  #circle colors for input, if desired, must be two-vector list, first vector is for input layer
  if(is.list(circle.col))
       {
          circle.col.inp<-circle.col[[1]]
          circle.col<-circle.col[[2]]
        }
  else circle.col.inp <- circle.col
  
  #initiate plotting
  x.range<-c(0,100)
  y.range<-c(0,100)
  #these are all proportions from 0-1
  if(is.null(line.stag)) line.stag <- 0.011*circle.cex/2
   layer.x <- seq(0.17,0.9, length=length(struct))
    bias.x <- layer.x[-length(layer.x)]+diff(layer.x)/2
    bias.y <- 0.95
circle.cex <- circle.cex
  #get variable names from mod.in object
  #change to user input if supplied
  if('numeric' %in% class(mod.in)){
    x.names<-paste0(rep('X',struct[1]),seq(1:struct[1]))
    y.names<-paste0(rep('Y',struct[3]),seq(1:struct[3]))
  }
  if('mlp' %in% class(mod.in)){
    all.names<-mod.in$snnsObject$getUnitDefinitions()
    x.names<-all.names[grep('Input',all.names$unitName),'unitName']
    y.names<-all.names[grep('Output',all.names$unitName),'unitName']
  }
  if('nn' %in% class(mod.in)){
    x.names<-mod.in$model.list$variables
    y.names<-mod.in$model.list$respons
  }
  if('xNames' %in% names(mod.in)){
    x.names<-mod.in$xNames
    y.names<-attr(terms(mod.in),'factor')
    y.names<-row.names(y.names)[!row.names(y.names) %in% x.names]
  }
  if(!'xNames' %in% names(mod.in) & 'nnet' %in% class(mod.in)){
    if(is.null(mod.in$call$formula)){
      x.names<-colnames(eval(mod.in$call$x))
      y.names<-colnames(eval(mod.in$call$y))
    }
    else{
        forms <- eval(mod.in$call$formula)
      x.names <- mod.in$coefnames
        facts <- attr(terms(mod.in),'factors')
      y.check <- mod.in$fitted
      if(ncol(y.check)>1) y.names<-colnames(y.check)
      else y.names<-as.character(forms)[2]
    } 
  }
  #change variables names to user sub 
  if(!is.null(x.lab)){
    if(length(x.names) != length(x.lab)) stop('x.lab length not equal to number of input variables')
    else x.names<-x.lab
  }
  if(!is.null(y.lab)){
    if(length(y.names) != length(y.lab)) stop('y.lab length not equal to number of output variables')
    else y.names<-y.lab
  }
  #initiate plot
  plot(x.range, y.range ,type='n', axes=F,ylab='', xlab='',...)
#-------------------------------------------------------------------------------  
  #function for getting y locations for input, hidden, output layers
  #input is integer value from 'struct'
  get.ys<-function(lyr, max_space = max.sp)
    {
  	if(max_space)
      { 
  		spacing <- diff(c(0*diff(y.range),0.9*diff(y.range)))/lyr
   	  } else {
    	spacing<-diff(c(0*diff(y.range),0.9*diff(y.range)))/max(struct)
   	  }
  		seq(0.5*(diff(y.range)+spacing*(lyr-1)),0.5*(diff(y.range)-spacing*(lyr-1)),
        length=lyr)
     }
#-------------------------------------------------------------------------------  
  #function for plotting nodes
  #'layer' specifies which layer, integer from 'struct'
  #'x.loc' indicates x location for layer, integer from 'layer.x'
  #'layer.name' is string indicating text to put in node
  layer.points<-function(layer,x.loc,layer.name,cex=cex.val)
    {
    x<-rep(x.loc*diff(x.range),layer)
    y<-get.ys(layer)
    points(x,y,pch=21,cex=circle.cex,col=in.col,bg=bord.col)
    if(node.labs) text(x,y,paste(layer.name,1:layer,sep=''),cex=cex.val)
    if(layer.name=='I' & var.labs) text(x-line.stag*diff(x.range),y,x.names,pos=2,cex=cex.val)      
    if(layer.name=='O' & var.labs) text(x+line.stag*diff(x.range),y,y.names,pos=4,cex=cex.val)
    } # end layer.points
#-------------------------------------------------------------------------------  
  #function for plotting bias points
  #'bias.x' is vector of values for x locations
  #'bias.y' is vector for y location
  #'layer.name' is  string indicating text to put in node
  bias.points <- function(bias.x,bias.y,layer.name,cex,...)
    {
    for(val in 1:length(bias.x))
      {
      points(
        diff(x.range)*bias.x[val],
        bias.y*diff(y.range),
        pch=21,col=in.col,bg=bord.col,cex=circle.cex
      )
      if(node.labs)
        text(
          diff(x.range)*bias.x[val],
          bias.y*diff(y.range),
          paste(layer.name,val,sep=''),
          cex=cex.val
        )
      }
     } # end bias.points--------------------------------------------------------  
#-------------------------------------------------------------------------------
  #function creates lines colored by direction and width as proportion of magnitude
  #use 'all.in' argument if you want to plot connection lines for only a single input node
  layer.lines <- function(mod.in,h.layer,layer1=1,layer2=2,out.layer=F,nid,rel.rsc,all.in,pos.col, neg.col,...)
    {    
    x0<-rep(layer.x[layer1]*diff(x.range)+line.stag*diff(x.range),struct[layer1])
    x1<-rep(layer.x[layer2]*diff(x.range)-line.stag*diff(x.range),struct[layer1])   
    if(out.layer==T)
      {      
      y0<-get.ys(struct[layer1])
      y1<-rep(get.ys(struct[layer2])[h.layer],struct[layer1])
      src.str<-paste('out',h.layer)     
      wts<-nnet.vals(mod.in,nid=F,rel.rsc)
      wts<-wts[grep(src.str,names(wts))][[1]][-1]
      wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
      wts.rs<-wts.rs[grep(src.str,names(wts.rs))][[1]][-1]    
      cols<-rep(pos.col,struct[layer1])
      cols[wts<0]<-neg.col      
      if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs)
      else segments(x0,y0,x1,y1)      
      }
    else{
      if(is.logical(all.in)) all.in<-h.layer
      else all.in<-which(x.names==all.in)
      
      y0<-rep(get.ys(struct[layer1])[all.in],struct[2])
      y1<-get.ys(struct[layer2])
      src.str<-paste('hidden',layer1)
      
      wts<-nnet.vals(mod.in,nid=F,rel.rsc)
      wts<-unlist(lapply(wts[grep(src.str,names(wts))],function(x) x[all.in+1]))
      wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
      wts.rs<-unlist(lapply(wts.rs[grep(src.str,names(wts.rs))],function(x) x[all.in+1]))
      
      cols<-rep(pos.col,struct[layer2])
      cols[wts<0]<-neg.col
      
      if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs)
      else segments(x0,y0,x1,y1)     
    }  
  } # end of layer.lines 
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
  #use functions to plot connections between layers
  #bias lines
  if(bias) bias.lines(bias.x,mod.in,nid=nid,rel.rsc=rel.rsc,all.out=all.out,pos.col=alpha(pos.col,alpha.val),
                      neg.col=alpha(neg.col,alpha.val))
  
  #layer lines, makes use of arguments to plot all or for individual layers
  #starts with input-hidden
  #uses 'all.in' argument to plot connection lines for all input nodes or a single node
  if(is.logical(all.in)){  
    mapply(
      function(x) layer.lines(mod.in,x,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc,
        all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)),
      1:struct[1]
    )
  }
  else{
    node.in<-which(x.names==all.in)
    layer.lines(mod.in,node.in,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc,all.in=all.in,
                pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val))
  }
  #connections between hidden layers
  lays<-split(c(1,rep(2:(length(struct)-1),each=2),length(struct)),
              f=rep(1:(length(struct)-1),each=2))
  lays<-lays[-c(1,(length(struct)-1))]
  for(lay in lays){
    for(node in 1:struct[lay[1]]){
      layer.lines(mod.in,node,layer1=lay[1],layer2=lay[2],nid=nid,rel.rsc=rel.rsc,all.in=T,
                  pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val))
    }
  }
  #lines for hidden-output
  #uses 'all.out' argument to plot connection lines for all output nodes or a single node
  if(is.logical(all.out))
    mapply(
      function(x) layer.lines(mod.in,x,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc,
                              all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)),
      1:struct[length(struct)]
      )
  else{
    node.in<-which(y.names==all.out)
    layer.lines(mod.in,node.in,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc,
                pos.col=pos.col,neg.col=neg.col,all.out=all.out)
  }
#-------------------------------------------------------------------------------  
  #use functions to plot nodes
  for(i in 1:length(struct)){
    in.col<-bord.col<-circle.col
    layer.name<-'H'
    if(i==1) { layer.name<-'I'; in.col<-bord.col<-circle.col.inp}
    if(i==length(struct)) layer.name<-'O'
    layer.points(struct[i],layer.x[i],layer.name)
    }

  if(bias) bias.points(bias.x,bias.y,'B')
  
} # end of plot.nnet------------------------------------------------------------

Try the gamlss.add package in your browser

Any scripts or data that you put into this service are public.

gamlss.add documentation built on Feb. 4, 2020, 9:08 a.m.