Skip to main content

JAX

  • J - Just-in-time
  • A - Autograd
  • X - XLA - Accelerated Linear Algebra

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

JAX a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

  • JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
  • JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
  • JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.
  • JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.