Nothing
#' Experience Replay
#'
#' Create replay memory for experience replay.
#'
#' Sampling from replay memory will be uniform.
#'
#' @param size \[`integer(1)`] \cr Size of replay memory.
#' @param batch.size \[`integer(1)`] \cr Batch size.
#'
#' @return \[`list(size, batch.size)`]
#' This list can then be passed onto [makeAgent], which will construct the
#' replay memory accordingly.
#'
#' @md
#' @aliases experience.replay, replay.memory
#' @export
#'
#' @examples
#' memory = makeReplayMemory(size = 100L, batch.size = 16L)
makeReplayMemory = function(size = 100L, batch.size = 16L) { # add arguments for priorization
checkmate::assertInt(size, lower = 1)
checkmate::assertInt(batch.size, lower = 1, upper = size)
x = list(size = size, batch.size = batch.size)
class(x) = "ReplayMemory"
x
}
ReplayMemory = R6::R6Class("ReplayMemory",
public = list(
memory = NULL,
size = NULL,
batch.size = NULL,
index = 0L,
index.full = 0L,
# fixme allow growing replay memory?
initialize = function(size, batch.size) {
self$size = size
self$batch.size = batch.size
self$memory = vector("list", length = self$size)
},
# # initialize following policy
# initializeMemory = function(env, policy) {
# for (i in seq_len(self$size)) {
# action = policy$sampleAction()
# env$step(action)
# data = list(state = preprocessState(envir$previous.state), action = action,
# reward = envir$reward, next.state = preprocessState(envir$state))
# }
# },
observe = function(state, action, reward, next.state) {
self$index = self$index + 1L
self$index.full = self$index.full + 1L
self$index.full = min(self$size, self$index.full)
index = self$getReplacementIndex()
obs = self$getReplayObservation(state, action, reward, next.state)
self$add(obs, index)
},
getReplayObservation = function(state, action, reward, next.state) {
list(state = state, action = action, reward = reward, next.state = next.state)
},
# e.g. oldest entry
getReplacementIndex = function() {
if (self$index > self$size) {
self$index = 1L
}
self$index
},
add = function(observation, index) {
self$memory[[index]] = observation
},
isFull = function(memory = self$memory) {
# maybe it is enough to check the last entry
full = !(any(purrr::map_lgl(memory, is.null)))
full
},
extract = function(batch, member, fun = lapply) {
states = fun(batch, "[[", member)
states
},
# checkMemory = function(memory = self$memory, batch.size = self$batch.size) {
# if (!self$isFull()) {
# if (self$index < batch.size) {
# return(FALSE)
# }
# }
# },
sampleBatch = function(memory = self$memory[seq_len(self$index.full)], batch.size = self$batch.size) {
if (length(memory) >= batch.size) {
indices = self$getIndices(length(memory), batch.size)
batch = memory[indices]
return(purrr::transpose(batch))
} else {
message("Cannot sample from replay memory because batch size > number of non-empty entries in replay memory.")
}
},
getIndices = function(memory.size, batch.size) {
indices = sample(seq_len(memory.size), size = batch.size)
indices
}
)
)
# ideas: maybe replay memory in future not list but hash table / dictionary etc
# data frame with list columns?
# fixme allow dynamic change of replay memory length
# store preprocessed state?
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.