You can do that in python using https://github.com/patrick-kidger/torchtyping
looks like this:
def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:
return x.unsqueeze(-1) * y.unsqueeze(-2)
There's also https://github.com/thomasahle/tensorgrad which uses sympy for "axis" dimension variables: b, x, y = sp.symbols("b x y")
X = tg.Variable("X", b, x)
Y = tg.Variable("Y", b, y)
W = tg.Variable("W", x, y)
XWmY = X @ W - YIs there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.
Quick heads-up that these days I recommend https://github.com/patrick-kidger/jaxtyping over the older repository you've linked there.
I learnt a lot the first time around, so the newer one is much better :)