From the paper it was a speedup on the XLA GPU kernel they wrote using Jax, which is probably not SOTA. I don't think Jax even has a official flash attention implementation.
Not sure what “official” means but would direct you to the GCP MaxText [0] framework which is not what this GDM paper is referring to but rather this repo contains various attention implementations in MaxText/layers/attentions.py
Not sure what “official” means but would direct you to the GCP MaxText [0] framework which is not what this GDM paper is referring to but rather this repo contains various attention implementations in MaxText/layers/attentions.py
[0] https://github.com/AI-Hypercomputer/maxtext