JAX: A Machine Learning Research Library

Nov 7, 2022, 4:30 pm6:00 pm
View location on My PrincetonU
Princeton students, graduate students, researchers, faculty, and staff


Event Description
JAX is a system for high-performance machine learning research. It offers the familiarity of Python+NumPy and the speed of hardware accelerators, and it enables the definition and the composition of function transformations useful for machine-learning programs. In particular, these transformations include automatic differentiation, automatic batching, end-to-end-compilation (via XLA), and even parallelizing over multiple accelerators. They are the key to JAX's power and to its relative simplicity.

JAX had its initial open-source release in December 2018 (https://github.com/google/jax). It is currently being used by several groups of researchers for a wide range of advanced applications, from studying spectra of neural networks, to probabilistic programming and Monte Carlo methods, and scientific applications in physics and biology. Users appreciate JAX most of all for its ease of use and flexibility.

This talk is an introduction to JAX and a description of some of its technical aspects. It also includes a discussion of current strengths and limitations, and of our plans for the near future, which may be of interest to potential users and contributors.

The talk will end around 5:30 PM to allow for 1-on-1 discussions with Peter.

Session format: Presentation and demo with open Q&A

Knowledge prerequisites: Attendees should have some knowledge of Python

Hardware/software prerequisites: None