First off, lets discuss why the attention mechanism has become so prevalent almost everywhere in the deep learning space.
The main problem that self-attention solves is the long-dependency problem.
Before the transformer and attention, the RNN was the most common for language modelling, though it ran into trouble with something called long-range context understanding. Where the beginning words of a large sentence are “forgotten” as the model processes more and more words. This happens because of the vanishing gradient problem.
Mathematically, we can see this problem happen:
Given an input sentence , the model sequentially processes at each step:
where
t represents the step
represents the activation function
is the single word that is processed
You can think of , the hidden state as being the memory of the model.
The vanishing problem rises as we begin to backpropagate the network in order to compute the gradient:
It specifically lies here, in the Jacobian:
Where is the pre-activation neuron computation.
If either
These are the root causes of vanishing gradients,
and are commonly susceptible to this because of the math which creates them.
For , initialization schemes like He and regularization techniques like the L2 make matrix values small to prevent exploding gradients.
And the activation functions used for are designed to prevent the same aforementioned problem.
The fundamental problem with RNN’s come from the foundation it’s built on. Modifying the initialization scheme or the activation functions would then lead to other, larger problems (like exploding gradients, where gradient values are far too large) and wouldn’t fix anything.
Self-attention largely solves this by weighing each word with a score based on its importance.
An intuitive way to understand self-attention is to think about our own human attention, when you read a book, you don’t memorize every single word you read, only the important parts: the plot, character names, personalities etc.
Self-Attention is a way to mathematically model the importance of words in a similar way humans do
Formally, self-attention is denoted as:
Where:
being the encoded embedding input sequence
= dimension size of model (usually 512)
are the projection matrices that project the input onto the 3 spaces (Query, Key, Value). Usually initialized with He Initialization (Where is the number of input neurons)
The numerator of essentially computes the attention score, we say that this numerator is sort of matching the Query representation for the token to the Key representation of it.
We then scale by to ensure values are not too large, preventing cases where only a few words are attended to (focused on, paid attention to).
is used to turn the attention scores into a probability distribution ensuring a row adds up to 1.
At this point we have our attention scores for the sequence, but we must multiply it by so we can actually apply these attention scores to the words.
Self attention differs from RNN’s in the sense that it is not sequential, instead of processing word by word, the entire input sequence is given to the mechanism. Self-attention finds its order in input from positional encoding
Think of each matrix learning this:
(Query) learns and understands the question of “What am I look for in this token?” Where each row of Q represents what information the current token (word) is seeking.
For an example, lets say the current token is “run”, The row of containing the representation of “run” might contain an encoded representation asking “what is running? Who is running?”
(Keys) learns the question “What information can this token provide?” Each row represents information about the token
If we continue with the current token of “run”, containing the token “run” might say “This word is an action, something that a thing can do”
(Values) says “What do information do I actually have?” Where each row contains the actual information the token represents.
Think of the entire process of computing attention scores: Initializing projection matrices, creating the project inputs of Q, K, V, and then finally computing the attention scores using the attention formula; All as one head.
And so as the name suggests, Multi-Head attention is the combination of multiple of these attention heads.
A multi-head attention block contains n heads. This means we initialize n different sets of projection matrices (), and compute n unique attention scores.
For multi-head attention,
The output is still the same as the number of heads are concatenated:
We then multiply by a initialized matrix of the same shape for the final output.
By using multiple unique attention heads, we can compute a more accurate and better attention score. Think about it this way: With each head having a different set of random projection matrices, each head will learn something different about the tokens, each head will focus on a different part about that token in a unique way, when they all come together through concatenation that unique angle is added to the way the final attention is computed
- sequence length
- hidden dimension size
The time complexity for RNN’s is:
Training -
Inference -
because of the hidden state computation:
For 1 time step () we must compute and since has shape (), we compute
Time complexity of self-attention:
Training -
Inference -
is from the fact that we compute , which has shape () given that both and are shape (). We compute because each token (word) needs an attention score with respect to all other tokens.
Though the time complexities appear similar in theory, in application, the fact that RNN’s are sequential, and that attention allow for parallelization, makes attention much faster.
Parallelization means that while an RNN takes 100 steps for a input sequence of 100 words no matter what, an attention mechanism can compute all matrix multiplication and attention scores at the exact same time thanks to parallelization.