GRU units

Feed forward pass

par(mar=rep(0,4))
plot(c(0,10),c(0,10),type="n",
     # bty ="n",
     xlab ="", ylab= "",xaxt="n",yaxt="n")

text(x=0,y=9.75,labels="GRU unit feed forward",pos=4,cex=2)





symbols(x=5,y=5,rectangles = matrix(8,nrow=1,ncol=2),inches = F,add = T) # unit itself
text(x=c(0,0.5,10,9.25),y=c(8,0,8,10),
     labels = c(expression('h'[t-1]),expression('x'[t]),expression('h'[t]),expression('y'[t])))

symbols(x=c(3,3.5,4.5,5,6,6.5),y=rep(2.5,6),rectangles = matrix(c(0.45,0.6),nrow=6,ncol=2,byrow = T),inches = F,add = T) # weight matrix
text(x=c(3,3.5,4.5,5,6,6.5),y=rep(2.5,6),
     labels = c(expression('W'[r]),expression('U'[r]),
                expression('W'[z]),expression('U'[z]),expression('W'[g]),expression('U'[g]))
     ) # weight matrix

segments(x0=c(1),x1 = c(6.5),y0 = c(0),y1 = c(0)) # xt segment
arrows(x0=c(3.5,5,6.5),y0=rep(0,3),y1=rep(2.2,3),length = 0.1) # xt arrows

symbols(x=c(3.25,4.75,6.25),y=rep(4,3),circles = rep(0.15,3),inches = F,add = T) # addition of the weight matrix from x and hidden
text(x=c(3.25,4.75,6.25),y=rep(4,3),labels="+") # addition of the weight matrix from x and hidden

symbols(x=c(3.25,4.75,6.25),y=c(rep(5,3)),rectangles = matrix(c(1,0.5),nrow=3,ncol=2,byrow = T),inches = F,add = T) # linearisation
text(x=c(3.25,4.75,6.25),y=c(rep(5,3)),labels=c("sigmoid","sigmoid","tanh")) # linearisation

symbols(x=c(1.75,4.75,4.75,6.25,6.25),y=c(6,8,7,6,8),circles = rep(0.15,5),inches = F,add = T) # addition of the weight matrix from x and hidden
text(x=c(1.75,4.75,4.75,6.25,6.25),y=c(6,8,7,6,8),labels=c("x","x","1-","x","+")) # addition of the weight matrix from x and hidden

# 

# 

# # arrows(x0)
# 
segments(x0=c(1.25),x1 = c(1.25),y0 = c(8),y1 = c(1.5)) # ht-1 segment
segments(x0=c(1.25),x1 = c(4.5),y0 = c(1.5),y1 = c(1.5)) # ht-1 segment
# 

arrows(x0=c(3,4.5),y0=rep(1.5,2),y1=rep(2.2,2),length = 0.1) # ht-1 arrows
arrows(x0=c(6),y0=rep(1.25,1),y1=rep(2.2,1),length = 0.1) # ht-1 arrows
# 
arrows(x0=c(3,3.5,4.5,5,6,6.5),x1=c(3,3.5,4.5,5,6,6.5)+c(0.1,-0.1),y0=rep(2.86),y1=rep(3.85,6),length = 0.1) # before addition arrows
# 
arrows(x0=c(3.25,4.75,6.25),x1=c(3.25,4.75,6.25),y0=rep(4.3,3),y1=rep(4.75,3),length = 0.1) # before linearisation
# 
arrows(x0=c(0.5,4.9,6.4),x1=c(4.6,6.1,9.75),y0=rep(8,2),length = 0.1) # h t-1 to ht arrows
# 
arrows(x0=c(1.75),y0=rep(8,1),y1=c(6.3),length = 0.1)# after linearisation
segments(x0=c(1.75),y0=rep(5.7),y1=c(1.25))# after linearisation
segments(x0=c(1.75),x1 = 6,y0=c(1.25))# after linearisation
arrows(x0=c(4.75),y0=rep(5.25,1),y1=c(6.75),length = 0.1)# after linearisation


