JAX is used by almost every large genAI player (Anthropic, Cohere, DeepMind, Midjourney, Character.ai, XAi, Apple, etc.). Its actual market share in foundation models development is something like 80%.
Are there any resources going into detail about why the big players prefer JAX? I've heard this before but have never seen explanations of why/how this happened.
It's all about cost and performance. If you can train a foundation model 2x faster with JAX on the same hardware, you are effectively slashing your training costs by 2x, which is significant for a multi-million dollar training run.
The current SOTA models (GPT4, DALL-E, Sora) were trained on GPUs. The next one (GPT5) will be, too. And the one after that. Besides, only very few people train models that need more than a few hundred H100s at a time, and PyTorch works well at that scale. And when you train large scale stuff the scaling problems are demonstrably surmountable, unlike, say, capacity problems which you will run into if you need a ton of modern TPU quota, because Google itself is pretty compute starved at the moment. Also, gone are the days when TPUs were significantly faster. GPUs have “TPUs” inside them, too, nowadays
No, I am saying, with JAX you train on G.P.U., with a G, and your training runs are >2x faster, so your training costs are 2x lower, which matters whether your training spend is $1k or $100M. You're not interested in that? That's ok, but most people are.
Have you actually tried that or are you just regurgitating Google’s marketing? I’ve seen Jax perform _slower_ than PyTorch on practical GPU workloads on the exact same machine, and not by a little, by something like 20%. I too thought I’d be getting great performance and “saving money”, but reality turned out to be a bit more complicated than that - you have to benchmark and tune.
Also JAX is not just for TPU. It's mainly for GPU. It's usually 2-3x faster than torch on GPU: https://keras.io/getting_started/benchmarks/
Far more industry users of JAX use it on GPU compared to TPU.