Class44-46 Introduction To Enncoder-Decoder Model Attention-03-09May2023
Class44-46 Introduction To Enncoder-Decoder Model Attention-03-09May2023
Introduction to Encoder-Decoder
Models and Attention Mechanism
Dr. Dileep A. D.
Associate Professor,
Multimedia Analytics Networks And Systems (MANAS) Lab,
School of Computing and Electrical Engineering (SCEE),
Indian Institute of Technology Mandi, Kamand, H.P.
Email: [email protected]
1
09-05-2023
2
09-05-2023
V V V V V V
s1 s2 s3 s4 s5 s6
W W W W W
U U U U U U
<start> Group of people shopping vegetables
y1 y2 y3 y4 y5 y6 5
• Using RNN:
Group of people <stop>
P ( y t j | (y)1t 1 ) softmax ( Vs t c) j
P(y2) P(y3) P(y4) P(yT)
• st is the state vector (hidden
V V V V representation) at time step t
• Recurrent connections
s0 s1 s2 s3 sT
ensure that information
W W W W about sequence y1, y2, ...,
U U U U yt−1 is embedded in st
y1 y2 y3 yT-1 • Hence,
<start> Group of vegetables P (y t j | (y)1t 1 ) P(y t j | st )
6
3
09-05-2023
• Using RNN:
Group of people
P ( y t j | s t ) softmax ( Vst c) j
<stop>
P(y2) P(y3) P(y4) P(yT)
• Recurrent connections
V V V V ensure that information
about sequence y1, y2, ...,
s0 s1 s2 s3 sT
yt−1 is embedded in st
W W W W
st tanh(Uy t Wst 1 b)
U U U U
st RNN (st 1 , y t )
y1 y2 y3 yT-1
<start> Group of vegetables
7
V V V V
s0 s1 s2 s3 sT
W W W W
U U U U
y1 y2 y3 yT-1
<start> Group of vegetables 8
4
09-05-2023
V V V V • Loss: L (θ ) Lt (θ )
t 1
s0 s1 s2 s3 sT
Lt (θ) log P (y t lt | (y)1t 1 )
W W W W
– where lt is the true word at
U U U U time step t
y1 y2 y3 yT-1 • One can also use LSTM or GRU
<start> Group of vegetables in the place of vanilla RNN 9
10
5
09-05-2023
W W W W
Lt (θ) log P (y t lt | (y)1t 1 )
U U U U – where lt is the true word at
y1 y2 y3 yT-1 time step t
<start> Group of vegetables 11
11
12
12
6
09-05-2023
13
13
14
14
7
09-05-2023
h
CNN
Encoder 15
15
16
8
09-05-2023
s0 s1 s2 s3 s4 sT • Model:
– Encoder: h CNN (I )
W W W W W T T T
– Decoder: st RNN (st 1 ,[ y t , h ] )
U U U U U
P ( y t j | st , I ) softmax ( Vs t c) j
y1 y2 y3 y4 yT-1
<start> A bird flying water • Model is trained using
Backpropagation through
h time and backpropagation
CNN through CNN
• Parameters: RNN parameters - U, V, W, b, c and
CNN parameters - W(CNN)
T
17
More Applications of
Encoder-Decoder Models
• Machine Translation:
– Translating sentence in one language to another
– Encoder: RNN
– Decoder: RNN
• Transliteration:
– Translating the script of one language to script of
another language
– Encoder: RNN
– Decoder: RNN
• Image Question Answering:
– Given the image and a question (sentence), generate
answer (word)
– Encoder: CNN + RNN
– Decoder: FCNN
18
18
9
09-05-2023
More Applications of
Encoder-Decoder Models
• Document Summarization:
– Generating a summary of a document
– Encoder: RNN
– Decoder: RNN
• Video Captioning:
– Generate sentence given video
– Encoder: RNN(CNN)
– Decoder: RNN
• And many more …
19
19
Attention
A bird
in over
flying
Encoder-Decoder Mechanism
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Encoder-decoder models
Decoder
can be made even more
V V V V V expressive by adding an
s0 s1 s2 s3 s4 sT “attention” mechanism
W W W W W
Encoder 20
20
10
09-05-2023
21
21
22
22
11
09-05-2023
23
24
24
12
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water image
• Now suppose an oracle told
h you which location in image to
focus on at a given time step t
CNN
25
25
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water image
+ c1 Now suppose an oracle told
•
a11 a21 a31
you which location in image to
h h1 h2 h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed it to
J
the decoder c a h t j 1
jt j
13
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water image
+ c2 Now suppose an oracle told
•
a12 a22 a32
you which location in image to
h h1 h2 h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed it to
J
the decoder c a h t j 1
jt j
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water image
+ c3 Now suppose an oracle told
•
a13
a23 a33 you which location in image to
h h1 h2 h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed it to
J
the decoder c a h t j 1
jt j
14
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water image
a14
+ c4 Now suppose an oracle told
•
a24 a34 you which location in image to
h h1 h2 h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed it to
J
the decoder c a h t
j 1
jt j
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water
cT image
+
Now suppose an oracle told
•
a1T
a3T you which location in image to
h h1 h a2T2 h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed it to
J
the decoder c a h t
j 1
jt j
15
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Let us revisit the decoder that
we have seen so far
V V V V V • Entire image is encoded into a
vector representation
s0 s1 s2 s3 s4 sT
• We feed this encoded
W W W W W representation to decoder at
each time step
U U U U U
• Suppose there are J concept
y1 y2 y3 y4 yT-1 locations (objects) in an
<start> A bird flying water
cT image
+
Now suppose an oracle told
•
a1T
a3T you which location in image to
h h1 h2a2T h3 focus on at a given time step t
LC(I) • We could just take a weighted average of the
CNN corresponding location representations (hj) and feed
it to the decoder
• Intuitively this should work better because we are
not overloading the decoder with irrelevant
information
• How do we convert this intuition into a model? 31
31
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • In practice we will not have
the information about the
V V V V V importance of each locations
– The machine will have to
s0 s1 s2 s3 s4 sT learn this from the data
W W W W W • The importance of concept
location representations (hj)
U U U U U in decoding and generating
y1 y2 y3 y4 yT-1 a word at the time t is
<start> A bird flying water captured by an attention
cT
+ score:
a1T
jt ATT t 1 j f (s , h )
a3T
h h1 h2a2T h3
• The attention score, αjt , captures the importance of
LC(I) the jth concept location in image for decoding the tth
CNN output word
• This attention score is normalized using the softmax
function to obtain attention weight:
exp( jt )
a jt J J = Number of concept locations
exp(
j 1
jt ) in image
32
32
16
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• Attention weight:
s0 s1 s2 s3 s4 sT exp( jt ) J = Number of
a jt concept locations
W W W W W J
j 1
exp( jt ) in image
U U U U U
y1 y2 y3 y4 yT-1 • Every location
water
representation (hj) at every
<start> A bird flying cT
+ time t is associated with one
a1T attention weight, ajt
a3T
h h1 h2a2T h3
This attention weight (ajt) along with location
LC(I) representation (hj) is used to generate context vector:
CNN J
ct a jt h j
j 1
33
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• Attention weight:
s0 s1 s2 s3 s4 sT exp( jt ) J = Number of
a jt concept locations
W W W W W J
j 1
exp( jt ) in image
U U U U U
y1 y2 y3 y4 yT-1 • Every location
water
representation (hj) at every
<start> A bird flying cT
+ time t is associated with one
a1T attention weight, ajt
a3T
h h1 h2a2T h3
This attention weight (ajt) along with location
LC(I) representation (hj) is used to generate context vector:
CNN J
ct a jt h j
j 1
34
17
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• Attention weight:
s0 s1 s2 s3 s4 sT exp( jt ) J = Number of
a jt concept locations
W W W W W J
j 1
exp( jt ) in image
U U U U U J
35
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• How to define fATT(.)?
s0 s1 s2 s3 s4 sT • Dot-product attention:
W W W W W – Attention score (αjt) will be
a dot product between the
U U U U U state of the decoder (st-1)
y1 y2 y4 yT-1 and location representation
y3
(hj)
<start> A bird flying water
cT
+ jt f ATT (st 1 , h j ) st 1 , h j
a1T
a3T
h h1 h2a2T h3
• Limitation: Applicable only
LC(I) when the dimension of st-1
CNN and hj are same
36
36
18
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V • How to define fATT(.)?
s0 s1 s2 s3 s4 sT • Multilayer perceptron
W W W W W attention:
– It is similar to the gates
U U U U U used in LSTM
y1 y2 y3 y4 yT-1 jt f ATT (st 1 , h j )
<start> A bird flying water
cT (U ATT h j WATT s t 1 b ATT )
+
a1T
a3T
h h1 h2a2T h3
LC(I)
CNN
37
37
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• How to define fATT(.)?
s0 s1 s2 s3 s4 sT • Multilayer perceptron
W W W W W attention:
– It is similar to the gates
U U U U U used in LSTM
y1 y2 y3 y4 yT-1 jt f ATT (st 1 , h j )
<start> A bird flying water
cT VATT , (U ATT h j WATT st 1 b ATT )
+
a1T
a3T UATT , VATT and WATT are the
h h1 h2a2T h3 parameters of multilayer perceptron
attention
LC(I)
CNN • Ω(.) can be either logistic or tan hyperbolic
38
38
19
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V • How to define fATT(.)?
s0 s1 s2 s3 s4 sT • Multilayer perceptron
W W W W W attention:
– It is similar to the gates
U U U U U used in LSTM
y1 y2 y3 y4 yT-1 jt f ATT (st 1 , h j )
<start> A bird flying water
cT VATT , (U ATT h j WATT st 1 b ATT )
+
a1T UATT , VATT and WATT are the
a3T
h h1 h2a2T h3 parameters of multilayer perceptron
attention
LC(I)
CNN • st-1 and hj are need not be of same dimension
• Then, the softmax operation is applied to obtain
attention weights
exp( jt )
a jt J J = Number of concept locations
exp(
j 1
jt ) in image
39
39
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • Attention score:
jt f ATT (st 1 , h j )
V V V V V
• How to define fATT(.)?
s0 s1 s2 s3 s4 sT • Multilayer perceptron
W W W W W attention:
– It is similar to the gates
U U U U U used in LSTM
y1 y2 y3 y4 yT-1 jt f ATT (st 1 , h j )
<start> A bird flying water
cT VATT , (U ATT h j WATT st 1 b ATT )
+
a1T
a3T UATT , VATT and WATT are the
h h1 h2a2T h3 parameters of multilayer perceptron
attention
LC(I)
CNN • st-1 and hj are need not be of same dimension
• Then, the softmax operation is applied to obtain
attention weights
• Attention weights are then used to generate context
vector associated with time t
40
40
20
09-05-2023
Attention
A bird
Mechanism:
flying over
Image Captioning
<stop>
P(y2) P(y3) P(y4) P(y5) P(yT) • How do we get the
location information?
V V V V V
• Image is encoded into a
s0 s1 s2 s3 s4 sT
last convolution layer
W W W W W representation
U U U U U
• Convolution
y1 y2 y3 y4 yT-1 representation has
<start> A bird flying water
cT location information
+
a1T
a3T
h h1 h2a2T h3
LC(I)
CNN
41
41
42
21
09-05-2023
512
1 2 196
h1 h2 h196
43
43
44
22
09-05-2023
45
Attention over time. As the model generates each word, its attention
changes to reflect the relevant parts of the image. (white indicates
the attended regions) [3]
[3] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan
Salakhudinov, Rich Zemel, Yoshua Bengio, “Show, Attend and Tell: Neural Image Caption
Generation with Visual Attention”, in Proceedings of the 32nd International Conference on
Machine Learning, PMLR vol. 37, pp. 2048-2057, 2015. 46
46
23
09-05-2023
47
48
48
24
09-05-2023
Encoder Decoder
X s Y
• Encoder is RNN
– The encoder reads the sentences only once and encodes
it in final state, sE,Ts
• Decoder is also RNN
– At each timestep, the decoder uses the embedding (sE,Ts)
from encoder to produce a new word
49
49
UE UE UE UE UE
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 50
50
25
09-05-2023
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Take a weighted average of the
WD WD WD WD state of encoder at different
instances of time j and feed it to
UE UE UE UE UE the decoder at time step t
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 51
51
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Take a weighted average of the
WD WD WD WD state of encoder at different
instances of time j and feed it to
UE UE UE UE UE the decoder at time step t
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 52
52
26
09-05-2023
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Take a weighted average of the
WD WD WD WD state of encoder at different
instances of time j and feed it to
UE UE UE UE UE the decoder at time step t
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 53
53
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Take a weighted average of the
WD WD WD WD state of encoder at different
instances of time j and feed it to
UE UE UE UE UE the decoder at time step t
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 54
54
27
09-05-2023
55
56
28
09-05-2023
57
c5 ct a jt s E, j
+ j 1
a15 a35 a45 a55 • Context vector (ct) is the weighted
a25
sum of hidden state of encoder
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Context vector combines all the state
WD WD WD WD information of encoder and gives the
how much each is important in the
UE UE UE UE UE combination
• This context vector is used to
x1 x2 x3 x4 x6 perform decoding operation at time t
i/p: <Go> My name is Dileep 58
58
29
09-05-2023
59
VD VD VD VD VD • Dot-product attention:
– Attention score (αjt) will be a dot
sD,0 sD,1 sD,2 sD,3 sD,4 sD,5 product between the state of the
WD WD WD WD decoder (sD,t-1) and state of the
encoder at time j (sE,j)
UD UD UD UD UD
jt f ATT (s D,t 1 , s E, j ) s D,t 1 , s E, j
y1 y2 y3 y4 y6
<Go> Mera naam Dileep hai – Limitation: Applicable only when
the dimension of sD,t-1 and sE,j are
+ c5 same
a15 a35 a45 a55 – That is, when the number of
a25
nodes in the hidden layers of
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 both the RNN encoder and RNN
decoder must be same
WD WD WD WD
UE UE UE UE UE
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 60
60
30
09-05-2023
VD VD VD VD VD • Multilayer perceptron
attention:
sD,0 sD,1 sD,2 sD,3 sD,4 sD,5 – It is similar to the gates used in
WD WD WD WD LSTM
jt f ATT (s D,t 1 , s E, j )
UD UD UD UD UD
v ATT , (U ATT s E, j WATT s D,t 1 b ATT )
y1 y2 y3 y4 y6
<Go> Mera naam Dileep hai
+ c5
a15 a35 a45 a55
a25
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5
WD WD WD WD
UE UE UE UE UE
x1 x2 x3 x4 x6
i/p: <Go> My name is Dileep 61
61
VD VD VD VD VD • Multilayer perceptron
attention:
sD,0 sD,1 sD,2 sD,3 sD,4 sD,5 – It is similar to the gates used in
WD WD WD WD LSTM
jt f ATT (s D,t 1 , s E, j )
UD UD UD UD UD
v ATT , (U ATT s E, j WATT s D,t 1 b ATT )
y1 y2 y3 y4 y6
<Go> Mera naam Dileep hai
– UATT , vATT and WATT are the
parameters of multilayer
+ c5 perceptron attention
a15 a35 a45 a55 – sD,t-1 and sE,j are need not be of
a25
same dimension
sE,0 sE,1 sE,2 sE,3 sE,4 sE,5 • Then, the softmax operation is
WD WD WD WD applied to obtain attention
weights
UE UE UE UE UE
• Attention weights are then used
x1 x2 x3 x4 x6 to generate context vector
i/p: <Go> My name is Dileep associated with time t 62
62
31
09-05-2023
63
63
Attention-based Models:
Transformers
64
32
09-05-2023
65
66
66
33
09-05-2023
67
68
68
34
09-05-2023
Text Books
1. Ian Goodfellow, Yoshua Bengio and Aaron Courville, Deep learning,
MIT Press, Available online: http://www.deeplearningbook.org,
2016
2. Charu C. Aggarwal, Neural Networks and Deep Learning, Springer,
2018
3. B. Yegnanarayana, Artificial Neural Networks, Prentice-Hall of India,
1999.
4. Satish Kumar, Neural Networks - A Class Room Approach, Second
Edition, Tata McGraw-Hill, 2013.
5. S. Haykin, Neural Networks and Learning Machines, Prentice Hall of
India, 2010.
6. C. M. Bishop, Pattern Recognition and Machine Learning, Springer,
2006.
7. J. Han and M. Kamber, Data Mining: Concepts and Techniques,
Third Edition, Morgan Kaufmann Publishers, 2011.
8. S. Theodoridis and K. Koutroumbas, Pattern Recognition, Academic
Press, 2009. 69
69
35