Simple Self-Attention from Scratch In rnn: Recurrent Neural Network

```knitr::opts_chunk\$set(
collapse = TRUE,
comment = "#>"
)
library(rnn)
library(attention)
```

This vignette describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.

We begin by generating encoder representations of four different words.

```# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)
```

Next, we stack the word embeddings into a single array (in this case a matrix) which we call `words`.

```# stacking the word embeddings into a single array
words = rbind(word_1,
word_2,
word_3,
word_4)
```

Let's see what this looks like.

```print(words)
```

Next, we generate random integers on the domain `[0,3]`.

```# initializing the weight matrices (with random values)
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
```

Next, we generate the Queries (`Q`), Keys (`K`), and Values (`V`). The `%*%` operator performs the matrix multiplication. You can view the R help page using `help('%*%')` (or the online An Introduction to R).

```# generating the queries, keys and values
Q = words %*% W_Q
K = words %*% W_K
V = words %*% W_V
```

Following this, we score the Queries (`Q`) against the Key (`K`) vectors (which are transposed for the multiplation using `t()`, see `help('t')` for more info).

```# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
```

We now generate the `weights` matrix.

```weights = attention::ComputeWeights(scores)
```

Let's have a look at the `weights` matrix.

```print(weights)
```

Finally, we compute the `attention` as a weighted sum of the value vectors (which are combined in the matrix `V`).

```# computing the attention by a weighted sum of the value vectors
attention = weights %*% V
```

Now we can view the results using:

```print(attention)
```

Try the rnn package in your browser

Any scripts or data that you put into this service are public.

rnn documentation built on April 22, 2023, 1:12 a.m.