⚠️ A valid COVID certificate must be presented on site to enter the event. ⚠️
With JAX you can use NumPy to express your ML (or other!) computations as pure functions. These functions can then be transformed (higher-order derivatives, single-example to batch transformations, multi-device data-parallel processing, ...) and JIT-compiled to highly optimized GPU/TPU code using XLA.
In this workshop we briefly talk about the inner workings of JAX and discuss JAX's function transformations and relationship to XLA. We will then implement basic linear algebra computations and simple neural networks in a Jupyter/Colab notebook. In the last part we will implement more advanced models with Flax (a neural network library built on top of JAX), and train these models in the cloud.
After this workshop you will
- Run JAX on Jupyter, Colab, cloud.
- Know to use JAX function transformations.
- Have built a simple model using pure JAX.
- Have trained a more realistic model using Flax.
Advanced level
Required knowledge:
- Experience building and training ML models using other frameworks.
Required equipment:
- Laptop with working Browser.
- If you want to run code locally : up-to-date Jupyter installation.
- If you want to run code on cloud other than Google Cloud : access to cloud VM.