NorthGradient
Start reading
Attention Is All You Need Browse lessons

Attention Is All You Need · Attention Is All You Need · 3 min read

Multi-head attention

A single attention operation computes one set of relationships across the sequence. But a sentence contains many different kinds of relationships simultaneously: grammatical agreement, causal links, coreference, proximity, and more. Multi-head attention runs several attention operations in parallel, each with its own learned projections, so the model can capture multiple relationship types at once.

One head vs eight heads

Left: a single attention block receives four input vectors and produces one output. Right: eight parallel attention blocks receive the same inputs and produce eight outputs, labeled 8 perspectives in parallel, which are then concatenated and projected.
Left: a single attention block receives four input vectors and produces one output. Right: eight parallel attention blocks receive the same inputs and produce eight outputs, labeled 8 perspectives in parallel, which are then concatenated and projected.

In the base transformer, the model runs h=8h = 8 attention heads in parallel. Each head receives the same input but applies its own learned query, key, and value projection matrices. Because the projections differ, each head attends to the sequence in a different way. The outputs of all eight heads are concatenated and passed through a single linear projection to produce the final multi-head attention output.

To keep the total computation comparable to a single full-dimension attention, each head operates in a reduced dimension dk=dmodel/h=512/8=64d_k = d_{\text{model}} / h = 512 / 8 = 64. So eight heads of dimension 64 together cover the same parameter budget as one head of dimension 512.

What different heads learn

Three attention heads each draw a different arc from the word "it": Head 1 connects it to "animal" (coreference), Head 2 connects it to "tired" (cause), Head 3 connects it to "street" with a cross (ruled out).
Three attention heads each draw a different arc from the word "it": Head 1 connects it to "animal" (coreference), Head 2 connects it to "tired" (cause), Head 3 connects it to "street" with a cross (ruled out).

Each head specializes during training without being explicitly told what to look for. In practice, different heads tend to capture different linguistic relationships. For the word “it” in the sentence “The animal did not cross the street because it was too tired”, one head may learn to link “it” back to “animal” (resolving what “it” refers to), another may link “it” to “tired” (the reason it did not cross), and another may learn that “street” is not the correct referent. No supervision directs this specialization: it emerges from the training objective alone.

Concatenation and final projection

Eight head output vectors are grouped by a bracket labeled concatenate, forming one tall vector, which passes through a linear projection box W_O to produce the multi-head attention output.
Eight head output vectors are grouped by a bracket labeled concatenate, forming one tall vector, which passes through a linear projection box W_O to produce the multi-head attention output.

After all eight heads compute their outputs, the results are concatenated into a single long vector of dimension h×dk=8×64=512h \times d_k = 8 \times 64 = 512. This concatenated vector is then multiplied by the output projection matrix WOW^O, which has shape 512×512512 \times 512. The projection mixes information across all heads and produces the final multi-head attention output, which has the same shape as the original input. This shape invariance is what allows the same block design to be stacked repeatedly in the encoder and decoder.