# @title ReinforceWithBaseline
# @format \code{\link{R6Class}} object
# @description ReinforceWithBaseline
# $\delta = G_t - v_w(s_t)$
# $w = w + \beta * \delta * \nabla_w v_w(s_t)$
# $\theta = \theta + \alpha * \gamma^t * \delta * \nabla_{\theta}log(\pi_{\theta}(A_t|S_t))
# @return [\code{\link{AgentPGBaseline}}].
AgentPGBaseline = R6::R6Class("AgentPGBaseline",
inherit = AgentPG,
public = list(
brain_actor = NULL, # cross entropy loss
brain_critic = NULL, # mse loss
critic_yhat = NULL,
p_old_c = NULL,
p_next_c = NULL,
delta = NULL,
list.rewards = NULL,
setBrain = function() {
self$task = "policy_fun"
self$brain_actor = SurroNN$new(self)
self$task = "value_fun"
self$brain_critic = SurroNN$new(self)
self$model = self$brain_critic
},
getReplayYhat = function(batchsize) {
self$list.replay = self$mem$sample.fun(batchsize)
self$glogger$log.nn$info("replaying %s", self$mem$replayed.idx)
list.states.old = lapply(self$list.replay, ReplayMem$extractOldState)
list.states.next = lapply(self$list.replay, ReplayMem$extractNextState)
self$list.rewards = lapply(self$list.replay, ReplayMem$extractReward)
self$list.acts = lapply(self$list.replay, ReplayMem$extractAction)
self$model = self$brain_critic
self$p_old_c = self$getYhat(list.states.old)
self$p_next_c = self$getYhat(list.states.next)
temp = simplify2array(list.states.old) # R array put elements columnwise
mdim = dim(temp)
norder = length(mdim)
self$replay.x = aperm(temp, c(norder, 1:(norder - 1)))
},
replay = function(batchsize) {
self$getReplayYhat(batchsize)
len = length(self$list.replay) # replay.list might be smaller than batchsize
self$setAmf(batchsize)
self$delta = array(self$vec_dis_return, dim = dim(self$p_old_c)) - self$p_old_c
list.targets.actor = lapply(1:len, function(i) as.vector(self$extractActorTarget(i)))
list.targets.critic = lapply(1:len, function(i) as.vector(self$extractCriticTarget(i)))
y_actor = t(simplify2array(list.targets.actor))
y_actor = diag(self$amf) %*% y_actor
y_actor = diag(as.vector(self$delta)) %*% y_actor
y_critic = array(unlist(list.targets.critic), dim = c(len, 1L))
self$brain_actor$batch_update(self$replay.x, y_actor) # update the policy model
self$brain_critic$batch_update(self$replay.x, y_critic) # update the policy model
},
extractCriticTarget = function(i) {
y = self$p_old_c[i, ] + self$delta[i]
return(y)
},
extractActorTarget = function(i) {
act = self$list.acts[[i]]
vec.act = rep(0L, self$act_cnt)
vec.act[act] = 1.0
target = vec.act
return(target)
},
adaptLearnRate = function() {
self$brain_actor$lr = self$brain_actor$lr * self$lr_decay
self$brain_critic$lr = self$brain_critic$lr * self$lr_decay
},
afterStep = function() {
self$policy$afterStep()
},
#@override
evaluateArm = function(state) {
state = array_reshape(state, c(1L, dim(state)))
self$vec.arm.q = self$brain_actor$pred(state)
self$glogger$log.nn$info("state: %s", paste(state, collapse = " "))
self$glogger$log.nn$info("prediction: %s", paste(self$vec.arm.q, collapse = " "))
},
afterEpisode = function() {
self$replay(self$interact$perf$total_steps) # key difference here
}
) # public
)
AgentPGBaseline$info = function() {
"Policy Gradient with Baseline"
}
quicktest = function() {
#pg.bl.agent.nn.arch.actor = list(nhidden = 64, act1 = "tanh", act2 = "softmax", loss = "categorical_crossentropy", lr = 25e-3, kernel_regularizer = "regularizer_l2(l=0.0001)", bias_regularizer = "regularizer_l2(l=0.0001)", decay = 0.9, clipnorm = 5)
#pg.bl.agent.nn.arch.critic = list(nhidden = 64, act1 = "tanh", act2 = "linear", loss = "mse", lr = 25e-3, kernel_regularizer = "regularizer_l2(l=0.0001)", bias_regularizer = "regularizer_l2(l=0)", decay = 0.9, clipnorm = 5)
#value_fun = makeNetFun(pg.bl.agent.nn.arch.critic, flag_critic = T)
#policy_fun = makeNetFun(pg.bl.agent.nn.arch.actor)
env = makeGymEnv("CartPole-v0")
conf = getDefaultConf("AgentPGBaseline")
agent = initAgent("AgentPGBaseline", env, conf, custom_brain = F)
#agent$customizeBrain(list(value_fun = value_fun, policy_fun = policy_fun))
agent$learn(200L)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.