segments(x0=c(3.25),y0=c(5.25),y1=c(6))# after linearisation
arrows(x0=c(6.25),y0=c(5.25),y1=c(5.7),length = 0.1)# after linearisation

arrows(x0=c(4.75),y0=c(7.3),y1=c(7.7),length = 0.1) # still still after linearisation
# 
arrows(x0=c(3.25,4.75),x1=c(1.9,6.1),y0=c(6,6),length = 0.1) # still after linearisation
# 
arrows(x0=c(6.25),y0=c(6.3),y1=c(7.7),length = 0.1)# after linearisation
# 
arrows(x0=c(9.25),y0=c(8),y1=c(9.5),length = 0.1) # after X
# 
text(x=c(3.25,4.75,6.25),y=5.5,labels=c(expression('r'[t]),expression('z'[t]),expression('g'[t])),pos = 4)

$$ r_t = sigmoid(h_{t-1} * W_r + x_t * U_r) $$

$$ z_t = sigmoid(h_{t-1} * W_z + x_t * U_z) $$

$$ g_t = tanh(W_g * (h_{t-1} \cdot r_t) + x_t * U_g) $$

$$ h_t = y_t = h_{t-1} \cdot (1 - z_t) + (z_t \cdot g_t) $$

Back propagation pass

par(mar=rep(0,4))
plot(c(0,10),c(0,10),type="n",
     # bty ="n",
     xlab ="", ylab= "",xaxt="n",yaxt="n")

text(x=0,y=9.75,labels="GRU unit back propagation",pos=4,cex=2)





symbols(x=5,y=5,rectangles = matrix(8,nrow=1,ncol=2),inches = F,add = T) # unit itself
text(x=c(0,0.5,10,9.25),y=c(8,0,8,10),
     labels = c(expression('h'[t-1]),expression('x'[t]),expression('h'[t]),expression('y'[t])))

symbols(x=c(3,3.5,4.5,5,6,6.5),y=rep(2.5,6),rectangles = matrix(c(0.45,0.6),nrow=6,ncol=2,byrow = T),inches = F,add = T) # weight matrix
text(x=c(3,3.5,4.5,5,6,6.5),y=rep(2.5,6),
     labels = c(expression('W'[r]),expression('U'[r]),
                expression('W'[z]),expression('U'[z]),expression('W'[g]),expression('U'[g]))
     ) # weight matrix

arrows(x1=c(1),x0 = c(6.5),y0 = c(0),y1 = c(0),length = 0.1) # xt segment
arrows(x0=c(3.5,5,6.5),y1=rep(0,3),y0=rep(2.2,3),length = 0.1) # xt arrows

symbols(x=c(3.25,4.75,6.25),y=rep(4,3),circles = rep(0.15,3),inches = F,add = T) # addition of the weight matrix from x and hidden
text(x=c(3.25,4.75,6.25),y=rep(4,3),labels="+") # addition of the weight matrix from x and hidden

symbols(x=c(3.25,4.75,6.25),y=c(rep(5,3)),rectangles = matrix(c(1,0.5),nrow=3,ncol=2,byrow = T),inches = F,add = T) # linearisation
text(x=c(3.25,4.75,6.25),y=c(rep(5,3)),labels=c("sigmoid","sigmoid","tanh")) # linearisation

symbols(x=c(1.75,4.75,4.75,6.25,6.25),y=c(6,8,7,6,8),circles = rep(0.15,5),inches = F,add = T) # addition of the weight matrix from x and hidden
text(x=c(1.75,4.75,4.75,6.25,6.25),y=c(6,8,7,6,8),labels=c("x","x","1-","x","+")) # addition of the weight matrix from x and hidden

# # arrows(x0)
# 
arrows(x0=c(1.25),x1 = c(1.25),y1 = c(8),y0 = c(1.5),length = 0.1) # ht-1 segment
segments(x1=c(1.25),x0 = c(4.5),y0 = c(1.5),y1 = c(1.5)) # ht-1 segment
# 

