This design implicitly does something similar to something that I sometimes think conventional transformers should try: allowing later layers to query the KV data from earlier layers. As far as I can tell, with a conventional transformer, if a layer (and presumably higher-level-thinking) layer wants wants to take input from earlier tokens from something lower down, it needs to get it from the output and “remember” it by itself instead of just reading it directly.
But suppose an extra attention head were added that queried the KV data from lower layers. At the very least, I imagine this might cleanly solve the STRAWBERRY problem: whatever layer has figured out that the prompt wants to count instances of R could attend to lower layers that actually perceive those Rs.
> extra attention head were added that queried the KV data from lower layers
Isn't this sort of similar to latent looping? E.g. [1]. But actually as [2] argues, even that wasn't a good experiment because it used the very last hidden state, which is too close to the logits and loses most of the rich embedding structure. Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
[1] https://arxiv.org/abs/2412.06769
[2] https://snimu.github.io/2025/03/30/multi-layer-language-head...
This architecture does not allow later layers to directly query KV data from earlier layers. Each iteration of the loop uses the same layer parameters, so the KV data in later layers may well end up being the same, but only if the model stops changing it in response to other tokens in the context. Which is also something a traditional multi-layer transformer could do. (But might not end up doing due to lack of corresponding inductive bias.)
None of this helps with the strawberry problem, where the very first layer already gets a tokenized representation, so there is no layer that "actually perceives those Rs."