0. goal
to get a feel of what mechanistic interpretability is like, intuitively understanding QK and OV* circuits.
[github] | [demo site] | [video]
ps: * i write OV circuit because usually there is OV circuit while doing mech-int but the tiny example i use, matrix O is not present so technically it's just V. many such hacks exist. if you're a pedantic, check section 3.
1. motivation
recently, i worked with a professor on saliency maps for interpretable vision transformers. this concept of making models which not only give the answer to what, but also sheds light on the question of how, appealed to me. because from a human perspective, knowing the how is what separates us from other species (ig, i'm not an expert on species, but to the best of my knowledge i don't think other animals know how digestion work scientifically, but they DO eat). apart from this, in present times, artificial intelligence has increasingly become a widely used but misunderstood, feared black box. i feel this field of work can help to alleviate these concerns to a large extent.
my search into the rabbit hole of explainable models eventually led me to Neel Nanda and his work on mechanistic interpretability (its a mouthful so will call it mechint from now on). the sort of neurosurgical analysis of models and particularly mapping the attention heads to a human-understandable algorithm/subtask made me very engrossed in this topic.
i started with the Anthropic paper called Mathematical Analysis of Transformer circuits which talk about QK and OV circuits, and accompanied this with Neel's video on the same. the video did an outstanding work of explaining the paper (i need to interrupt here and mention i still haven't gone through the entirety of the paper/video) but i had a mental scratch to itch.
mechint can be viewed as reverse engineering a trained model to see which attention head performs which subtask. what if i did the reverse of that? but you may ask, isn't that just training a model? yes. but with a twist. i would define a model where the attention layers would do a subtask which i define for it, and i shall provide the weights using nothing but pen and paper. hence, the pen is mightier than the GPU.
2. designing the architecture
broadly, these are the following things i need to decide about the architecture
- how to convert tokens (17+25) to embeddings (including value of $d_{\text{model}}$)
- how many layers of attention to stack on one top of another
- how to convert final output into a proper sum
if you're keen you will notice there is no mention of the ffn or the layer-norm, which are usually present in a standard transformer. 2 reasons: firstly, it will increase the number of weights i have to manage by hand. secondly, from a mechint pov, i wish to understand how attention layers can be dedicated to each subtask, so just having them combined to do the task will align more with my end goal.
for the tokens-to-embedding procedure, i settled on $d_{\text{model}} = 3$. why? i visualize three distinct information that require storage. first is the value of the digit itself, second is some sort of information about the position, third is about if there is carry or not. for position, i used $-1.0$ for tokens at even positions and $+1.0$ for tokens at odd positions (one-indexing of positions). the digit is copied as it is, and initially the dimension designated for carry is set to $0.0$ for all embeddings.
for layers, i chose 2 layers. why? well adding 2 two-digit numbers as taught in elementary school, has 2 distinct subtasks, with a 3rd linking task: first we add the units place and get a digit for the final result's units digit, propagate the carry (if any), then add the tens place to get final digit for the final result's tens place (combined w carry). this approach overlooks the part what if there is a carry in the tens place addition, this we will see later how it happens.
finally, there's the question of the readout: how do we get a human-readable number out of these vectors? in a real transformer, you'd have a massive unembedding matrix that maps the final hidden state to a vocabulary of tokens. since i'm doing this by hand and didn't want to write out a huge matrix, i used a bit of a "python hack." instead of a proper linear layer, i just reach into the last token's vector, grab the sums we've accumulated, and use the standard + and * operators in python to reconstruct the number. it's technically cheating the transformer purists, but it helps isolate the actual logic of the attention layers without getting bogged down in the readout. my goal initially was to see how attention layers can come up with solving the subtasks of handling units place and tens place separately, so if i can prove that the layers do those subtasks properly, no harm in using the + operator to get to the final result quickly.
another nuance here is the sequence itself. i decided to omit the "+" sign entirely to keep things lean. the model doesn't need a symbol to know it's adding; the architecture itself is the calculator. the sequence is just four digits followed by a special "eos" (end of sentence) token. this eos token is the most important part of the whole setup—it acts like a blank bucket or a sponge. while the digit tokens just sit there holding their values, the eos token is designed to reach back into the past using attention and "accumulate" the sums from the other tokens into its own embedding. by the time we hit the last layer, the eos token's vector is the only thing we actually look at to get our answer.
3. nuances for the pedantic
before we dive into the math, i gotta admit i took some shortcuts. first, the embedding and unembedding (readout) are hardcoded hacks. in a real model, these would be learned weights. second, i didn't use a scaling factor like $\frac{1}{\sqrt{d_k}}$ because my $d_{\text{model}}$ is so small it doesn't really matter. third, i ignored layer norm and mlp blocks entirely. why? because i wanted to see the raw power of the attention heads. also, my softmax might look a bit aggressive with those large weights, but it's just to make the attention "binary" and clear for our human brains to follow. if you're a transformer purist, please don't sue me.
4. pen is mightier than the gpu - figuring out the weights
this is where the magic happens. we are going to play god with these matrices. i needed two layers because addition is a two-step process in my head: handle the units, then handle the tens. let's look at how we force the model to behave using two analogies: the "man on the hill" (qk circuit) and the "librarian" (v circuit).
layer 1: the units circuit
for layer 1, our goal is to get the eos token to "look" only at the units digits (7 and 5 in our 17+25 example).
the man on the hill (qk circuit):
imagine the eos token is a man standing on a hill with a megaphone (the query, q). he's shouting, "i'm looking for anyone with a position value of -1!" down in the valley, all the other tokens have signs (the keys, k).
to make this happen, i set
$$W_{Q1} = \begin{pmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 10 \end{pmatrix}, \qquad W_{K1} = \begin{pmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & -10 \end{pmatrix}$$recall the attention formula where $\text{attn} = \text{softmax}(QK^T)V$ [here the scores are small enough i dont need to worry about explosion hence removing the division by $\sqrt{d_k}$]
why these specific spots? because column 2 is where we stored our position data ($-1$ for units, $+1$ for tens/eos). by putting the $10$ and $-10$ in the $[2,2]$ index of these matrices, we isolate that position info.
when the megaphone (q) meets the sign (k), they multiply. for the units digits, it's $10 \times (-10) \times (-1) = +100$ (a massive match!). for the tens digits, it's $10 \times (-10) \times 1 = -100$ (a massive rejection). the softmax turns these scores into probabilities, making the eos token's attention almost $1.0$ on the units and $0.0$ on everything else. i used $10$ because it's large enough to make the softmax "hard"—it's not a suggestion; it's a command.
the librarian (v circuit):
now that the man on the hill has found his targets, the librarian (the value matrix, v) needs to decide what to do with their info.
$$W_{V1} = \begin{pmatrix} 0 & 0 & 0 \\ 2.0 & 0 & 0 \\ 0 & 0 & 0 \end{pmatrix}$$the $2.0$ is at $[1,0]$. why? column 1 of our input is where the raw digit lives. column 0 is our empty "units register." the librarian is essentially saying: "take whatever is in the digit slot (col 1) and move it to the units register (col 0)."
but why $2.0$ and not $1.0$? remember, attention is an averaging machine. if we attend to two tokens (7 and 5), the model does $(7 + 5) / 2 = 6$. we don't want the average; we want the sum! so we multiply by $2.0$ to counteract that division. boom. the units register in the eos token now holds a $12$.
layer 2: the tens circuit
now we move to the tens digits (1 and 2). the residual stream has carried our $12$ forward, and now layer 2 kicks in.
the man on the hill (qk circuit):
this time, the man on the hill is shouting for a different crowd: "i want everyone with a position of $+1$!"
$$W_{Q2} = \begin{pmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 10 \end{pmatrix}, \qquad W_{K2} = \begin{pmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 10 \end{pmatrix}$$notice $W_{K2}$ changed from $-10$ to $10$. now, a $+1$ position (the tens) multiplied by $10$ and $10$ gives $+100$. the tens tokens are now the stars of the show.
the librarian (v circuit):
the librarian for layer 2 has a different job. we set
$$W_{V2} = \begin{pmatrix} 0 & 0 & 0 \\ 0 & 3.0 & 0 \\ 0 & 0 & 0 \end{pmatrix}$$this time, the librarian is looking at column 1 (the digits) and keeping them in column 1 (the tens register).
but why $3.0$? this is the subtle part. in this layer, the eos token has a position of $+1$, so it actually attends to itself along with the two tens tokens. the attention is split three ways ($\frac{1}{3}$ each). the eos token's digit slot is $0$, while the tens are $1$ and $2$. the average is $(1 + 2 + 0) / 3 = 1.0$. to get back to our sum of $3$, we have to multiply by $3.0$.
by the end of this, our librarian has shelf-spaced everything perfectly. column 0 has the units sum, and column 1 has the tens sum. no training, no gpus, just a very bossy man on a hill and an organized librarian.
5. getting the sum out of the forgeformer brain
now the eos token is "pregnant" with information. its embedding vector at the very end looks something like $\{12.0,\ 3.0,\ 1.0\}$.
- column 0 ($12.0$) is our units sum.
- column 1 ($3.0$) is our tens sum.
- column 2 ($1.0$) is just the leftover position marker.
to get the final number, we just apply standard primary school carry logic.
is $12 \geq 10$? yes.
so carry $= 1$, and units becomes $2$.
then we just do: $(\text{tens\_sum} + \text{carry}) \times 10 + \text{units}$.
$(3 + 1) \times 10 + 2 = 42$.
it's alive!
6. takeaways
the coolest thing about this isn't that a computer added two numbers; my calculator from 1998 can do that. it's that we can see exactly how the "handshake" between keys and queries creates a specific algorithm. its like seeing how the brain of the model is first looking at units place then at tens place.
the qk circuit is the "search engine" which decides who to talk to based on position or value. the v circuit (or ov circuit) is the retreival mechanism which decides what information to actually carry forward. by manually setting these, we proved that transformers aren't just random magic; they are just really fast, high-dimensional spreadsheet workers.
next time you see a massive model like gpt-4, just remember that somewhere inside those billions of parameters, there are probably tiny circuits just like this one, looking for a specific "handshake" to solve a subtask. turns out, the pen really is mightier than the gpu when it comes to understanding.
7. future tasks
- Properly implementing the hundredth carry, right now it cheats.
- Adding FFNs and Layer Norms (if needed to see what they contribute intuitively)
- Generalizing an approach for $n$-digit $+$ $m$-digit, $n \neq m$ (using computer trained models of course)
- Carrying our concepts about what the attention heads actually do and checking with open-source models
- Training one model for the same task as Forgeformer using Pytorch and comparing their weights with mine.
- Doing all this for other trivial tasks