Self-Attention Mechanism

Learn about the self-attention mechanism and how to compute the embedding of the words.

To understand how multi-head attention works, we first need to understand the self-attention mechanism.

Self-attention mechanism

Let's understand the self-attention mechanism with an example. Consider the following sentence:

A dog ate the food because it was hungry


In the preceding sentence, the pronoun 'it' could mean either 'dog' or 'food'. By reading the sentence, we can easily understand that the pronoun 'it' implies the 'dog' and not 'food'. But how does our model understand that in the given sentence, the pronoun 'it' implies the 'dog' and not 'food'? Here is where the self-attention mechanism helps us.

Representation of the words

In the given sentence, ‘A dog ate the food because it was hungry’, first our model computes the representation of the word ‘A’, next it computes the representation of the word ‘dog’, then it computes the representation of the word ‘ate’, and so on. While computing the representation of each word, it relates each word to all other words in the sentence to understand more about the word.



Computing the representation of the words

For instance, while computing the representation of the word 'it', our model relates the word 'it' to all the words in the sentence to understand more about the word 'it'.

As shown in the following figure, in order to compute the representation of the word 'it', our model relates the word 'it' to all the words in the sentence. By relating the word 'it' to all the words in the sentence, our model can understand that the word 'it' is related to the word 'dog' and not 'food'. As we can observe, the line connecting the word 'it' to 'dog' is thicker compared to the other lines, which indicates that the word 'it' is related to the word 'dog' and not 'food' in the given sentence:

 


Okay, but how exactly does this work? Now that we have a basic idea of what the self-attention mechanism is, let's understand more about it in detail.


 What is embedding?

Suppose our input sentence (source sentence) is 'I am good'. First, we get the embeddings for each word in our sentence. Note that the embeddings are just the vector representation of the word and the values of the embeddings will be learned during training.

Let x1 be the embedding of the word ‘I’, x2 be the embedding of the word ‘am’, x3 be the embedding of the word 'good'. Consider the following:

  • the embedding of the word ‘I’ is x1 = [1.76,2.22,…,6.66]
  • the embedding of the word ‘am’ is x2 = [7.77,0.63,..,5.35]
  • the embedding of the word ‘good’ is x3 = [11.44,10.10,..,3.33]


Then, we can represent our input sentence 'I am good' using the input matrix X (embedding matrix or input embedding), as shown here:


Note: The values used in the preceding matrix are arbitrary and used here just to give us a better understanding.

From the preceding input matrix X, we can understand that the first row of the matrix implies the embedding of the word 'I', the second row implies the embedding of the word 'am', and the third row implies the embedding of the word 'good'.

Dimensions of the embedding matrix#

The dimension of the input matrix X will be:

[sentence length×embedding dimension]


The number of words in our sentence (sentence length) is 3. Let the embedding dimension be 512; then, our input matrix (input embedding) dimension will be:

[3×512]

Now, from the input matrix X, we create three new matrices: a query matrix Q, key matrix K, and a value matrix V.

 Wait. What are these three new matrices? And why do we need them? They are used in the self-attention mechanism. We will see how exactly these three matrices are in the next post.

Comments