Memory usage increases quadratically with sequence length. Therefore, using shorter sequences during fine-tuning can prevent memory explosions. On my 64GB RAM machine, I'm limited to input sequences of about 2,000 tokens, considering my average output for the fine-tuning task is around 1,000 tokens (~3k tokens total).
Ah that makes sense, quadratic scaling is brutal. So with 96gb i'd probably get somewhere around 4-5k total sequence length before hitting the wall, which is still pretty limiting for anything multimodal. Do you do any gradient checkpointing or is that not worth the speed tradeoff at these sizes?