Agree it's about time to make this built in. Other functional packages like JAX [0] are already using the concept but they build it into their library from scratch.
[0] https://flax.readthedocs.io/en/latest/api_reference/flax.cor...