R/utils.R

VALID_ARCHS <- c("perceptron", "lstm", "liquid", "hopfield")

.arch2fun <- function(x){
  switch (x,
    "perceptron" = "Perceptron",
    "lstm" = "LSTM",
    "liquid" = "Liquid",
    "hopfield" = "Hopfield"
  )
}

.architect <- function(brain, arch, layers){

  if(missing(arch) || missing(layers))
    stop("missing arch or layers", call. = FALSE)

  if(!arch %in% VALID_ARCHS)
    stop("invalid arch", call. = FALSE)

  brain$opts$architecture <- arch
  brain$opts$train$prt <- layers

  brain$opts$layers$prt <- rep("logistic", length(layers))

  layers <- paste0(layers, collapse = ",")

  func <- .arch2fun(arch)
  func <- paste0("var net = new synaptic.Architect.", func, "(", layers, ");")

  brain$brain$eval(func)
  return(brain)
}

.get_data <- function(x, brain, what = "training"){

  if(!is.null(x)){
    row.names(x) <- NULL
    return(x)
  }

  if(length(brain$opts[[what]]$data))
    return(brain$opts[[what]]$data)
  else
    stop("no data", call. = FALSE)

}

.check_name <- function(x){
  space <- grepl("[[:space:]]", x)

  if(isTRUE(space))
    stop("name cannot include spaces", call. = FALSE)
}

.get_gate <- function(x){
  func <- switch(
    x,
    "input" = "INPUT_GATE",
    "output" = "OUTPUT_GATE",
    "one2one" = "ONE_TO_ONE"
  )

  paste0("synaptic.Layer.gateType.", func)
}

.make_log <- function(x){
  jsobj <- paste0(
    "{
    	every: ", x, ",
    	do: function(data) {
        log.push({iteration: data.iterations, error: data.error, rate: data.rate})
    	}
    }"
  )

  paste("opts['schedule'] =", jsobj)
}

.prt <- function(brain){

  if(brain$opts$architecture == "undefined")
    arch <- crayon::red(cli::symbol$cross)
  else
    arch <- crayon::green(cli::symbol$tick)

  cat(
    cli::rule(left = "Brain"), "\n",
    arch,
    "Architecture:",
    brain$opts$architecture
  )

  if(length(brain$opts$layers$prt) >= 3){

    hidden <- paste0(
      "[",
      paste0(
        brain$opts$train$prt[3:length(brain$opts$train$prt)-1],
        collapse = ","
      ),
      "]"
    )

    cat(
      "\n",
      crayon::yellow(cli::symbol$menu), "Layers:\n",
      "\t", cli::symbol$arrow_right, "input: ", brain$opts$layers$prt[1], "-", brain$opts$train$prt[1], "\n",
      "\t", cli::symbol$double_line, "hidden:", brain$opts$layers$prt[2], "-", hidden, "\n",
      "\t", cli::symbol$arrow_left, "output:", brain$opts$layers$prt[3], "-", brain$opts$train$prt[3]
    )

  } else if(length(brain$opts$train$prt)){
    l <- paste0(
      "[",
      paste0(
        brain$opts$train$prt,
        collapse = ","
      ),
      "]"
    )
    cat(
      "\n",
      crayon::yellow(cli::symbol$menu),
      "Layers:",
      l
    )
  }

  if(isTRUE(brain$opts$trained))
    cat("\n", crayon::green(cli::symbol$radio_on),"Trained:", brain$opts$trainingcost)
  else
    cat("\n", crayon::red(cli::symbol$radio_off),"Untrained")

  cat("\n")
}

.prt_train <- function(x){
  cat(
    cli::rule(left = "Training"), "\n",
    crayon::yellow(cli::symbol$cross),
    "Error:", x$error, "\n",
    crayon::yellow(cli::symbol$info),
    "Iterations:", x$iterations, "\n",
    crayon::yellow(cli::symbol$play),
    "Time:", x$time, "ms\n"
  )
}

.null2js <- function(x){
  if(is.null(x))
    return("null")
  else
    return(x)
}

.squash <- function(brain, name, squash, bias = NULL){

  if(missing(squash))
    stop("missing squash function", call. = FALSE)

  if(!inherits(squash, "squash_function"))
    stop("squash must be a squash function, see squash_function", call. = FALSE)

  bias <- .null2js(bias)

  opts <- paste0(
    "{squash: ", squash, ", bias: ", bias, "}"
  )

  ev <- paste0(name, ".set(", opts, ")")

  brain$brain$eval(ev)

  return(brain)
}
brain-r/brain documentation built on May 21, 2019, 4:05 a.m.