While this is a bit too harsh - and the solution is naive at best - the problem is real.
The idea of bitwise reproducibility for floating point computations is completely laughable in any part of the DL landscape. Meanwhile in just about every other area that uses fp computation it's been the defacto standard for decades.
To frameworks somehow being even worse. Where the best you can do is order the frameworks in terms of how bad they are - with tensorflow being far down at the bottom and jax being (currently) at the top - and try to use the best one.
This is a huge issue to anyone serious about developing novel models and I see no one talking about it, let alone trying to solve it.
While this is a bit too harsh - and the solution is naive at best - the problem is real.
The idea of bitwise reproducibility for floating point computations is completely laughable in any part of the DL landscape. Meanwhile in just about every other area that uses fp computation it's been the defacto standard for decades.
From NVidia not guaranteeing bitwise reproducibility even on the same GPU: https://docs.nvidia.com/deeplearning/cudnn/backend/v9.17.0/d...
To frameworks somehow being even worse. Where the best you can do is order the frameworks in terms of how bad they are - with tensorflow being far down at the bottom and jax being (currently) at the top - and try to use the best one.
This is a huge issue to anyone serious about developing novel models and I see no one talking about it, let alone trying to solve it.