R/predictLSTMmodel.R

#' @title predictLSTMmodel
#' @description get predictions for an LSTM recurrent network
#' @param model LSTM model as returned by \code{\link{mxLSTM}} or \code{\link{fitLSTMmodel}}
#' @param dat input data as provided by \code{\link{transformLSTMinput}} in the 'x' element of the list.
#' @param fullSequence Boolean. If FALSE, only the last predicted element of a sequence is returned.
#'                              If TRUE, a prediction for each step in the sequence is returned.
#'                              If the model was trained with optimizeFullSequence = FALSE, 
#'                              this will be set to FALSE with a warning.
#' @return data.frame with predictions.
#' @details the sequence length is inferred from the \code{model} argument.
#' @seealso \code{\link{mxLSTM}}, \code{\link{fitLSTMmodel}}, \code{\link{getLSTMmodel}}
#' @export 
predictLSTMmodel <- function(model, dat, fullSequence = TRUE){
  
  if(!"mxLSTM" %in% class(model)) stop("'model' must be an mxLSTM object")
  
  if(!all(model$varNames$x %in% dimnames(dat)[[1]])) {
    stop("Wrong variables in input data for prediction")
  }
  
  seq.len    <- dim(model$arg.params$data)[2]
  ## check whether optimizeFullSequence was switched off. 
  ## If so, fullSequence argument for prediction doesn't make sense
  optimizeFullSequence <- seq.len > 1 & dim(model$arg.params$label)[2] == seq.len
  if(!optimizeFullSequence){
    if(fullSequence){
      warning("fullSequence turned off because model was trained without full sequence support")
    }
  }

  if(dim(dat)[2] != seq.len) stop("Prediction data has a wrong sequence length")
  
  ## create executor from input symbol and parameters
  exec <- mxnet:::mx.symbol.bind(symbol     = model$symbol,     ctx        = mx.cpu(), 
                                 arg.arrays = model$arg.params, aux.arrays = model$aux.params, 
                                 grad.reqs  = rep("null", length(model$arg.params)) ## no gradients needed in testing
  )
  
  ## create init state arrays with all 0s for clearing after each batch
  init.states.name <- grep(".*\\.[ch]$", names(model$arg.params), value = TRUE)
  
  init.states.cleared <- 
    lapply(model$arg.params[init.states.name], function(x) return(x * 0))
  
  ## get correct order of variables
  dat <- dat[model$varNames$x,,, drop = FALSE]
  
  batch.size <- dim(model$arg.params$data)[3]  
  
  ## create dummy y variables as placeholder
  y <- array(0, dim = dim(dat))

  ## create the iterator over batches
  input <- mx.io.arrayiter(data    = dat, 
                           label   = y, 
                           batch.size = batch.size, 
                           shuffle = FALSE)
  
  input$reset()
  if (!input$iter.next()) stop("Cannot predict on empty iterator")
  
  input$reset()
  
  ## result container
  packer <- mxnet:::mx.nd.arraypacker()
  
  while (input$iter.next()) {
    
    ## clear initial states
    mx.exec.update.arg.arrays(exec, init.states.cleared, match.name = TRUE)
    
    ## set inputs
    mx.exec.update.arg.arrays(exec, list(data = input$value()$data), 
                              match.name = TRUE)
    
    ## calculate outputs
    mx.exec.forward(exec, is.train = FALSE)
    
    out.pred <- exec$ref.outputs[[1]]
    
    
    ## reorder
    
    padded <- input$num.pad()
    
    if(fullSequence & optimizeFullSequence){
      ## reorder the output so that it is elem1[seq1], elem1[seq2], ..., elem1[seqN], elem2[seq1],,,
      ## that makes it time-ordered, as the original label should be
      ## if the last batch is not fully filled, outputs from previous batch are repeated..
      ## num.pad indicates, how many elements in the batch are padded from the back.
      ## remove those from the output. Be careful: the output is ordered as follows: 
      ## elem1seq1, elem2seq1, ..., elemNseq1, ..., elemNseq2, ..., elemNseqN
      timeOrderIndices <- integer(0)
      for(s in seq_len(batch.size - padded)){
        timeOrderIndices <-
          c(timeOrderIndices,
            seq(s - 1,
                seq.len * batch.size- 1,
                by = batch.size
                )
          )
      }
  
      timeOrderIndices <- mx.nd.array(timeOrderIndices)

      out.pred <- mx.nd.take(a = out.pred, indices = timeOrderIndices, mode = "clip") # mode = "raise" would be preferred but does not work anymore.
      
    } else if(optimizeFullSequence) { # fullSequence == FALSE
      
      ## if fullSequence is FALSE, select the last element of each sequence
      lastElementIndices <- 
        (seq.len - 1) * batch.size + (0 : (batch.size - padded - 1)) %>% 
        mx.nd.array()
      
      out.pred <- mx.nd.take(a = out.pred, indices = lastElementIndices, mode = "clip") # mode = "raise" would be preferred but does not work anymore.
    
    } else {## else optimizeFullSequence is FALSE
      ## only do something if padded
      if(padded > 0){
        out.pred <- mx.nd.take(a = out.pred, indices = mx.nd.array(seq_len(batch.size - padded) - 1), mode = "clip")
      }
    }
    
    packer$push(out.pred)
    
  }
  
  input$reset()
  
  out <- 
    packer$get() %>% 
    t %>% 
    data.table %>% 
    setnames(if(length(.) > 1) paste0("y", seq_along(.)) else "y")
  
  
  ## get an index of row numbers
  if(fullSequence){
    index <- seq_len(nrow(out))
  } else {
    index <- seq(from = seq.len, by = seq.len, length.out = nrow(out))
  }
  
  out[, rowIndex := index]
  
  return(as.data.frame(out))
}
MarkusBonsch/mxLSTM documentation built on May 28, 2019, 6:40 a.m.