Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-05-14"
source-sha: d37b1d8adbf6e18b17e125cca761a6eb2ccd9041
synced-at: "2026-06-19"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
Expand Down
17 changes: 10 additions & 7 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,10 @@ Numba 非常高效地处理了这个顺序运算。
```{code-cell} ipython3
cpu = jax.devices("cpu")[0]

@partial(jax.jit, static_argnames=("n",), device=cpu)
# Pin the input to the CPU, which keeps the whole computation there
x0_cpu = jax.device_put(0.1, cpu)

@partial(jax.jit, static_argnames=("n",))
def qm_jax_fori(x0, n, α=4.0):

x = jnp.empty(n + 1).at[0].set(x0)
Expand All @@ -481,7 +484,7 @@ def qm_jax_fori(x0, n, α=4.0):
```

* 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。
* 我们通过 `device=cpu` 将计算固定到 CPU,因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。
* 我们通过 `jax.device_put` 将输入固定到 CPU(从而使整个计算保持在 CPU 上),因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。

重要提示:虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新!

Expand All @@ -490,7 +493,7 @@ def qm_jax_fori(x0, n, α=4.0):
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -500,7 +503,7 @@ with qe.Timer():
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -514,7 +517,7 @@ JAX 对于这种顺序运算也相当高效!
这种替代方案可以说更符合 JAX 的函数式风格——尽管语法难以记忆。

```{code-cell} ipython3
@partial(jax.jit, static_argnames=("n",), device=cpu)
@partial(jax.jit, static_argnames=("n",))
def qm_jax_scan(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
Expand All @@ -531,7 +534,7 @@ def qm_jax_scan(x0, n, α=4.0):
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -541,7 +544,7 @@ with qe.Timer():
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand Down
Loading