logoalt Hacker News

thomasahlelast Sunday at 5:23 PM2 repliesview on HN

You can do that in python using https://github.com/patrick-kidger/torchtyping

looks like this:

    def batch_outer_product(x:   TensorType["batch", "x_channels"],
                            y:   TensorType["batch", "y_channels"]
                            ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)
There's also https://github.com/thomasahle/tensorgrad which uses sympy for "axis" dimension variables:

    b, x, y = sp.symbols("b x y")
    X = tg.Variable("X", b, x)
    Y = tg.Variable("Y", b, y)
    W = tg.Variable("W", x, y)
    XWmY = X @ W - Y

Replies

patrickkidgerlast Sunday at 10:09 PM

Quick heads-up that these days I recommend https://github.com/patrick-kidger/jaxtyping over the older repository you've linked there.

I learnt a lot the first time around, so the newer one is much better :)

show 1 reply
ydjyesterday at 7:00 AM

Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.

show 1 reply