Attention (machine learning)

From Wikipedia, the free encyclopedia
Jump to navigation Jump to search

In neural networks, attention is a technique that mimics cognitive attention. The effect enhances some parts of the input data while diminishing other parts — the thought being that the network should devote more focus to that small but important part of the data. Learning which part of the data is more important than others depends on the context and is trained by gradient descent.

Attention-like mechanisms were introduced in the 1990s under names like multiplicative modules, sigma pi units, and hypernetworks.[1] Its flexibility comes from its role as "soft weights" that can change during runtime, in contrast to standard weights that must remain fixed at runtime. Uses of Attention include memory in Neural Turing Machines, reasoning tasks in Differentiable Neural Computers,[2] language processing in Transformers, and multi-sensory data processing (sound, images, video, text) in Perceivers.[3][4][5][6]

A language translation example[edit]

To build a machine that translates English-to-French (see diagram below), one starts with an Encoder-Decoder and grafts an attention unit to it. In the simplest case such as the example below, the attention unit is just lots of dot products of recurrent layer states and does not need training. In practice, the attention unit consists of 3 fully connected neural network layers that needs to be trained. The 3 layers are called Query, Key, and Value.

Encoder-Decoder with attention. This diagram uses specific values to relieve an already cluttered notation alphabet soup. The left part (in black) is the Encoder-Decoder, the middle part (in orange) is the attention unit, and the right part (in grey & colors) is the computed data. Grey regions in H matrix and w vector are zero values. Numerical subscripts are examples of vector sizes. Lettered subscripts i and i-1 indicate time step.
Encoder-Decoder with attention. This diagram uses specific values to relieve an already cluttered notation alphabet soup. The left part (in black) is the Encoder-Decoder, the middle part (in orange) is the attention unit, and the right part (in grey & colors) is the computed data. Grey regions in H matrix and w vector are zero values. Numerical subscripts are examples of vector sizes. Lettered subscripts i and i-1 indicate time step.
Legend
label description
100 max sentence length
300 embedding size (word dimension)
500 length of hidden vector
9k, 10k dictionary size of input & output languages respectively.
x, Y 9k and 10k 1-hot dictionary vectors. x → x implemented as a lookup table rather than vector multiplication. Y is the 1-hot maximizer of the linear Decoder layer D; that is, it takes the argmax of D's linear layer output.
x 300-long word embedding vector. The vectors are usually pre-calculated from other projects such as GloVe or Word2Vec.
h 500-long encoder hidden vector. At each point in time, this vector summarizes all the preceding words before it. The final h can be viewed as a "sentence" vector, or a thought vector as Hinton calls it.
s 500-long decoder hidden state vector.
E 500 neuron RNN encoder. 500 outputs. Input count is 800–300 from source embedding + 500 from recurrent connections. The encoder feeds directly into the decoder only to initialize it, but not thereafter; hence, that direct connection is shown very faintly.
D 2-layer decoder. The recurrent layer has 500 neurons and the fully connected linear layer has 10k neurons (the size of the target vocabulary).[7] The linear layer alone has 5 million (500 * 10k) weights -- ~10 times more weights than the recurrent layer.
score 100-long alignment score
w 100-long vector attention weight. These are "soft" weights which changes during the forward pass, in contrast to "hard" neuronal weights that change during the learning phase.
A Attention module — this can be a dot product of recurrent states, or the Query-Key-Value fully connected layers. The output is a 100-long vector w.
H 500x100. 100 hidden vectors h concatenated into a matrix
c 500-long context vector = H * w. c is a linear combination of h vectors weighted by w.

This table shows the calculations at each time step. For clarity, it uses specific numerical values and shapes rather than letters. The nested shapes depict the summarizing nature of h, where each h contains a history of the words that came before it. Here, the attention scores were cooked up to get the desired attention weights.

step x h, H = encoder output
these are 500x1 vectors represented as shapes
s = decoder input to Attention alignment score w = attention weight
= softmax( score )
c = context vector = H*w y = decoder output
1 I Icon-red-diamond.png = vector encoding for "I" - - - - -
2 love Icon-green-square.png = vector encoding for "I love" - - - - -
3 you Icon-blue-circle.png = vector encoding for "I love you" - - - - -
4 - - the decoder state does not exist yet so we use the encoder output h3 to kick off the decoder
Icon-blue-circle.png
[.63 -3.2 -2.5 .5 .5 ...] [.94 .02 .04 0 0 ...] .94 * Icon-red-diamond.png + .02 * Icon-green-square.png + .04 * Icon-blue-circle.png je
5 - - s4 [-1.5 -3.9 .57 .5 .5 ...] [.11 .01 .88 0 0 ...] .11 * Icon-red-diamond.png + .01 * Icon-green-square.png + .88 * Icon-blue-circle.png t'
6 - - s5 [-2.8 .64 -3.2 .5 .5 ...] [.03 .95 .02 0 0 ...] .03 * Icon-red-diamond.png + .95 * Icon-green-square.png + .02 * Icon-blue-circle.png aime

Viewed as a matrix, the attention weights show how the network adjusts its focus according to context.

