logoalt Hacker News

ydjyesterday at 7:00 AM1 replyview on HN

Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.


Replies

thomasahleyesterday at 11:33 AM

I doubt it, since jaxtyping supports some quite advanced stuff:

    def full(size: int, fill: float) -> Float[Array, "{size}"]:
        return jax.numpy.full((size,), fill)

    class SomeClass:
        some_value = 5

        def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
            return jax.numpy.full((self.some_value + 3,), fill)