R/train.R

#' Train
#'
#' Train the brain.
#'
#' @inheritParams architecture
#' @param print Whether to print training error, iterations and elapsed time.
#' @param data A data.frame from where to select input and output (\code{...}).
#' @param rate Learning rate to train the network. It can be a static rate (just a number), dynamic
#' (an list of numbers, which will transition from one to the next one according to the number of iterations).
#' @param iterations Maximum number of iterations.
#' @param shuffle If \code{TRUE}, the training set is shuffled after every iteration, this is useful for
#' training data sequences which order is not meaningful to networks with context memory, like \code{\link{lstm}}.
#' @param error Minimum error.
#' @param func Training cost function, one of \code{cross_entropy}, \code{mse} (mean squared error), or \code{binary}.
#' @param cost A cost function as returned by \code{\link{cost_function}}.
#' @param log Frequency at which to log training error and elapsed time (iterations), can be retrieved with
#' \code{\link{get_log}}.
#' @param scale Set to \code{TRUE} to scale the data with \code{\link{balance}}.
#' @param cv Set to \code{TRUE} to cross validate.
#' @param ... Bare column names from \code{data}.
#'
#' @section Functions:
#' \itemize{
#'   \item{\code{train}: Train the brain.}
#'   \item{\code{train_data}: Add training data.}
#'   \item{\code{train_opts}: Pass training options.}
#'   \item{\code{train_input}, \code{train_output}: Pass training input and output based on \code{data}.}
#'   \item{\code{cost_function}: Returns a cost function to use as \code{cost} in \code{train} function.}
#'   \item{\code{get_training}: Returns training error, iterations and elapsed time.}
#' }
#'
#' @section Cost functions:
#' \itemize{
#'   \item{\code{mse}: Mean squared error}
#'   \item{\code{binary}: Binay}
#'   \item{\code{cross_entropy}: Cross Entropy}
#' }
#'
#' @examples
#' df <- dplyr::tibble(
#'   x = c(0, 0, 1, 1),
#'   y = c(0, 1, 0, 1),
#'   z = c(0, 1, 1, 0)
#' )
#'
#' br <- brain() %>%
#'   perceptron(c(2,3,1)) %>%
#'   train_data(df) %>%
#'   train_input(x, y) %>%
#'   train_output(z) %>%
#'   train(
#'     cost = cost_function("cross_entropy"),
#'     log = 10
#'   )
#'
#' log <- get_log(br)
#' plot(log)
#'
#' @seealso \code{\link{network}} to manually create layers, connections and resulting brain, \code{\link{dsr}} for
#' distracted sequence recall.
#'
#' @name train
#' @export
train <- function(brain, cost, log = NULL, print = TRUE){

  if(missing(cost))
    stop("missing cost function", call. = FALSE)

  if(!inherits(cost, "cost_function"))
    stop("cost must be of class cost, see ctost_function", call. = FALSE)

  io <- purrr::map2(
    brain$opts$training$input,
    brain$opts$training$output,
    ~c(.x, .y)
  )

  brain$opts$trainingcost <- gsub("opts\\['cost'\\] = synaptic.trainer.cost.", "", tolower(cost))

  opts <- brain$opts$training$options

  if(is.null(opts)) opts <- list()

  brain$brain$eval("var trainer = new synaptic.Trainer(net)")
  brain$brain$assign("trainingData", io)
  brain$brain$assign("opts", opts)
  brain$brain$eval(cost)

  if(!is.null(log)){
    lg <- .make_log(log)
    brain$brain$eval("var log = []")
    brain$brain$eval(lg)
  }

  brain$brain$eval("var history = trainer.train(trainingData, opts)")

  if(isTRUE(print)){
    out <- get_training(brain)
    .prt_train(out)
  }

  brain$opts$trained <- TRUE

  return(brain)

}

#' @rdname train
#' @export
train_input <- function(brain, ..., scale = FALSE, data = NULL){

  data <- .get_data(data, brain, "training")

  input <- data %>%
    dplyr::select(...) %>%
    apply(2, function(x, scale){
      if(isTRUE(scale))
        balance(x)
      else
        x
    }, scale = scale) %>%
    unname() %>%
    apply(1, as.list) %>%
    purrr::map(~list(input = .x))

  brain$opts$training$input <- input
  return(brain)
}

