logoalt Hacker News

valinetoday at 8:27 AM1 replyview on HN

So let’s start with a really simple decoder transformer with a single layer and single attention head, and train it to predict the next token in a sequence of text. To predict the next token you need a few things: a query for the very last token in the sequence, and a key and value for every prior token. You take your query and compute a dot product with every prior key (two large vectors in, scaler attention score out). That scaler attention score first goes through softmax, and then becomes the weight you use to compute a weighted average of your values, new value goes through the mlp, mlp output is projected into the logits from which you sample your next token (that’s the general idea at least skipped a few steps).

The last query in the sequence will be new for every new token you predict, but the set of prior keys and values stay the same, ie keys and values are reusable. The key value cache gets bigger and bigger for each new token you add to the sequence, and that’s where compression comes in. You have to store the keys and values in vram, and you’d like to keep the size down by not storing the raw uncompressed tensors. To make this work well your compression needs two things: it needs to be fast so that you can compress and decompress on the fly, and it needs to play well with softmax attention. Prior attempts at compression usually suck at one or the other, either the speed to decompress is too slow and your token/s takes a hit, or you lose important precision and the model output quality suffers. The claim in the paper is that they’ve made progress on both.


Replies

edg5000today at 8:35 AM

So limiting max context length also reduces VRAM needs a bit? If cache is 20% of total, 1/10th of context as a limit would mean 18% total memory reduction.

show 1 reply