library(tidyverse)
library(tidygraph)
library(ggraph)
library(successr)
kable <- knitr::kable

Updating equations

The Successor Representation (SR) encodes how likely it is that you'll end up in a particular state of the environment, given the state you're currently in. It's easiest to think of this as $p(state_B | state_A)$, though technically it's something closer to "discounted expected counts of future state occupancies." More on this later. Everything I'm saying now applies equally well to Successor Features (SFs), which is a generalization of the SR. (Or, we could alternatively say that the SR is a special case of SFs...)

Updates to the SR can be expressed as:

$M_{new}(s_{t-1}) \leftarrow M_{old}(s_{t-1}) + \alpha \delta$,

where $\delta = 1(s_t) + \gamma M(s_t) - M(s_{t-1})$.

$M$ is a square $N \times N$ matrix, where $N$ refers to the number of nodes/states in the network/environment. Note that while the $M$ should technically be indexed as a matrix, I treat it as a single-input function that returns the entire row corresponding to its input. Therefore, updates to the SR are performed in a row-wise manner.

Occupied states are denoted as $s$, such that the current state at time $t$ is $s_t$, and the predecessor state is $s_{t-1}$.

The prediction error $\delta$ encodes the difference between what was expected, versus what was observed. The idea is that you'd like to push your successor estimates away from your expectations, and closer to observed reality... with the ultimate goal of learning successor estimates that minimize prediction errors in the long run. Each update is tempered by the learning rate $\alpha$, which dictates how much each observation pushes your estimates towards observed reality.

What counts as observation? In the definition of $\delta$, the one-hot term $1(s_t)$ is a function that returns a vector filled with zeroes (0), except at the index of $s_t$, which takes on the value one (1). The resulting one-hot vector has the same length as $M(s_t)$. You can think of this as an observation signal reflecting the transition between states $s_{t-1} \rightarrow s_t$.

What counts as expectation? In the definition of $\delta$, your most up-to-date estimate of the successor values is encoded by $M(s_{t-1})$. Since there's a discrepancy between expectation and observation, you subtract your expectation from the observation signal.