#' @rdname train
#' @export
train_output <- function(brain, ..., scale = FALSE, data = NULL){

  data <- .get_data(data, brain, "training")

  output <- data %>%
    dplyr::select(...) %>%
    apply(2, function(x, scale){
      if(isTRUE(scale))
        balance(x)
      else
        x
    }, scale = scale) %>%
    unname() %>%
    apply(1, as.list) %>%
    purrr::map(~list(output = .x))

  brain$opts$training$output <- output

  return(brain)
}

#' @rdname train
#' @export
train_data <- function(brain, data){

  if(missing(data)) stop("missing data", call. = FALSE)

  row.names(data) <- NULL

  brain$opts$training$data <- data

  return(brain)
}

#' @rdname train
#' @export
train_opts <- function(brain, rate = NULL, iterations = NULL, error = NULL, shuffle = NULL, cv = NULL){

  opts <- list()

  if(!is.null(rate)) opts$rate <- rate
  if(!is.null(iterations)) opts$iterations <- iterations
  if(!is.null(shuffle)) opts$shuffle <- shuffle
  if(!is.null(error)) opts$error <- error
  if(!is.null(error)) opts$crossValidate <- cv

  brain$opts$training$options <- opts
  return(brain)
}

#' @rdname train
#' @export
cost_function <- function(func){

  func <- tolower(func)

  if(!func %in% c("cross_entropy", "mse", "binary"))
    stop("invalid func", call. = FALSE)

  jsobj <- switch (func,
    "cross_entropy" = "cost.CROSS_ENTROPY",
    "mse" = "cost.MSE",
    "binary" = "cost.BINARY"
  )

  jsobj <- paste0("synaptic.Trainer.", jsobj)
  jsobj <- paste("opts['cost'] =", jsobj)
  structure(jsobj, class = "cost_function")

}

#' @rdname train
#' @export
get_training <- function(brain){
  x <- brain$brain$get("history")

  df <- tryCatch(
    dplyr::as_tibble(x),
    error = function(e) e
  )

  if(!inherits(df, "error"))
    return(df)
  else
    return(x)
}

#' @name train
#' @export
get_log <- function(brain){

  log <- tryCatch(
    brain$brain$get("log"),
    error = function(e) e
  )

  if(inherits(log, "error"))
    stop("no log specified in train", call. = FALSE)
  else
    log <- structure(log, class = c("log", "data.frame"))

  return(log)
}

#' @rdname hopfield
#' @export
learn_pattern <- train_input

#' @rdname hopfield
#' @export
learn_data <- train_data

#' DSR
#'
#' \href{http://www.overcomplete.net/papers/nn2012.pdf}{Distracted Sequence Recall} for \code{\link{lstm}}.
#'
#' @inheritParams architecture
#' @param targets Target sequence, a \code{list}.
#' @param distractors Distractors sequence, a \code{list}.
#' @param prompts Prompts sequence, a \code{list}.
#' @param length Length of input sequence, an \code{int}, i.e.:\code{10L}.
#' @param iterations Number of iteations.
#' @param rate Learning rate.
#' @param ... Any other parameter.
#'
#' @examples
#' \dontrun{
#' brain() %>%
#'   lstm(c(6, 7, 2)) %>%
#'   train(cost_function("cross_entropy")) %>%
#'   dsr(
#'     targets = list(2, 4),
#'     distractors = list(3, 5),
#'     prompts = list(0, 1),
#'     length = 10
#'   )
#' }
#'
#' @export
dsr <- function(brain, targets, distractors, prompts, length, iterations = NULL, rate = NULL, ...){

  if(missing(targets) || missing(distractors) || missing(length))
    stop("missing targets, distractors, or length", call. = FALSE)


  opts <- list(
    targets = targets,
    distractors = distractors,
    prompts = prompts,
    length = length
  )

  if(!is.null(iterations)) opts$iterations <- iterations
  if(!is.null(rate)) opts$rate <- rate

  brain$brain$assign("dsr", opts)
  brain$brain$eval("trainer.DSR(dsr)")

  return(brain)

}
brain-r/brain documentation built on May 21, 2019, 4:05 a.m.