The Gray Scott School

Day 5 — Python on CPU

June 26, with Alice Faure, Jean-Marc Colley, Sébastien Valat and Nabil Garroum: profile Python, vectorize with NumPy, compile with Numba, then trace with JAX — up to ×18 without leaving Python.

June 26, 2026 · Speakers: Alice Faure, Jean-Marc Colley, Sébastien Valat & Nabil Garroum · Marcel Vivargent Auditorium + satellite sites (including CINERI). The hands-on is a series of six Jupyter notebooks (GrayScott2026/day-5/CPU/tutorial/, solutions included) — the GPU/ folder waits for Day 7.

Morning session — measure, then vectorize

1. Profile first — time and memory

True to Day 1's rule, the first notebook (1_Optimization) optimizes nothing: it measures. Timing (timeit, cProfile), then — the day's specialty, the slides are literally called gray-scott-python-memmemory profiling with tracemalloc: in Python every temporary array is an allocation, and the naive Gray-Scott creates several per time step.

2. The GIL — why threads do not save Python

The Global Interpreter Lock serializes pure Python threads. Real CPU parallelism goes through libraries that release the GIL during native computation — NumPy, Numba's prange, XLA under JAX — or through multiprocessing. None of today's speedups fights the GIL: they all dive below it, into compiled code.

3. From loops to arrays

Notebooks 2_Numpy and 3_Python_Implementation: the Laplacian is written with NumPy slices (u[:-2, 1:-1] + u[2:, 1:-1] + … − 4*u[1:-1, 1:-1]) — zero Python loops, the iterations move into NumPy's C loops. This is the day's benchmark baseline.

Afternoon session — compiling Python

4. Numba — compile the loop you already have

Notebook 4_Numba_Implementation: keep the explicit loop, add @njit, and LLVM compiles the function on first run. Ideal when the algorithm naturally is a loop (stencils!) — and prange parallelizes it.

5. JAX — trace, then let XLA fuse

Notebook 5_JAX, the day's centerpiece. JAX trades discipline for speed: immutable arrays (u.at[i, j].set(v) instead of assignment), no index checking (silent errors lurk), and above all composable transformationsjit, vmap, grad — that only work on pure functions. The machinery behind jax.jit:

Python functionpureTracerabstract valuesjaxprintermediate prog.XLAcompiles + fusescached binaryreused as-isnext calls: straight to the binary — the tracing cost is paid only onceimmutable arraysdeclared static argsfixed shapes
jax.jit: the function is traced once with abstract values, compiled by XLA, then served from cache — hence the purity constraints

The notebook details the constraints: static arguments to declare (static_argnums), fixed shapes (every new shape re-traces), debugging via jax.debug, and the control-flow operators (lax.cond, lax.fori_loop) that replace if/for inside traced code.

6. Porting Gray-Scott to JAX

Notebook 6_JAX_Implementation: two competing versions — the generic stencil (any 3×3 convolution) and the specialized 3×3 stencil (the nine terms written by hand, which XLA fuses into a single kernel). They are the benchmark's two JAX columns.

7. The verdict — one Gray-Scott, four speeds

Official numbers from the repo (CPU/Benchmarks.md), 32×1000 iterations:

CPUNumPyNumbaJAX (generic)JAX (3×3)
Intel Xeon Silver 4210R7800 s3257 s1031 s377 s
AMD EPYC 73132545 s1219 s386 s141 s
NumPyC loops over whole arrays×1Numba@njit — the Python loop JIT-compiled×2.1JAX (generic)trace + XLA: operations get fused×6.6JAX (3×3)kernel specialized for the stencil×18measured relative speeds — EPYC 7313, 32×1000 iterations
The day’s ladder: each rung swaps more interpreter for compiled code — up to ×18 without leaving Python
Gray-Scott Python: NumPy / Numba / JAX (32×1000 iterations)
Xeon 4210R EPYC 7313
NumPy
7800 s
2545 s
Numba
3257 s
1219 s
JAX (generic)
1031 s
386 s
JAX (3×3)
377 s
141 s

Official numbers from the course repo (CPU/Benchmarks.md). Shorter = better.

8. The bridge to the GPU

This implementation is the reference that Day 7 ports to the accelerator: JAX replays the same code on GPU, joined by CuPy and cuNumeric. It is also exactly the approach of the SenLand project — profile PyTorch, port to JAX, compare honestly.

The hands-on — GrayScott2026/day-5/CPU/

Three ways to follow, your pick:

# 1) locally, environment pinned by pixi
git clone https://gitlab.in2p3.fr/alice.faure/gray-scott-python.git
pixi run jupyter-lab           # opens the tutorial/ notebooks

# 2) on the MUST cluster (LAPP): https://jupyter.must-dc.cloud
#    → "Gray-Scott Revolutions" → "Python CPU"

# 3) in a container (apptainer / podman / docker): the course vscode image

tutorial/ sets the exercises, solutions/ corrects them, scripts/gray_scott_utils.py provides shared I/O, and results/ ships a reference simulation (simulation.h5 + video).

On video — the official replay

Replay — Python On CPU (Gray Scott Thursdays)

Sources & official material

Copyright © 2026