R/plot.R

#' Plot brain
#'
#' Plot brain architecture.
#'
#' @inheritParams architecture
#' @param color Color of neurons, one of \code{state}, \code{layer}, \code{bias}, \code{activation}, or \code{old}.
#' @param size Size of neurons, one of \code{state}, \code{bias}, \code{activation}, or \code{old}.
#'
#' @examples
#' brain() %>%
#'   perceptron(c(2,3,1)) %>%
#'   plot_brain(size = "bias")
#'
#' @importFrom stats runif
#'
#' @export
plot_brain <- function(brain, color = NULL, size = NULL){

  if(!is.null(color))
    if(!color %in% c("state", "old", "activation", "bias", "layer")) stop("invalid color", call. = FALSE)

  if(!is.null(size))
    if(!size %in% c("state", "old", "activation", "bias")) stop("invalid size", call. = FALSE)

  net <- export(brain)

  source <- purrr::map_int(net$connections, "from") + 1
  target <- purrr::map_int(net$connections, "to") + 1
  weight <- purrr::map(net$connections, "weight") %>% unlist()

  edges <- cbind.data.frame(source, target, weight) %>%
    dplyr::mutate(
      id = 1:dplyr::n(),
      type = "arrow"
    )

  if(!is.null(size))
    sz <- purrr::map(net$neurons, size) %>% unlist() %>% abs()
  else
    sz <- rep(1, length(net$neurons))

  sz <- scales::rescale(sz, to = c(1, 5))

  if(!is.null(color))
    col <- purrr::map(net$neurons, color) %>% unlist()
  else
    col <- rep(1, length(net$neurons))

  x <- purrr::map(net$neurons, "layer") %>% unlist()
  labs <- gsub("0", "hidden", x)

  x <- c(
    rep(0, length(x[grepl("input", x)])),
    rep(1, length(x[grepl("0", x)])),
    rep(2, length(x[grepl("output", x)]))
  )

  nodes <- dplyr::tibble(
    x = scales::rescale(x, to = c(1, 20)),
    y = runif(length(x), 5, 15),
    size = sz,
    label = labs,
    color = scales::col_factor(c("#ff8c45", "#ffda45", "#45aeff"), domain = NULL)(col)
  ) %>%
    dplyr::mutate(
      id = 1:dplyr::n()
    )

  sigmajs::sigmajs() %>%
    sigmajs::sg_nodes(nodes, id, label, size, color, x, y) %>%
    sigmajs::sg_edges(edges, id, source, target, weight, type) %>%
    sigmajs::sg_settings(
      labelThreshold = 0,
      edgeColor = "default",
      minNodeSize = 2,
      maxNodeSize = 10
    ) %>%
    sigmajs::sg_drag_nodes()
}
brain-r/brain documentation built on May 21, 2019, 4:05 a.m.