I love you
je .94 .02 .04
t' .11 .01 .88
aime .03 .95 .02

This view of the attention weights addresses the "explainability" problem that neural networks are criticized for. Networks that perform verbatim translation without regard to word order would have a diagonally dominant matrix if they were analyzable in these terms. The off-diagonal dominance shows that the attention mechanism is more nuanced. On the first pass through the decoder, 94% of the attention weight is on the first English word "I", so the network offers the word "je". On the second pass of the decoder, 88% of the attention weight is on the third English word "you", so it offers"t'". On the last pass, 95% of the attention weight is on the second English word "love", so it offers "aime".

Variants[edit]

There many variants of attention: dot product, query-key-value, hard, soft, self, cross, Luong, and Bahdanau to name a few. These variants recombine the encoder-side inputs to redistribute those effects to each target output. Often, a correlation-style matrix of dot products provides the re-weighing coefficients (see legend).

1. encoder-decoder dot product 2. encoder-decoder QKV 3. encoder-only dot product 4. encoder-only QKV 5. Pytorch tutorial
Both Encoder & Decoder are needed to calculate Attention.[8]
Both Encoder & Decoder are needed to calculate Attention.[9]
Decoder is NOT used to calculate Attention. With only 1 input into corr, W is an auto-correlation of dot products. wij = xi * xj [10]
Decoder is NOT used to calculate Attention.[11]
A FC layer is used to calculate Attention instead of dot product correlation. [12]
Legend
label description
variables X,H,S,T upper case variables represent the entire sentence, and not just the current word. For example, H is a matrix of the encoder hidden state—one word per column.
S, T S = decoder hidden state, T = target word embedding. In the Pytorch Tutorial variant training phase, T alternates between 2 sources depending on the level of teacher forcing used. T could be the embedding of the network's output word, i.e. embedding(argmax(FC output)). Alternatively with teacher forcing, T could be the embedding of the known correct word which can occur with a constant forcing probability, say 1/2.
X, H H = encoder hidden state, X = input word embeddings
W attention coefficients
Qw, Kw, Vw, FC weight matrices for query, key, vector respectively. FC is a fully connected weight matrix.
circled +, circled x circled + = vector concatenation. circled x = matrix multiplication
corr column wise softmax( matrix of all combinations of dot products ). The dot products are xi * xj in variant # 3, hi * sj in variant 1, and column i ( Kw*H ) * column j ( Qw*S ) in variant 2, and column i (Kw*X) * column j (Qw*X) in variant 4. variant 5 uses a fully connected layer to determine the coefficients. If the variant is QKV, then the dot products are normalized by the sqrt(d) where d is the height of the QKV matrices.

See also[edit]

References[edit]

  1. ^ Yann Lecun (2020). Deep Learning course at NYU, Spring 2020, video lecture Week 6. Event occurs at 53:00. Retrieved 2021-12-13.
  2. ^ "Hybrid computing using a neural network with dynamic external memory". Nature. 538 (7626): 471–476. 2016-10-12. Bibcode:2016Natur.538..471G. doi:10.1038/nature20101. ISSN 1476-4687. PMID 27732574. S2CID 205251479.
  3. ^ Vaswani, Ashish; Shazeer, Noam; Parmar, Niki; Uszkoreit, Jakob; Jones, Llion; Gomez, Aidan N.; Kaiser, Lukasz; Polosukhin, Illia (2017-12-05). "Attention Is All You Need". arXiv:1706.03762 [cs.CL].
  4. ^ Ramachandran, Prajit; Parmar, Niki; Vaswani, Ashish; Bello, Irwan; Levskaya, Anselm; Shlens, Jonathon (2019-06-13). "Stand-Alone Self-Attention in Vision Models". arXiv:1906.05909 [cs.CV].
  5. ^ Jaegle, Andrew; Gimeno, Felix; Brock, Andrew; Zisserman, Andrew; Vinyals, Oriol; Carreira, Joao (2021-06-22). "Perceiver: General Perception with Iterative Attention". arXiv:2103.03206 [cs.CV].
  6. ^ Ray, Tiernan. "Google's Supermodel: DeepMind Perceiver is a step on the road to an AI machine that could process anything and everything". ZDNet. Retrieved 2021-08-19.
  7. ^ "Pytorch.org seq2seq tutorial". Retrieved December 2, 2021.
  8. ^ Luong, Minh-Thang (2015-09-20). "Effective Approaches to Attention-based Neural Machine Translation". arXiv:1508.04025v5.
  9. ^ Neil Rhodes (2021). CS 152 NN—27: Attention: Keys, Queries, & Values. Event occurs at 06:30. Retrieved 2021-12-22.
  10. ^ Alfredo Canziani & Yann Lecun (2021). NYU Deep Learning course, Spring 2020. Event occurs at 05:30. Retrieved 2021-12-22.
  11. ^ Alfredo Canziani & Yann Lecun (2021). NYU Deep Learning course, Spring 2020. Event occurs at 20:15. Retrieved 2021-12-22.
  12. ^ Robertson, Sean. "NLP From Scratch: Translation With a Sequence To Sequence Network and Attention". pytorch.org. Retrieved 2021-12-22.

External links[edit]