R/fixpl.nnet.R

#' slightly modified program from beckmw gist (see note below)
#' @param mod.in nnet instance
#' @note \url{https://tinyurl.com/tuyvbkk} is slightly broken, and this is a simple repair.
#' The original link is from \url{https://www.r-bloggers.com/visualizing-neural-networks-in-r-update/}.
#' @export
plot.nnet = function (mod.in, nid = T, all.out = T, all.in = T, bias = T, 
    wts.only = F, rel.rsc = 5, circle.cex = 5, node.labs = T, 
    var.labs = T, x.lab = NULL, y.lab = NULL, line.stag = NULL, 
    struct = NULL, cex.val = 1, alpha.val = 1, circle.col = "lightblue", 
    pos.col = "black", neg.col = "grey", ...) 
{
    require(scales)
    if ("mlp" %in% class(mod.in)) 
        warning("Bias layer not applicable for rsnns object")
    if ("numeric" %in% class(mod.in)) {
        if (is.null(struct)) 
            stop("Three-element vector required for struct")
        if (length(mod.in) != ((struct[1] * struct[2] + struct[2] * 
            struct[3]) + (struct[3] + struct[2]))) 
            stop("Incorrect length of weight matrix for given network structure")
    }
    nnet.vals <- function(mod.in, nid, rel.rsc, struct.out = struct) {
        require(scales)
        require(reshape)
        if ("numeric" %in% class(mod.in)) {
            wts <- mod.in
            struct.out <- struct
        }
        if ("nn" %in% class(mod.in)) {
            struct.out <- c(dim(mod.in$weights[[1]][[1]])[[1]] - 
                1, dim(mod.in$weights[[1]][[1]])[[2]], dim(mod.in$response)[2])
            wts <- unlist(mod.in$weights[[1]])
        }
        if ("nnet" %in% class(mod.in)) {
            struct.out <- mod.in$n
            wts <- mod.in$wts
        }
        if ("mlp" %in% class(mod.in)) {
            wts <- mod.in$snnsObject$getCompleteWeightMatrix()
            inps <- wts[grep("Input", row.names(wts)), grep("Hidden", 
                colnames(wts)), drop = F]
            inps <- rbind(rep(NA, ncol(inps)), inps)
            outs <- wts[grep("Hidden", row.names(wts)), grep("Output", 
                colnames(wts)), drop = F]
            outs <- rbind(rep(NA, ncol(outs)), outs)
            wts <- c(melt(inps)$value, melt(outs)$value)
            struct.out <- c(mod.in$nInputs, mod.in$archParams$size, 
                mod.in$nOutputs)
            assign("bias", F, envir = environment(nnet.vals))
        }
        if (nid) 
            wts <- rescale(abs(wts), c(1, rel.rsc))
        indices <- matrix(seq(1, struct.out[1] * struct.out[2] + 
            struct.out[2]), ncol = struct.out[2])
        out.ls <- list()
        for (i in 1:ncol(indices)) {
            out.ls[[paste("hidden", i)]] <- wts[indices[, i]]
        }
        if (struct.out[3] == 1) 
            out.ls[["out 1"]] <- wts[(max(indices) + 1):length(wts)]
        else {
            out.indices <- matrix(seq(max(indices) + 1, length(wts)), 
                ncol = struct.out[3])
            for (i in 1:ncol(out.indices)) {
                out.ls[[paste("out", i)]] <- wts[out.indices[, 
                  i]]
            }
        }
        assign("struct", struct.out, envir = environment(nnet.vals))
        out.ls
    }
    wts <- nnet.vals(mod.in, nid = F)
    if (wts.only) 
        return(wts)
    x.range <- c(0, 100)
    y.range <- c(0, 100)
    if (is.null(line.stag)) 
        line.stag <- 0.011 * circle.cex/2
    layer.x <- seq(0.17, 0.9, length = 3)
    bias.x <- c(mean(layer.x[1:2]), mean(layer.x[2:3]))
    bias.y <- 0.95
    in.col <- bord.col <- circle.col
    circle.cex <- circle.cex
    if ("numeric" %in% class(mod.in)) {
        x.names <- paste0(rep("X", struct[1]), seq(1:struct[1]))
        y.names <- paste0(rep("Y", struct[3]), seq(1:struct[3]))
    }
    else {
        if ("mlp" %in% class(mod.in)) {
            all.names <- mod.in$snnsObject$getUnitDefinitions()
            x.names <- all.names[grep("Input", all.names$unitName), 
                "unitName"]
            y.names <- all.names[grep("Output", all.names$unitName), 
                "unitName"]
        }
        else {
            if (is.null(mod.in$call$formula)) {
                x.names <- colnames(eval(mod.in$call$x))
                y.names <- colnames(eval(mod.in$call$y))
            }
            else {
                forms <- formula(mod.in)
                facts <- attr(terms(forms), "factors")
                x.names <- colnames(facts)
                if ("nn" %in% class(mod.in)) 
                  y.check <- mod.in$response
                else y.check <- mod.in$fitted
                if (ncol(y.check) > 1) 
                  y.names <- colnames(y.check)
                else y.names <- row.names(facts)[!row.names(facts) %in% 
                  x.names]
            }
        }
    }
    if (!is.null(x.lab)) {
        if (length(x.names) != length(x.lab)) 
            stop("x.lab length not equal to number of input variables")
        else x.names <- x.lab
    }
    if (!is.null(y.lab)) {
        if (length(y.names) != length(y.lab)) 
            stop("y.lab length not equal to number of output variables")
        else y.names <- y.lab
    }
    plot(x.range, y.range, type = "n", axes = F, ylab = "", xlab = "", 
        ...)
    get.ys <- function(lyr) {
        spacing <- diff(c(0 * diff(y.range), 0.9 * diff(y.range)))/max(struct)
        seq(0.5 * (diff(y.range) + spacing * (lyr - 1)), 0.5 * 
            (diff(y.range) - spacing * (lyr - 1)), length = lyr)
    }
    layer.points <- function(layer, x.loc, layer.name, cex = cex.val) {
        x <- rep(x.loc * diff(x.range), layer)
        y <- get.ys(layer)
        points(x, y, pch = 21, cex = circle.cex, col = in.col, 
            bg = bord.col)
        if (node.labs) 
            text(x, y, paste(layer.name, 1:layer, sep = ""), 
                cex = cex.val)
        if (layer.name == "I" & var.labs) 
            text(x - line.stag * diff(x.range), y, x.names, pos = 2, 
                cex = cex.val)
        if (layer.name == "O" & var.labs) 
            text(x + line.stag * diff(x.range), y, y.names, pos = 4, 
                cex = cex.val)
    }
    bias.points <- function(bias.x, bias.y, layer.name, cex, 
        ...) {
        for (val in 1:length(bias.x)) {
            points(diff(x.range) * bias.x[val], bias.y * diff(y.range), 
                pch = 21, col = in.col, bg = bord.col, cex = circle.cex)
            if (node.labs) 
                text(diff(x.range) * bias.x[val], bias.y * diff(y.range), 
                  paste(layer.name, val, sep = ""), cex = cex.val)
        }
    }
    layer.lines <- function(mod.in, h.layer, layer1 = 1, layer2 = 2, 
        out.layer = F, nid, rel.rsc, all.in, pos.col, neg.col, 
        ...) {
        x0 <- rep(layer.x[layer1] * diff(x.range) + line.stag * 
            diff(x.range), struct[layer1])
        x1 <- rep(layer.x[layer2] * diff(x.range) - line.stag * 
            diff(x.range), struct[layer1])
        if (out.layer == T) {
            y0 <- get.ys(struct[layer1])
            y1 <- rep(get.ys(struct[layer2])[h.layer], struct[layer1])
            src.str <- paste("out", h.layer)
            wts <- nnet.vals(mod.in, nid = F, rel.rsc)
            wts <- wts[grep(src.str, names(wts))][[1]][-1]
            wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
            wts.rs <- wts.rs[grep(src.str, names(wts.rs))][[1]][-1]
            cols <- rep(pos.col, struct[layer1])
            cols[wts < 0] <- neg.col
            if (nid) 
                segments(x0, y0, x1, y1, col = cols, lwd = wts.rs)
            else segments(x0, y0, x1, y1)
        }
        else {
            if (is.logical(all.in)) 
                all.in <- h.layer
            else all.in <- which(x.names == all.in)
            y0 <- rep(get.ys(struct[layer1])[all.in], struct[2])
            y1 <- get.ys(struct[layer2])
            src.str <- "hidden"
            wts <- nnet.vals(mod.in, nid = F, rel.rsc)
            wts <- unlist(lapply(wts[grep(src.str, names(wts))], 
                function(x) x[all.in + 1]))
            wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
            wts.rs <- unlist(lapply(wts.rs[grep(src.str, names(wts.rs))], 
                function(x) x[all.in + 1]))
            cols <- rep(pos.col, struct[layer2])
            cols[wts < 0] <- neg.col
            if (nid) 
                segments(x0, y0, x1, y1, col = cols, lwd = wts.rs)
            else segments(x0, y0, x1, y1)
        }
    }
    bias.lines <- function(bias.x, mod.in, nid, rel.rsc, all.out, 
        pos.col, neg.col, ...) {
        if (is.logical(all.out)) 
            all.out <- 1:struct[3]
        else all.out <- which(y.names == all.out)
        for (val in 1:length(bias.x)) {
            wts <- nnet.vals(mod.in, nid = F, rel.rsc)
            wts.rs <- nnet.vals(mod.in, nid = T, rel.rsc)
            if (val == 1) {
                wts <- wts[grep("out", names(wts), invert = T)]
                wts.rs <- wts.rs[grep("out", names(wts.rs), invert = T)]
            }
            if (val == 2) {
                wts <- wts[grep("out", names(wts))]
                wts.rs <- wts.rs[grep("out", names(wts.rs))]
            }
            cols <- rep(pos.col, length(wts))
            cols[unlist(lapply(wts, function(x) x[1])) < 0] <- neg.col
            wts.rs <- unlist(lapply(wts.rs, function(x) x[1]))
            if (nid == F) {
                wts.rs <- rep(1, struct[val + 1])
                cols <- rep("black", struct[val + 1])
            }
            if (val == 1) {
                segments(rep(diff(x.range) * bias.x[val] + diff(x.range) * 
                  line.stag, struct[val + 1]), rep(bias.y * diff(y.range), 
                  struct[val + 1]), rep(diff(x.range) * layer.x[val + 
                  1] - diff(x.range) * line.stag, struct[val + 
                  1]), get.ys(struct[val + 1]), lwd = wts.rs, 
                  col = cols)
            }
            if (val == 2) {
                segments(rep(diff(x.range) * bias.x[val] + diff(x.range) * 
                  line.stag, struct[val + 1]), rep(bias.y * diff(y.range), 
                  struct[val + 1]), rep(diff(x.range) * layer.x[val + 
                  1] - diff(x.range) * line.stag, struct[val + 
                  1]), get.ys(struct[val + 1])[all.out], lwd = wts.rs[all.out], 
                  col = cols[all.out])
            }
        }
    }
    if (bias) 
        bias.lines(bias.x, mod.in, nid = nid, rel.rsc = rel.rsc, 
            all.out = all.out, pos.col = alpha(pos.col, alpha.val), 
            neg.col = alpha(neg.col, alpha.val))
    if (is.logical(all.in)) {
        mapply(function(x) layer.lines(mod.in, x, layer1 = 1, 
            layer2 = 2, nid = nid, rel.rsc = rel.rsc, all.in = all.in, 
            pos.col = alpha(pos.col, alpha.val), neg.col = alpha(neg.col, 
                alpha.val)), 1:struct[1])
    }
    else {
        node.in <- which(x.names == all.in)
        layer.lines(mod.in, node.in, layer1 = 1, layer2 = 2, 
            nid = nid, rel.rsc = rel.rsc, all.in = all.in, pos.col = alpha(pos.col, 
                alpha.val), neg.col = alpha(neg.col, alpha.val))
    }
    if (is.logical(all.out)) 
        mapply(function(x) layer.lines(mod.in, x, layer1 = 2, 
            layer2 = 3, out.layer = T, nid = nid, rel.rsc = rel.rsc, 
            all.in = all.in, pos.col = alpha(pos.col, alpha.val), 
            neg.col = alpha(neg.col, alpha.val)), 1:struct[3])
    else {
        all.out <- which(y.names == all.out)
        layer.lines(mod.in, all.out, layer1 = 2, layer2 = 3, 
            out.layer = T, nid = nid, rel.rsc = rel.rsc, pos.col = pos.col, 
            neg.col = neg.col)
    }
    layer.points(struct[1], layer.x[1], "I")
    layer.points(struct[2], layer.x[2], "H")
    layer.points(struct[3], layer.x[3], "O")
    if (bias) 
        bias.points(bias.x, bias.y, "B")
}
vjcitn/edx_adv_bioc documentation built on March 17, 2020, 3:58 p.m.