logoalt Hacker News

bbminnertoday at 4:44 AM0 repliesview on HN

I still consider jax.vmap to be a little miracle: fn2 = vmap(fn, (1,2)), if i remember correctly, traverces the computation graph of fn and correctly broacasts all operations in a way that ensures that fn2 acts like fn applied in a loop across the second dimension of the first argument (but accelerated, has auto-gradients, etc).