logoalt Hacker News

cavisnetoday at 2:42 AM1 replyview on HN

From the paper it was a speedup on the XLA GPU kernel they wrote using Jax, which is probably not SOTA. I don't think Jax even has a official flash attention implementation.


Replies

yarritoday at 3:58 PM

Not sure what “official” means but would direct you to the GCP MaxText [0] framework which is not what this GDM paper is referring to but rather this repo contains various attention implementations in MaxText/layers/attentions.py

[0] https://github.com/AI-Hypercomputer/maxtext