There are a lot of shortcomings in the current implementation, making it slow (but in my tree is 2x faster as we speak). For instance activations aren't taken in the GPU, kernels are not fused, flash attention is not used, and many other issues. Now I'll focus on that changes to approach PyTorch numbers a little bit more.