We Need Context: From Last Word to Weighted Sum
Recap: The Last-Word-Only Model
In Module 2, we built a simple next-word predictor. It takes the embedding of the last word, multiplies by a weight matrix, and produces probabilities:
embedding = get_embedding(last_word) logits = embedding @ lm_head probs = softmax(logits)
This works for simple cases, but it ignores everything before the last word. Consider:
- "the big cat sat on the [?]" → model only sees "the"
- "the small dog ran to the [?]" → model also only sees "the"
Both inputs look identical to the model. It can't distinguish between them because it never sees "big cat" versus "small dog".
First Attempt: Average All Embeddings
The simplest fix: instead of using just the last word's embedding, average all the embeddings together.
context = (embedding("the")
+ embedding("big")
+ embedding("cat")
+ embedding("sat")) / 4
Now feed context to the LM head instead of just the last embedding:
logits = context @ lm_head probs = softmax(logits)
This is better! "the big cat sat" and "the small dog ran" now produce different context vectors because they contain different words.
Why Averaging Isn't Enough
Simple averaging has two major problems:
All words treated equally. In "the big cat sat", the word "the" gets the same weight as "cat". But "cat" is clearly more important for predicting what comes next. Function words like "the" and "on" shouldn't contribute as much as content words like "cat" and "big".
Word order is ignored. Consider:
- "dog bites man"
- "man bites dog"
These have completely different meanings, but their average embeddings are identical - same words, just in different order.
We'll address word order later with positional encodings. For now, let's focus on the first problem: not all words should matter equally.
Weighted Sum: Some Words Matter More
Instead of equal averaging, let each word have its own importance weight:
context = w_the * embedding("the")
+ w_big * embedding("big")
+ w_cat * embedding("cat")
+ w_sat * embedding("sat")
For example, if we're predicting after "the big cat sat on the", we might want:
context = 0.05 * embedding("the")
+ 0.35 * embedding("big")
+ 0.45 * embedding("cat")
+ 0.10 * embedding("sat")
+ 0.03 * embedding("on")
+ 0.02 * embedding("the")
Now "cat" and "big" dominate the context, while "the" and "on" barely contribute.
Requirements for weights:
- Non-negative (negative weights don't make sense for mixing)
- Sum to 1 (think of them as percentages of attention)
This is exactly what we want! The model can focus on the words that matter most for the current prediction.
The Missing Piece: Where Do Weights Come From?
We've established that weighted sums are the right idea. But we hand-picked those weights (0.35 for "big", 0.45 for "cat", etc.). That won't work in practice.
The core question: How does the model automatically decide which words deserve high weights?
We need weights that:
- Depend on the current position (what word we're predicting from)
- Depend on each candidate word (what it offers)
- Are learned, not hand-coded
The next article shows how the model transforms embeddings into "queries" and "keys" that compute these weights automatically. Queries ask "what am I looking for?" and keys answer "what do I offer?" - their interaction determines the attention weights.