Okay, what about that weird stuff in the middle of $\delta$? That is where the magic happens, and is the reason why this algorithm is called a successor representation. Let's say that we've got a three-person network consisting of Asami, Mako, and Korra (yes, I've been re-watching old cartoons, why do you ask?). If we observe an interaction between Asami and Mako, then an interaction between Mako and Korra, we might want to encode predictive knowledge that seeing Asami might ostensibly lead to seeing Korra. This is in spite of never having observed Asami and Korra interact with each other. Because Korra is the successor of Mako, who is the successor of Asami, we can start to encode these longer-range relationships by adding something to our observation (one-hot) signal: the term $M(s_t)$. Just as the learning rate $\alpha$ tempers how strongly prediction errors cause an update to the successor estimates, the lookahead parameter $\gamma$ controls how strongly successor knowledge is added to the observation signal.

SR predictive horizion

Therefore, with greater values of $\gamma$, the representation has a longer "lookahead" horizon, as illustrated here.

This means that if we know $\gamma$, we know exactly how many lookahead steps the SR is taking into the future, using the formula: $lookahead = \frac{1}{1 - \gamma}$. We'll note that the lookahead horizon does not increase linearly with $\gamma$.

tibble(gamma = seq(0, 0.9, 0.1)) %>%
  mutate(lookahead = 1 / (1 - gamma)) %>%
  kable()

Using algebraic rearrangement, we can also compute what value of $\gamma$ is required to achieve some desired lookahead horizon: $\gamma = 1 - L^{-1}$.

tibble(lookahead = 1:10) %>%
  mutate(gamma = 1 - (1 / lookahead)) %>%
  kable()

Counts and probabilities

There's a tight coupling between $\gamma$, the lookahead horizon, and what the successor values are actually encoding.

Normally, the SR computes expected counts, such that it's possible that $s_{t \rightarrow t-1} \ge 1$. Oftentimes, we care a little more about the probability $p(s_t | s_{t-1}, \gamma)$. We can convert counts into probabilities using the equation $M (1 - \gamma)$. We can inversely convert probabilities into counts: $M \frac{1}{1 - \gamma}$.

Let's illustrate using a real network:

# Load network
karate_network <- karate %>%
  tbl_graph(edges = ., directed = FALSE)

# Simulate random walk through network (observations)
karate_walk <- karate_network %>%
  generate_random_walk(n_obs = 5000)

# Learn successor values using many parameter values
karate_successor <- karate_network %>%
  initialize_successor() %>%
  learn_from_param_sweep(observations = karate_walk,
                         alphas = 0.1,
                         gammas = seq(0, 0.8, 0.4),
                         bidirectional = TRUE)

# Plot raw counts
karate_successor %>%
  mutate(across(.cols = from:to, .fns = factor),
         from = fct_rev(from)) %>%
  mutate(alpha = str_c("alpha=", alpha),
         gamma = str_c("gamma=", gamma)) %>%
  ggplot(aes(x = to, y = from, fill = successor_value)) +
  facet_grid(rows = vars(alpha), cols = vars(gamma)) +
  geom_tile() +
  scale_fill_viridis_c() +
  coord_fixed() +
  theme(legend.position = "bottom",
        axis.text = element_blank()) +
  ggtitle("Raw counts")

# Plot (normalized) probabilities
karate_successor %>%
  counts_to_probability(successor_value) %>%
  mutate(across(.cols = from:to, .fns = factor),
         from = fct_rev(from)) %>%
  mutate(alpha = str_c("alpha=", alpha),
         gamma = str_c("gamma=", gamma)) %>%
  ggplot(aes(x = to, y = from, fill = successor_value)) +
  facet_grid(rows = vars(alpha), cols = vars(gamma)) +
  geom_tile() +
  scale_fill_viridis_c() +
  coord_fixed() +
  theme(legend.position = "bottom",
        axis.text = element_blank()) +
  ggtitle("Normalized probabilities")

We can also check whether these normalized probabilities sum to (approximately) 1.

karate_successor %>%
  counts_to_probability(successor_value) %>%
  group_by(alpha, gamma) %>%
  # Divide by 34 because that's how many people are in the network
  summarise(successor_sum = sum(successor_value) / 34,
            .groups = "drop")

Visualizing the SR at work

Let's take a look at a "linear" network.

linear_network <- create_path(n = 5, directed = TRUE) %>%
  mutate(name = row_number())

linear_network %>%
  ggraph(layout = "linear") +
  geom_edge_link(arrow = arrow(length = unit(4, "mm")), 
                 end_cap = circle(3, "mm")) + 
  geom_node_label(aes(label = name))

Let's simulate walking along this path many times, then look at what the SR looks like. What we can appreciate from this visualization is that with increasing $\gamma$, the algorithm encodes a higher probability of ending up in the (terminal) state 5, even from the starting state 1.

linear_walk <- linear_network %>%
  generate_random_walk(n_obs = 5, start_here = 1) %>%
  expand_grid(rep = 1:1000, .)

linear_successor <- linear_network %>%
  initialize_successor() %>%
  learn_from_param_sweep(observations = linear_walk,
                         alphas = 0.1,
                         gammas = seq(0, 0.8, 0.4),
                         bidirectional = FALSE)

linear_successor %>%
  mutate(across(.cols = from:to, .fns = factor),
         from = fct_rev(from)) %>%
  mutate(alpha = str_c("alpha=", alpha),
         gamma = str_c("gamma=", gamma)) %>%
  ggplot(aes(x = to, y = from, fill = successor_value)) +
  facet_grid(rows = vars(alpha), cols = vars(gamma)) +
  geom_tile() +
  scale_fill_viridis_c() +
  coord_fixed() +
  theme(legend.position = "bottom")


psychNerdJae/successr documentation built on Dec. 22, 2021, 9:56 a.m.