From 671321c593e4ec560f8a904f3a594a15794b8800 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Fri, 19 Jun 2026 12:17:17 +1000 Subject: [PATCH] [numpy_vs_numba_vs_jax] Fix deprecated device= arg on jax.jit JAX has deprecated the device/backend arguments to jax.jit, which emitted two DeprecationWarnings into the lecture's rendered HTML output: DeprecationWarning: backend and device argument on jit is deprecated. You can use jax.device_put(..., jax.local_devices(backend="cpu")[0]) on the inputs to the jitted function to get the same behavior. Pin the computation to the CPU the recommended way -- commit the input to the CPU with jax.device_put -- instead of the deprecated decorator argument. The behaviour (CPU execution for this sequential workload) is unchanged. Co-Authored-By: Claude Opus 4.8 --- lectures/numpy_vs_numba_vs_jax.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index f45df634..b5142c04 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -471,7 +471,10 @@ We'll apply a `lax.fori_loop`, which is a version of a for loop that can be comp ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=("n",), device=cpu) +# Pin the input to the CPU, which keeps the whole computation there +x0_cpu = jax.device_put(0.1, cpu) + +@partial(jax.jit, static_argnames=("n",)) def qm_jax_fori(x0, n, α=4.0): x = jnp.empty(n + 1).at[0].set(x0) @@ -485,7 +488,7 @@ def qm_jax_fori(x0, n, α=4.0): ``` * We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code. -* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. +* We pin the input to the CPU with `jax.device_put` (which keeps the whole computation on the CPU) because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. Important: Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place! @@ -494,7 +497,7 @@ Let's time it with the same parameters: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -504,7 +507,7 @@ Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -521,7 +524,7 @@ although the syntax is difficult to remember. ```{code-cell} ipython3 -@partial(jax.jit, static_argnames=("n",), device=cpu) +@partial(jax.jit, static_argnames=("n",)) def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -538,7 +541,7 @@ Let's time it with the same parameters: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -548,7 +551,7 @@ Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ```