library(keras)
source("R/snake.R")
source("R/memory.R")
source("R/DQN.R")
snake_model_input <- layer_input(shape = c(12))
snake_model_base <- snake_model_input %>%
layer_dense(units = 512, activation = 'relu') %>%
layer_dropout(0.1) %>%
layer_dense(256, activation='relu') %>%
layer_dropout(0.1) %>%
layer_dense(256,activation = "relu")
snake_model_arm1<-snake_model_base%>%
layer_dense(256,activation = "relu")%>%
layer_dropout(0.1) %>%
layer_dense(256,activation = "relu")%>%
layer_dropout(0.1) %>%
# layer_dense(256,activation = "relu")%>%
# layer_dropout(0.1) %>%
# layer_dense(256,activation = "relu")%>%
layer_dense(256,activation = "relu")
snake_model_arm2<-snake_model_base%>%
layer_dense(128,activation = "relu")%>%
layer_dropout(0.1) %>%
layer_dense(128,activation = "relu")%>%
layer_dropout(0.1) %>%
# layer_dense(128,activation = "relu")%>%
# layer_dropout(0.1) %>%
# layer_dense(128,activation = "relu")%>%
layer_dense(128,activation = "relu")
snake_model_output<-layer_concatenate(list(snake_model_arm1,
snake_model_arm2)) %>%
layer_dense(4,activation = "softmax")
snake_model<-keras_model(snake_model_input,snake_model_output)
optimizer <- optimizer_rmsprop(lr = 0.0005)
snake_model %>% compile(
loss = "mse",
optimizer = optimizer
)
snake_game<-new("snake")
dqn_agent<-new("DQN_agent")
dqn_agent$add_Model(snake_model)
counter_games<-1
record <- 0
best_game <- 0
records <- list()
fruitPos<- list()
while(counter_games < 2000){
# Initialize classes
snake_game$init()
food1 = snake_game$food
fruitRecord<-list(food1)
# Perform first move
state<-snake_game$run_iter(returnStatus = TRUE)
dqn_agent$remember(state$state, state$reward, state$action, state$done, state$state_new)
dqn_agent$train_on(state$state, state$reward, state$action, state$done, state$state_new)
while(!snake_game$dead){
#get old state
state_old = state$state_new
#perform random actions based on agent.epsilon, or choose the action
prediction = dqn_agent$next_step(state_old)
final_move = c("up","down","left","right")[which.max(prediction)]
#perform new move and get new state
state<-snake_game$run_iter(final_move,returnStatus = TRUE)
if(food1!=snake_game$food){
food1<-snake_game$food
fruitRecord<-c(fruitRecord,list(food1))
}
#train short memory base on the new action and state
dqn_agent$train_on(state$state, state$reward, state$action, state$done, state$state_new)
# store the new data into a long term memory
dqn_agent$remember(state$state, state$reward, state$action, state$done, state$state_new)
score = snake_game$score_total
}
#retrain on all data available
dqn_agent$train_long(3000)
cat("Game", counter_games, "\tScore:", score,"\n")
records<-c(records,list(list(game=counter_games,score=score,log=snake_game$log,fruit_positions=fruitRecord)))
counter_games <- counter_games + 1
}
save_model_hdf5(dqn_agent$model[[1]],"snake_player_evenlonger.hd5")
saveRDS(dqn_agent,"snake_dqn_longest.rds")
saveRDS(records,"records_longest.rds")
bestScore<-max(sapply(records,`[[`,2))
bestPerf<-which.max(sapply(records,function(x)ifelse(x[[2]]==bestScore,length(x[[1]]),0)))
steps<-records[[bestPerf]][[3]]
fruit_locs<-records[[bestPerf]][[4]]
snake_game2<-new("snake")
dqn_agent2<-new("DQN_agent")
dqn_agent2$add_Model(dqn_agent$model[[1]])
counter_games<-1
record <- 0
best_game <- 0
records <- list()
fruitPos<- list()
while(counter_games < 500){
# Initialize classes
snake_game2$init()
food1 = snake_game2$food
fruitRecord<-list(food1)
# Perform first move
state<-snake_game2$run_iter(returnStatus = TRUE)
dqn_agent2$remember(state$state, state$reward, state$action, state$done, state$state_new)
dqn_agent2$train_on(state$state, state$reward, state$action, state$done, state$state_new)
while(!snake_game2$dead){
#get old state
state_old = state$state_new
#perform random actions based on agent.epsilon, or choose the action
prediction = dqn_agent2$next_step(state_old,randguess=FALSE)
final_move = c("up","down","left","right")[which.max(prediction)]
#perform new move and get new state
state<-snake_game2$run_iter(final_move,returnStatus = TRUE)
if(food1!=snake_game2$food){
food1<-snake_game2$food
fruitRecord<-c(fruitRecord,list(food1))
}
#train short memory base on the new action and state
dqn_agent2$train_on(state$state, state$reward, state$action, state$done, state$state_new)
# store the new data into a long term memory
dqn_agent2$remember(state$state, state$reward, state$action, state$done, state$state_new)
score = snake_game2$score_total
}
cat("Game", counter_games, "\tScore:", score,"\n")
dqn_agent2$train_long(100*counter_games)
records<-c(records,list(list(game=counter_games,score=score,log=snake_game2$log,fruit_positions=fruitRecord)))
counter_games <- counter_games + 1
}
bestScore<-max(sapply(records,`[[`,2))
bestPerf<-which.max(sapply(records,function(x)ifelse(x[[2]]==bestScore,length(x[[1]]),0)))
steps<-records[[bestPerf]][[3]]
fruit_locs<-records[[bestPerf]][[4]]
snake_game2$replay(steps[-1],fruit_locs,delay = .1)
res<-saveGIF(
snake_game2$replay(steps[-1],fruit_locs,delay = .001),
movie.name = "init_animation_longer_725.gif",
interval= .07)
save_model_hdf5(dqn_agent2$model[[1]],"snake_player_larger2.hd5")
saveRDS(dqn_agent2,"snake_dqn2.rds")
saveRDS(records,"snake_dqn2_records.rds")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.