Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.
I doubt it, since jaxtyping supports some quite advanced stuff:
def full(size: int, fill: float) -> Float[Array, "{size}"]: return jax.numpy.full((size,), fill) class SomeClass: some_value = 5 def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]: return jax.numpy.full((self.some_value + 3,), fill)
I doubt it, since jaxtyping supports some quite advanced stuff: