Workshop / Overview

⚠️ A valid COVID certificate must be presented on site to enter the event. ⚠️

With JAX you can use NumPy to express your ML (or other!) computations as pure functions. These functions can then be transformed (higher-order derivatives, single-example to batch transformations, multi-device data-parallel processing, ...) and JIT-compiled to highly optimized GPU/TPU code using XLA.

In this workshop we briefly talk about the inner workings of JAX and discuss JAX's function transformations and relationship to XLA. We will then implement basic linear algebra computations and simple neural networks in a Jupyter/Colab notebook. In the last part we will implement more advanced models with Flax (a neural network library built on top of JAX), and train these models in the cloud.

Workshop / Outcome

After this workshop you will

  • Run JAX on Jupyter, Colab, cloud.
  • Know to use JAX function transformations.
  • Have built a simple model using pure JAX.
  • Have trained a more realistic model using Flax.

Workshop / Difficulty

Advanced level

Workshop / Prerequisites

Required knowledge:

  • Experience building and training ML models using other frameworks.

Required equipment:

  • Laptop with working Browser.
  • If you want to run code locally : up-to-date Jupyter installation.
  • If you want to run code on cloud other than Google Cloud : access to cloud VM.

Track / Co-organizers

Andreas Steiner

Software Engineer, Google

Avital Oliver

Software Engineer, Google

AMLD EPFL 2021 / Workshops

Towards ethical AI – practical tools for responsible data scientists

With Johan Rochel & Lea Strohm

10:00-11:30 November 10Online

How to make your NLP system multilingual

With Adam Bittlingmayer & Nerses Nersesyan

10:00-12:00 March 02Online

Deep Learning-Driven Text Summarization & Explainability with Reuters News Data

With Nadja Herger, Nina Hristozova & Andreea Iuga

15:00-17:30 March 02Online

AMLD / Global partners