Ah that makes sense, quadratic scaling is brutal. So with 96gb i'd probably get somewhere around 4-5k total sequence length before hitting the wall, which is still pretty limiting for anything multimodal. Do you do any gradient checkpointing or is that not worth the speed tradeoff at these sizes?
Haven’t tried yet. That’s on the do list. But good suggestion.