## examples
### tic tac toe
library(zobrist)
library(magrittr)
library(ggplot2)
library(tidyr)
library(reshape2)
## define the value function by a zobrist hash table of bitsize 20, where
##
## 1 ~ 9 ... indicates if the position is filled by X
## 11 ~ 19 ... indicates if the position is filled by O
## 10, 20 ... unused
##
## assume the indices are column-first, that is,
## index i corresponds to [(i-1) %% 3) + 1, (i+2) %/% 3]
## position (i, j) corresponds to 3*(j-1) + i
state_to_key <- function(state)
{
# returns the key for Q-function corresponding to the state
#
# state : 3x3 matrix of 0 (blank), 1 (X), and 2 (O)
#
# returns:
# integer vector of key
c(which(state == 1L), which(state == 2L)+10)
}
index_to_coord <- function(i)
{
c(((i-1) %% 3) + 1, (i+2) %/% 3)
}
legal_moves <- function(state)
{
# returns all legal moves in the state
# state : 3x3 matrix of 0 (blank), 1 (X), and 2 (O)
#
# returns:
# integer vector of indices of legal moves
which(state == 0L)
}
epsilon_greedy <- function(state, epsilon)
{
# choose a move using the epsilon greedy
#
# values : vector of values
# epsilon: probability of making a random choice
#
# returns:
# integer
choices <- legal_moves(state)
if (runif(1) < epsilon) {
# random choice
i <- sample.int(length(choices), 1)
return(choices[i])
}
return(optim_action(state, conservative = FALSE))
}
optim_action <- function(state, conservative = TRUE)
{
turn <- 2L - (sum(state == 0) %% 2)
choices <- legal_moves(state)
## get the current value of each choice
## if the value is undefined, impute zero
values <- rep(0, length(choices))
for (i in seq_along(choices))
{
v <- Vfunc$get(state_to_key(new_state(state, choices[i], turn)))
if (is.null(v)) {
if (conservative) {
# give a bit of penalty to undefined values
if (turn == 1L) {
v <- -5
} else {
v <- 5
}
} else {
# try undefined values
if (turn == 1L) {
v <- 5
} else {
v <- -5
}
}
}
values[i] <- v
}
if (turn == 2L) values <- -values
index <- which(values == max(values))
i <- index[sample.int(length(index), 1)]
choices[i]
}
check_win <- function(state, player)
{
# retrun TRUE if the player gets a line in the state
for (i in 1:3)
{
if (all(state[i,] == player)) return(TRUE)
if (all(state[,i] == player)) return(TRUE)
}
if (all(state[cbind(1:3,1:3)] == player)) return(TRUE)
if (all(state[cbind(1:3,3:1)] == player)) return(TRUE)
return(FALSE)
}
new_state <- function(state, move, turn)
{
coord <- index_to_coord(move)
state[coord[1], coord[2]] <- turn
state
}
show_state <- function(state)
{
o <- matrix("+", nrow = 3, ncol = 3)
o[state == 1L] <- "X"
o[state == 2L] <- "O"
for (i in 1:3) cat("\n ", paste(o[i,], collapse = " "))
cat("\n")
}
equivalent_states <- function(state)
{
# returns all equivalent states
s1 <- state[3:1, 1:3] # vertical flip
s2 <- state[1:3, 3:1] # horizontal flip
s3 <- state[3:1, 3:1]
s4 <- t(state) # transpose
s5 <- s4[3:1, 1:3]
s6 <- s4[1:3, 3:1]
s7 <- s4[3:1, 3:1]
list(state, s1, s2, s3, s4, s5, s6, s7)
}
### learning ###
train <- function(Vfunc, epsilon = 0.2, alpha = 0.9, gamma = 0.99, quiet = TRUE)
{
# initialization
turn <- 1L
state <- matrix(0L, nrow = 3, ncol = 3)
for (movenumber in 1:10)
{
if (!quiet) {
cat("\n=======================\n")
cat("Move:", movenumber, "\n")
}
## if the current state is winninf for the other player, update the value
## and finish
win <- check_win(state, 3L-turn)
if (win) {
# win = true if the other player wins
if (!quiet) cat("player", 3-turn, "wins!\n")
# if player 1 wins, positive, otherwise negative
value <- 100 * (3 - 2*turn) * (-1)
# update all equivalent states too
for (s in unique(equivalent_states(state)))
Vfunc$update(state_to_key(s), value)
# this episode ends
if (!quiet) cat("final state\n")
if (!quiet) show_state(state)
break
}
## if the move number exceeds 9, the game is over
## give the state the value of 0
if (movenumber > 9L) {
if (!quiet) cat("draw!\n")
# update all equivalent states too
for (s in unique(equivalent_states(state)))
Vfunc$update(state_to_key(s), 0)
# this episode ends
if (!quiet) cat("final state\n")
if (!quiet) show_state(state)
break
}
## get the current value of each choice
## if the value is undefined, impute zero
choices <- legal_moves(state)
cur_values <- rep(0, length(choices))
for (i in seq_along(choices))
{
v <- Vfunc$get(state_to_key(new_state(state, choices[i], turn)))
if (!is.null(v)) cur_values[i] <- unlist(v)
}
# update value
oldval <- Vfunc$get(state_to_key(state))
if (is.null(oldval)) oldval <- 0
if (turn == 1L) {
newval <- (1-alpha)*oldval + alpha*gamma*max(cur_values)
} else {
newval <- (1-alpha)*oldval + alpha*gamma*min(cur_values)
}
keys <- lapply(unique(equivalent_states(state)), state_to_key)
lapply(keys, Vfunc$update, newval)
# player 1 maximizes value, while player 2 minimizes
action <- epsilon_greedy(state, epsilon)
if (!quiet) {
cat("current state\n")
show_state(state)
cat("legal action\n")
cat(choices, "\n")
cat("action chosen\n")
cat(action, "\n")
}
# update state
state <- new_state(state, action, turn)
# switch turn
turn <- 3L - turn
}
}
simulate <- function(times = 1000)
{
# simulate game between the program
results <- rep(NA, times)
for (count in 1:times)
{
cat(sprintf("\rSimulating ...%4d/%4d", count, times))
turn <- 1L
state <- matrix(0L, nrow = 3, ncol = 3)
for (mv in 1:9)
{
action <- optim_action(state)
state <- new_state(state, action, turn)
win <- check_win(state, turn)
if (win) {
winner <- turn
break
} else if (mv == 9) {
winner <- 0L
break
}
turn <- 3L - turn
}
results[count] <- winner
}
cat("\n")
return(results)
}
set.seed(87)
Vfunc <- zobristht(20, 10, rehashable = TRUE)
## save low level AI
ttt_value1 <- Vfunc$clone()
devtools::use_data(ttt_value1, overwrite = TRUE, internal = TRUE)
rm(ttt_value1)
res <- simulate()
sim_res <- data.frame(prop.table(table(res)))
sim_res$simulation <- 0
max_train <- 5000
for (i in 1:max_train)
{
cat(sprintf("\rTraining ...%4d/%4d", i, max_train))
train(Vfunc, epsilon = 0.2)
if ((i %% 100) == 0) {
cat("\ntraining count:", i, "\n")
res <- simulate()
prop.table(table(res)) %>% print()
tmp <- data.frame(prop.table(table(res)))
tmp$simulation <- i
sim_res <- rbind(sim_res, tmp)
}
if (i == 1000) {
ttt_value2 <- Vfunc$clone()
devtools::use_data(ttt_value2, overwrite = TRUE, internal = TRUE)
rm(ttt_value2)
}
if (i == 1500) {
ttt_value3 <- Vfunc$clone()
devtools::use_data(ttt_value3, overwrite = TRUE, internal = TRUE)
rm(ttt_value3)
}
if (i == 5000) {
ttt_value4 <- Vfunc$clone()
devtools::use_data(ttt_value4, overwrite = TRUE, internal = TRUE)
rm(ttt_value4)
}
}
# graphing
dat <- spread(sim_res, key = res, value = Freq, fill = 0) %>%
melt(id.vars = "simulation", variable.name = "result_num", value.name = "frac")
dat$result <- ""
dat$result[dat$result_num == "0"] <- "Draw"
dat$result[dat$result_num == "1"] <- "Player1"
dat$result[dat$result_num == "2"] <- "Player2"
ggplot(dat, aes(simulation, frac, linetype = result, shape = result)) +
geom_line(size = 1, color = "grey25") +
geom_point(size = 2, color = "grey50") +
geom_point(size = 1.5, color = "grey10") +
xlab("number of training") + ylab("fraction") +
theme_bw()
if (!dir.exists("example/output")) dir.create("example/output")
ggsave("example/output/ttt-learning-curve.pdf", width = 10, height = 6)
ggsave("example/output/ttt-learning-curve.png", width = 10, height = 6)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.