arrows(x0=c(3,4.5),y1=rep(1.5,2),y0=rep(2.2,2),length = 0.1) # ht-1 arrows
arrows(x0=c(6),y1=rep(1.25,1),y0=rep(2.2,1),length = 0.1) # ht-1 arrows
# 
arrows(x1=c(3,3.5,4.5,5,6,6.5),x0=c(3,3.5,4.5,5,6,6.5)+c(0.1,-0.1),y1=rep(2.86),y0=rep(3.85,6),length = 0.1) # before addition arrows
# 
arrows(x1=c(3.25,4.75,6.25),x0=c(3.25,4.75,6.25),y1=rep(4.3,3),y0=rep(4.75,3),length = 0.1) # before linearisation
# 
arrows(x1=c(0.5,4.9,6.4),x0=c(4.6,6.1,9.75),y0=rep(8,2),length = 0.1) # h t-1 to ht arrows
# 
arrows(x0=c(1.75),y1=rep(8,1),y0=c(6.3),length = 0.1)# after linearisation
segments(x0=c(1.75),y1=rep(5.7),y0=c(1.25))# after linearisation
segments(x1=c(1.75),x0 = 6,y0=c(1.25))# after linearisation
arrows(x0=c(4.75),y1=rep(5.25,1),y0=c(6.75),length = 0.1)# after linearisation


arrows(x0=c(3.25),y1=c(5.25),y0=c(6),length = 0.1)# after linearisation
arrows(x0=c(6.25),y1=c(5.25),y0=c(5.7),length = 0.1)# after linearisation

arrows(x0=c(4.75),y1=c(7.3),y0=c(7.7),length = 0.1) # still still after linearisation
# 
arrows(x1=c(4.75),x0=c(6.1),y0=c(6),length = 0.1) # still after linearisation
segments(x1=c(3.25),x0=c(1.9),y0=c(6)) # still after linearisation
# 
arrows(x0=c(6.25),y1=c(6.3),y0=c(7.7),length = 0.1)# after linearisation
# 
arrows(x0=c(9.25),y1=c(8),y0=c(9.5),length = 0.1) # after X
# 
text(x=c(3.25,4.75,6.25),y=5.5,labels=c(expression('r'[t]),expression('z'[t]),expression('g'[t])),pos = 2)

text(x = c(9.25), y = c(9.25), labels=c(expression(delta~"1")), pos=4,cex=0.8)
text(x = c(9.75), y = c(8), labels=c(expression(delta~"2")), pos=3,cex=0.8)

text(x = c(6.25), y = c(7), labels=c(expression(delta~"3")), pos=4,cex=0.8)
text(x = c(7), y = c(8), labels=c(expression(delta~"3")), pos=3,cex=0.8)
text(x = c(5.5), y = c(8), labels=c(expression(delta~"3")), pos=3,cex=0.8)

text(x = c(4), y = c(8), labels=c(expression(delta~"4")), pos=3,cex=0.8)
text(x = c(4.75), y = c(7.55), labels=c(expression(delta~"5")), pos=4,cex=0.8)

text(x = c(4.75), y = c(6.5), labels=c(expression(delta~"6")), pos=4,cex=0.8)

text(x = c(5.75), y = c(6), labels=c(expression(delta~"7")), pos=3,cex=0.8)
text(x = c(6.25), y = c(5.5), labels=c(expression(delta~"8")),pos = 4,cex=0.8)
text(x = c(4.75), y = c(5.5), labels=c(expression(delta~"9")),pos = 4,cex=0.8)

text(x = c(6.25), y = c(4.35), labels=c(expression(delta~"10")), pos=4,cex=0.8)
text(x = c(6.25)+0.2, y = c(4.35)-1, labels=c(expression(delta~"10")), pos=4,cex=0.8)
text(x = c(4.75), y = c(4.35), labels=c(expression(delta~"11")), pos=4,cex=0.8)
text(x = c(4.75)+0.2, y = c(4.35)-1, labels=c(expression(delta~"11")), pos=4,cex=0.8)


text(x = c(6.5)-0.1, y = c(0.5), labels=c(expression(delta~"12")), pos=4,cex=0.8)
text(x = c(6)-0.1, y = c(1.85), labels=c(expression(delta~"13")), pos=4,cex=0.8)
text(x = c(5)-0.1, y = c(0.5), labels=c(expression(delta~"14")), pos=4,cex=0.8)
text(x = c(4.5)-0.1, y = c(1.85), labels=c(expression(delta~"15")), pos=4,cex=0.8)

