#' slightly modified program from beckmw gist (see note below)
#' @param mod.in nnet instance
#' @note \url{https://tinyurl.com/tuyvbkk} is slightly broken, and this is a simple repair.
#' The original link is from \url{https://www.r-bloggers.com/visualizing-neural-networks-in-r-update/}.
#' @export
plot.nnet = function (mod.in, nid = T, all.out = T, all.in = T, bias = T,
wts.only = F, rel.rsc = 5, circle.cex = 5, node.labs = T,
var.labs = T, x.lab = NULL, y.lab = NULL, line.stag = NULL,
struct = NULL, cex.val = 1, alpha.val = 1, circle.col = "lightblue",
pos.col = "black", neg.col = "grey", ...)
{
require(scales)
if ("mlp" %in% class(mod.in))
warning("Bias layer not applicable for rsnns object")
if ("numeric" %in% class(mod.in)) {
if (is.null(struct))
stop("Three-element vector required for struct")
if (length(mod.in) != ((struct[1] * struct[2] + struct[2] *
struct[3]) + (struct[3] + struct[2])))
stop("Incorrect length of weight matrix for given network structure")
}
nnet.vals <- function(mod.in, nid, rel.rsc, struct.out = struct) {
require(scales)
require(reshape)
if ("numeric" %in% class(mod.in)) {
wts <- mod.in
struct.out <- struct
}
if ("nn" %in% class(mod.in)) {
struct.out <- c(dim(mod.in$weights[[1]][[1]])[[1]] -
1, dim(mod.in$weights[[1]][[1]])[[2]], dim(mod.in$response)[2])
wts <- unlist(mod.in$weights[[1]])
}
if ("nnet" %in% class(mod.in)) {
struct.out <- mod.in$n
wts <- mod.in$wts
}
if ("mlp" %in% class(mod.in)) {
wts <- mod.in$snnsObject$getCompleteWeightMatrix()
inps <- wts[grep("Input", row.names(wts)), grep("Hidden",
colnames(wts)), drop = F]
inps <- rbind(rep(NA, ncol(inps)), inps)
outs <- wts[grep("Hidden", row.names(wts)), grep("Output",
colnames(wts)), drop = F]
outs <- rbind(rep(NA, ncol(outs)), outs)
wts <- c(melt(inps)$value, melt(outs)$value)
struct.out <- c(mod.in$nInputs, mod.in$archParams$size,
mod.in$nOutputs)
assign("bias", F, envir = environment(nnet.vals))
}
if (nid)
wts <- rescale(abs(wts), c(1, rel.rsc))
indices <- matrix(seq(1, struct.out[1] * struct.out[2] +
struct.out[2]), ncol = struct.out[2])
out.ls <- list()
for (i in 1:ncol(indices)) {
out.ls[[paste("hidden", i)]] <- wts[indices[, i]]
}
if (struct.out[3] == 1)
out.ls[["out 1"]] <- wts[(max(indices) + 1):length(wts)]
else {
out.indices <- matrix(seq(max(indices) + 1, length(wts)),
ncol = struct.out[3])
for (i in 1:ncol(out.indices)) {
out.ls[[paste("out", i)]] <- wts[out.indices[,
i]]
}
}
assign("struct", struct.out, envir = environment(nnet.vals))
out.ls
}
wts <- nnet.vals(mod.in, nid = F)
if (wts.only)
return(wts)
x.range <- c(0, 100)
y.range <- c(0, 100)
if (is.null(line.stag))
line.stag <- 0.011 * circle.cex/2
layer.x <- seq(0.17, 0.9, length = 3)
bias.x <- c(mean(layer.x[1:2]), mean(layer.x[2:3]))
bias.y <- 0.95
in.col <- bord.col <- circle.col
circle.cex <- circle.cex
if ("numeric" %in% class(mod.in)) {
x.names <- paste0(rep("X", struct[1]), seq(1:struct[1]))
y.names <- paste0(rep("Y", struct[3]), seq(1:struct[3]))
}
else {
if ("mlp" %in% class(mod.in)) {
all.names <- mod.in$snnsObject$getUnitDefinitions()
x.names <- all.names[grep("Input", all.names$unitName),
"unitName"]
y.names <- all.names[grep("Output", all.names$unitName),
"unitName"]
}
else {
if (is.null(mod.in$call$formula)) {
x.names <- colnames(eval(mod.in$call$x))
y.names <- colnames(eval(mod.in$call$y))
}
else {
forms <- formula(mod.in)
facts <- attr(terms(forms), "factors")
x.names <- colnames(facts)
if ("nn" %in% class(mod.in))
y.check <- mod.in$response
else y.check <- mod.in$fitted
if (ncol(y.check) > 1)
y.names <- colnames(y.check)
else y.names <- row.names(facts)[!row.names(facts) %in%
x.names]
}
}
}
if (!is.null(x.lab)) {
if (length(x.names) != length(x.lab))
stop("x.lab length not equal to number of input variables")
else x.names <- x.lab
}
if (!is.null(y.lab)) {
if (length(y.names) != length(y.lab))
stop("y.lab length not equal to number of output variables")
else y.names <- y.lab
}
plot(x.range, y.range, type = "n", axes = F, ylab = "", xlab = "",
...)
get.ys <- function(lyr) {
spacing <- diff(c(0 * diff(y.range), 0.9 * diff(y.range)))/max(struct)
seq(0.5 * (diff(y.range) + spacing * (lyr - 1)), 0.5 *
(diff(y.range) - spacing * (lyr - 1)), length = lyr)
}
layer.points <- function(layer, x.loc, layer.name, cex = cex.val) {
x <- rep(x.loc * diff(x.range), layer)
y <- get.ys(layer)
points(x, y, pch = 21, cex = circle.cex, col = in.col,
bg = bord.col)
if (node.labs)
text(x, y, paste(layer.name, 1:layer, sep = ""),
cex = cex.val)
if (layer.name == "I" & var.labs)
text(x - line.stag * diff(x.range), y, x.names, pos = 2,
cex = cex.val)
if (layer.name == "O" & var.labs)
text(x + line.stag * diff(x.range), y, y.names, pos = 4,
cex = cex.val)
}
bias.points <- function(bias.x, bias.y, layer.name, cex,
...) {
for (val in 1:length(bias.x)) {
points(diff(x.range) * bias.x[val], bias.y * diff(y.range),
pch = 21, col = in.col, bg = bord.col, cex = circle.cex)
if (node.labs)
text(diff(x.range) * bias.x[val], bias.y * diff(y.range),
paste(layer.name, val, sep = ""), cex = cex.val)
}
}
layer.lines <- function(mod.in, h.layer, layer1 = 1, layer2 = 2,
out.layer = F, nid, rel.rsc, all.in, pos.col, neg.col,
...) {
x0 <- rep(layer.x[layer1] * diff(x.range) + line.stag *
diff(x.range), struct[layer1])
x1 <- rep(layer.x[layer2] * diff(x.range) - line.stag *
diff(x.range), struct[layer1])
if (out.layer == T) {
y0 <- get.ys(struct[layer1])
y1 <- rep(get.ys(struct[layer2])[h.layer], struct[layer1])
src.str <- paste("out", h.layer)
wts <- nnet.vals(mod.in, nid = F, rel.rsc)
wts <- wts[grep(src.str, names(wts))][[1]][-1]
wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
wts.rs <- wts.rs[grep(src.str, names(wts.rs))][[1]][-1]
cols <- rep(pos.col, struct[layer1])
cols[wts < 0] <- neg.col
if (nid)
segments(x0, y0, x1, y1, col = cols, lwd = wts.rs)
else segments(x0, y0, x1, y1)
}
else {
if (is.logical(all.in))
all.in <- h.layer
else all.in <- which(x.names == all.in)
y0 <- rep(get.ys(struct[layer1])[all.in], struct[2])
y1 <- get.ys(struct[layer2])
src.str <- "hidden"
wts <- nnet.vals(mod.in, nid = F, rel.rsc)
wts <- unlist(lapply(wts[grep(src.str, names(wts))],
function(x) x[all.in + 1]))
wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
wts.rs <- unlist(lapply(wts.rs[grep(src.str, names(wts.rs))],
function(x) x[all.in + 1]))
cols <- rep(pos.col, struct[layer2])
cols[wts < 0] <- neg.col
if (nid)
segments(x0, y0, x1, y1, col = cols, lwd = wts.rs)
else segments(x0, y0, x1, y1)
}
}
bias.lines <- function(bias.x, mod.in, nid, rel.rsc, all.out,
pos.col, neg.col, ...) {
if (is.logical(all.out))
all.out <- 1:struct[3]
else all.out <- which(y.names == all.out)
for (val in 1:length(bias.x)) {
wts <- nnet.vals(mod.in, nid = F, rel.rsc)
wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
if (val == 1) {
wts <- wts[grep("out", names(wts), invert = T)]
wts.rs <- wts.rs[grep("out", names(wts.rs), invert = T)]
}
if (val == 2) {
wts <- wts[grep("out", names(wts))]
wts.rs <- wts.rs[grep("out", names(wts.rs))]
}
cols <- rep(pos.col, length(wts))
cols[unlist(lapply(wts, function(x) x[1])) < 0] <- neg.col
wts.rs <- unlist(lapply(wts.rs, function(x) x[1]))
if (nid == F) {
wts.rs <- rep(1, struct[val + 1])
cols <- rep("black", struct[val + 1])
}
if (val == 1) {
segments(rep(diff(x.range) * bias.x[val] + diff(x.range) *
line.stag, struct[val + 1]), rep(bias.y * diff(y.range),
struct[val + 1]), rep(diff(x.range) * layer.x[val +
1] - diff(x.range) * line.stag, struct[val +
1]), get.ys(struct[val + 1]), lwd = wts.rs,
col = cols)
}
if (val == 2) {
segments(rep(diff(x.range) * bias.x[val] + diff(x.range) *
line.stag, struct[val + 1]), rep(bias.y * diff(y.range),
struct[val + 1]), rep(diff(x.range) * layer.x[val +
1] - diff(x.range) * line.stag, struct[val +
1]), get.ys(struct[val + 1])[all.out], lwd = wts.rs[all.out],
col = cols[all.out])
}
}
}
if (bias)
bias.lines(bias.x, mod.in, nid = nid, rel.rsc = rel.rsc,
all.out = all.out, pos.col = alpha(pos.col, alpha.val),
neg.col = alpha(neg.col, alpha.val))
if (is.logical(all.in)) {
mapply(function(x) layer.lines(mod.in, x, layer1 = 1,
layer2 = 2, nid = nid, rel.rsc = rel.rsc, all.in = all.in,
pos.col = alpha(pos.col, alpha.val), neg.col = alpha(neg.col,
alpha.val)), 1:struct[1])
}
else {
node.in <- which(x.names == all.in)
layer.lines(mod.in, node.in, layer1 = 1, layer2 = 2,
nid = nid, rel.rsc = rel.rsc, all.in = all.in, pos.col = alpha(pos.col,
alpha.val), neg.col = alpha(neg.col, alpha.val))
}
if (is.logical(all.out))
mapply(function(x) layer.lines(mod.in, x, layer1 = 2,
layer2 = 3, out.layer = T, nid = nid, rel.rsc = rel.rsc,
all.in = all.in, pos.col = alpha(pos.col, alpha.val),
neg.col = alpha(neg.col, alpha.val)), 1:struct[3])
else {
all.out <- which(y.names == all.out)
layer.lines(mod.in, all.out, layer1 = 2, layer2 = 3,
out.layer = T, nid = nid, rel.rsc = rel.rsc, pos.col = pos.col,
neg.col = neg.col)
}
layer.points(struct[1], layer.x[1], "I")
layer.points(struct[2], layer.x[2], "H")
layer.points(struct[3], layer.x[3], "O")
if (bias)
bias.points(bias.x, bias.y, "B")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.