JAX is a system for high-performance machine learning research. It offers the familiarity of Python+NumPy and the speed of hardware accelerators, and it enables the definition and the composition of function transformations useful for machine-learning programs. In particular, these transformations include automatic differentiation, automatic batching, end-to-end-compilation (via XLA), and even parallelizing over multiple accelerators. They are the key to JAX's power and to its relative simplicity.
JAX had its initial open-source release in December 2018 (https://github.com/google/jax). It is currently being used by several groups of researchers for a wide range of advanced applications, from studying spectra of neural networks, to probabilistic programming and Monte Carlo methods, and scientific applications in physics and biology. Users appreciate JAX most of all for its ease of use and flexibility.
This talk is an introduction to JAX and a description of some of its technical aspects. It also includes a discussion of current strengths and limitations, and of our plans for the near future, which may be of interest to potential users and contributors.
Questions? Email firstname.lastname@example.org