The JAX AI Stack extends the core JAX numerical library into an end-to-end, open-source platform designed for training and deploying machine learning models at extreme scale on TPUs and GPUs. Its modular design lets teams combine only the components they need, from core model development tools to data pipelines, quantization, alignment, and inference runtimes.
Across the AI ecosystem, leading organizations such as Anthropic, xAI, and Apple rely on JAX as a foundation for building state-of-the-art language and vision models, which validates the maturity and flexibility of this stack for both research and production. By aligning closely with Google Cloud TPUs, the stack delivers strong performance, high utilization, and attractive total cost of ownership for large-scale workloads.
Architectural philosophy
The JAX AI Stack follows a modular, loosely coupled architecture, where each library focuses on a single responsibility such as model definition, optimization, checkpointing, or data loading. This approach allows developers to assemble a tailored ML stack that fits their workloads without being locked into a monolithic framework.
A key design goal is to offer a continuum of abstraction, from high-level automation that accelerates experimentation to low-level controls that unlock every last microsecond of performance. As new algorithms and kernels emerge, they can be integrated or swapped in quickly, enabling teams to keep pace with rapidly evolving AI research.
Core JAX AI Stack libraries
At the center of the ecosystem is a set of four core libraries—JAX, Flax, Optax, and Orbax—packaged together as the jax-ai-stack metapackage for easy installation. Installing this core with a single command provides a solid foundation for authoring, optimizing, and stabilizing large-scale training workflows.
- JAX provides the accelerator-oriented array programming model, combining composable function transformations with XLA compilation to scale workloads across diverse hardware and clusters.
- Flax adds an ergonomic, object-style API for building and modifying neural networks, while still leveraging JAX’s functional and performance characteristics.
- Optax offers a rich collection of composable gradient and optimization primitives, so teams can express advanced training strategies without manually managing optimizer state.
- Orbax delivers “any-scale” checkpointing, including asynchronous and distributed checkpoints designed to protect long-running, multi-node training runs from hardware failures.
Infrastructure: XLA and Pathways
Under the hood, the stack relies on XLA and Pathways to scale from a single accelerator to thousands of TPUs or GPUs. XLA is a domain-specific, hardware-agnostic compiler that performs whole-program analysis, fusing operations and optimizing memory layouts to deliver strong out-of-the-box performance for new architectures without hand-written kernels.
Pathways acts as the unified runtime for large-scale distributed computation, allowing developers to write code as if they are targeting a single powerful machine while it orchestrates work across tens of thousands of chips. It handles fault tolerance, recovery, and automation, which is essential when training frontier models with massive compute footprints.
Advanced performance and kernel tooling
For teams that need to exceed what automated compilation alone can deliver, the stack includes specialized tools such as Pallas and Tokamax. Pallas extends JAX with a Pythonic way to write low-level custom kernels for TPUs and GPUs, giving direct control over memory layouts, tiling, and parallelism.
Tokamax complements Pallas as a curated library of state-of-the-art kernels—such as advanced attention implementations—that can be dropped into models for immediate performance gains. Together, they offer both a framework for authoring custom kernels and a ready-made catalog of high-performance implementations tuned for modern accelerator hardware.
Quantization and data pipelines
As models grow, quantization becomes a critical lever for reducing memory footprint and improving training and inference efficiency, and the stack addresses this with Qwix. Qwix is a comprehensive, JAX-native quantization library that supports techniques such as QLoRA, quantization-aware training, and post-training quantization across XLA and on-device runtimes with minimal changes to model code.
For data ingestion, Grain provides a deterministic, high-performance data loading library that avoids pipeline bottlenecks in large-scale training. Grain integrates with Orbax so that the exact state of the data pipeline can be checkpointed alongside the model, delivering bit-for-bit reproducibility even after massive jobs are restarted.
Training frameworks for LLMs and diffusion models
On top of the core stack, MaxText and MaxDiffusion offer production-ready starting points for training large language models and diffusion models on Google Cloud TPUs. These frameworks are optimized for goodput and Model FLOPs Utilization, using the best-in-class JAX libraries such as Optax, Orbax, Qwix, and Tunix to deliver robust scaling behavior out of the box.
By adopting these reference implementations, teams can focus on modeling and data rather than low-level infrastructure, while still achieving competitive performance and cost efficiency. The shared model code also supports efficient serving workflows, bridging the gap between training and inference.
Alignment and post-training with Tunix
Once models are pre-trained, they typically require alignment and post-training to meet application and safety requirements, and this is where Tunix fits. Tunix is a JAX-native alignment and reinforcement learning library that implements techniques such as supervised fine-tuning with LoRA or Q-LoRA, as well as GRPO, GSPO, DPO, and PPO in a unified interface.
MaxText integrates directly with Tunix to provide a scalable, high-performance post-training pipeline tailored to Google Cloud customers. This integration makes it easier to run large-scale alignment experiments on Cloud TPUs without building custom RL or fine-tuning tooling from scratch.
Inference on TPUs with vLLM
For serving, the stack supports a dual-path deployment strategy centered on vLLM integrated with TPU backends. vLLM-TPU is a high-performance inference stack that runs PyTorch and JAX large language models efficiently on Cloud TPUs, achieving significant throughput gains by leveraging JAX primitives and TPU-specific optimizations.
The unified vLLM TPU backend lowers models through a single JAX-to-XLA path, while features such as prefix caching, speculative decoding, structured decoding, and quantized KV cache are tailored to Trillium and v5e TPU generations. This lets teams deploy production LLMs on Google Cloud with strong performance and a familiar vLLM development experience.
From research to production on TPUs
Taken together, the JAX AI Stack delivers a vertically integrated path from experimentation to production on Google Cloud TPUs. Researchers can prototype with JAX and Flax, scale training with Optax, Orbax, XLA, and Pathways, fine-tune and align with Tunix, and finally deploy with vLLM TPU for high-throughput inference.
Because each component is modular and open-source, teams can adopt the entire stack or selectively integrate specific libraries into existing workflows, whether they are building new foundation models or optimizing established production systems. This flexibility, combined with the performance of Cloud TPUs, positions the JAX AI Stack as a strong choice for organizations looking to industrialize their AI workloads.
Read more such articles from our Newsletter here.