text(x = c(1.75), y = c(4.35), labels=c(expression(delta~"13")), pos=4,cex=0.8)

text(x = c(2.75), y = c(6), labels=c(expression(delta~"16")), pos=3,cex=0.8)
text(x = c(1.75), y = c(7), labels=c(expression(delta~"17")), pos=4,cex=0.8)

text(x = c(3.25), y = c(4.35), labels=c(expression(delta~"18")), pos=4,cex=0.8)
text(x = c(3.25)+0.2, y = c(4.35)-1, labels=c(expression(delta~"18")), pos=4,cex=0.8)

text(x = 1.5, y = c(8), labels=c(expression(delta~"19")), pos=3,cex=0.8)

text(x = c(3.5)-0.1, y = c(0.5), labels=c(expression(delta~"20")), pos=4,cex=0.8)
text(x = c(3)-0.1, y = c(1.85), labels=c(expression(delta~"21")), pos=4,cex=0.8)


text(x = 2.25, y = 1.5, labels=c(expression(delta~"22")), pos=3,cex=0.8)

text(x = 0.5, y = c(8), labels=c(expression(delta~"23")), pos=3,cex=0.8)
text(x = 1, y = 0, labels=c(expression(delta~"24")), pos=3,cex=0.8)

To perform the BPTT with a GRU unit, we have the eror comming from the top layer ($\delta 1$), the future hidden states ($\delta 2$). Also, we have stored during the feed forward the states at each step of the feeding. In the case of the future layer, this error is just set to zero if not calculated yet. For convention, $\cdot$ correspond to point wise multiplication, while $*$ correspond to matrix multiplication.

The rules on how to back prpagate come from this post.

$$\delta 3 = \delta 1 + \delta 2 $$

$$\delta 4 = (1 - z_t) \cdot \delta 3 $$

$$\delta 5 = \delta 3 \cdot h_{t-1} $$

$$\delta 6 = 1 - \delta 5 $$

$$\delta 7 = \delta 3 \cdot g_t $$

$$\delta 8 = \delta 3 \cdot z_t $$

$$\delta 9 = \delta 7 + \delta 8 $$

$$\delta 10 = \delta 8 \cdot tanh'(g_t) $$

$$\delta 11 = \delta 9 \cdot sigmoid'(z_t) $$

$$\delta 12 = \delta 10 * W_g^T $$ $$\delta 13 = \delta 10 * U_g^T $$ $$\delta 14 = \delta 11 * W_z^T $$ $$\delta 15 = \delta 11 * U_z^T $$

$$\delta 16 = \delta 13 \cdot h_{t-1} $$ $$\delta 17 = \delta 13 \cdot r_t $$

$$\delta 18 = \delta 17 \cdot sigmoid'(r_t) $$

$$\delta 19 = \delta 17 + \delta 4 $$

$$\delta 20 = \delta 18 * W_r^T $$ $$\delta 21 = \delta 18 * U_r^T $$

$$\delta 22 = \delta 21 + \delta 15 $$

$$\delta 23 = \delta 19 + \delta 22 $$

$$\delta 24 = \delta 12 + \delta 14 +\delta 20 $$

The error $\delta 23$ and $\delta 24$ are used for the next layers. Once all those errors are available, it is possible to calculate the weight update.

$$\delta W_r = \delta W_f + h_{t-1}^T * \delta 10 $$ $$\delta U_r = \delta U_f + x_{t}^T * \delta 10 $$

$$\delta W_z = \delta W_i + h_{t-1}^T * \delta 11 $$ $$\delta U_z = \delta U_i + x_{t}^T * \delta 11 $$

$$\delta W_g = \delta W_g + (h_{t-1}^T \cdot r_t) * \delta 18 $$ $$\delta U_g = \delta U_g + x_{t}^T * \delta 18 $$



Try the rnn package in your browser

Any scripts or data that you put into this service are public.

rnn documentation built on April 22, 2023, 1:12 a.m.