logoalt Hacker News

kouteiheikatoday at 2:53 PM2 repliesview on HN

> You can only give it a try, but don't get your hopes high on a large context.

You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel.

For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes.


Replies

gavinraytoday at 4:14 PM

I'd not heard of this before, quick search turned up this 2025 post which suggests "fused cross-entropy loss" kernel was integrated into PyTorch:

https://pytorch.org/blog/peak-performance-minimized-memory/

  > "The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. "
Is this the same thing as you discuss above?
show 1 reply
hirako2000today at 4:01 PM

Activation would still require gigabytes for a few kb context.

There are plenty of techniques to optimise. But the question is what can an rtx 3080 train before OOM. The answer is not that much.

Can barely do quantized fine tuning. Even then, small context.

show 1 reply