logoalt Hacker News

hirako2000today at 2:30 PM1 replyview on HN

The claims of the article assumes far more compute and far more VRAM..while the trick enables less back and forth, they don't eliminate it.

I doubt you meant 50M. Rather 50B?

You can only give it a try, but don't get your hopes high on a large context. If their technique works I would guess 8096k context limits would still OOM. 2048 maybe.

I'm extrapolating based on my experiment without this paper's trick to leverage the system memory.


Replies

kouteiheikatoday at 2:53 PM

> You can only give it a try, but don't get your hopes high on a large context.

You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel.

For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes.

show 2 replies