Quick heads-up that these days I recommend https://github.com/patrick-kidger/jaxtyping over the older repository you've linked there.
I learnt a lot the first time around, so the newer one is much better :)
Ah, I would have never thought jaxtyping supports torch :)
Ah, I would have never thought jaxtyping supports torch :)