Skip to content

Add O(J^2)-per-test-point predictive mean/variance from the QSM Cholesky#272

Open
dfm wants to merge 2 commits into
mainfrom
qsm-fast-predict
Open

Add O(J^2)-per-test-point predictive mean/variance from the QSM Cholesky#272
dfm wants to merge 2 commits into
mainfrom
qsm-fast-predict

Conversation

@dfm

@dfm dfm commented Jun 12, 2026

Copy link
Copy Markdown
Owner

celerite's Cholesky carries a small $J\times J$ matrix $f_k$ that it updates at every step. Usually it is just scratch internal to the factorization, but it is also the accumulator that the predictive variance needs at each test point. Once the factorization is done, the predictive mean and variance both follow from it, and the variance only costs one extra backward pass. The mean was already computed this way; the variance is the new part, and it is nearly free.

The idea comes from the state-space (Kalman filter and smoother) picture of these kernels. There, prediction reads the variance off the filter covariance, which evolves by a Riccati recursion that turns out to match the celerite carry. That picture is only motivation; everything below is written with the quasiseparable generators alone.

Setup

Sort the training inputs $t_0,\dots,t_{N-1}$. The quasiseparable representation has low-rank generators $p_k,q_k\in\mathbb R^J$, transition matrices $a_k\in\mathbb R^{J\times J}$, and a diagonal $d_k$. tinygp builds them from a state-space model with observation vector $h_k=h(t_k)$, transition $a_k=\Phi(t_{k-1},t_k)$ (and $a_0=I$), and stationary covariance $P_\infty$:

$$q_k=h_k,\qquad p_k=a_k^\top P_\infty h_k,\qquad d_k=h_k^\top P_\infty h_k$$

with the observation noise $\eta_k^2$ added to $d_k$. Writing $\Psi_{m,j}=a_m a_{m-1}\cdots a_{j+1}$ with $\Psi_{j,j}=I$ for the propagator from point $j$ to $m$, the lower triangle of the kernel is $K_{ij}=p_i^\top \Psi_{i-1,j} q_j$ for $i\gt j$.

In the language of the celerite papers, $(d,p,q,a)$ are the semiseparable factors $(D,U,V,P)$: the diagonal $D$, the two low-rank generators $U$ and $V$, and the propagators $P$. The one structural difference is that celerite's $P$ is diagonal, a single scalar decay per mode, whereas $a$ here is a full transition matrix. The carry $f$ defined next is celerite's $S$.

The Cholesky $K=LL^\top$ (with the noise folded in, $K=K_{\rm train}+\eta^2 I$) is the celerite scan. It produces the factor's generators $c_k\in\mathbb R$ and $w_k\in\mathbb R^J$, with $L_{kk}=c_k$ and $L_{kj}=p_k^\top\Psi_{k-1,j}w_j$, and threads the carry

$$c_k=\sqrt{d_k-p_k^\top f_{k-1}p_k},\quad w_k=\frac{q_k-a_k f_{k-1}p_k}{c_k},\quad f_k=a_k f_{k-1}a_k^\top+w_k w_k^\top,\quad f_{-1}=0.$$

The variance reuses $f_k$, so it is worth holding onto.

Anchoring a test point

Take a test point $t_\ast$ and let $i=\max{k:t_k\le t_\ast}$ be its left training neighbour, found by binary search. The cross-covariance $k_\ast\in\mathbb R^N$ splits at $i$ into two $J$-vectors $\xi_\ast$ and $\zeta_\ast$, the only test-dependent quantities:

$$(k_\ast)_k=\xi_\ast^\top\Psi_{i,k} q_k\ \ (k\le i),\qquad (k_\ast)_k=p_k^\top\Psi_{k-1,i} \zeta_\ast\ \ (k\gt i).$$

They are built from the test observation $h_\ast=h(t_\ast)$, propagated to the two neighbours. With $\phi_L=\Phi(t_i,t_\ast)$ and $\phi_R=\Phi(t_\ast,t_{i+1})$, each spanning at most one gap,

$$\xi_\ast=\phi_L^\top P_\infty h_\ast,\qquad \zeta_\ast=a_{i+1}^{-1}\phi_R h_\ast.$$

The right anchor is formed at $t_{i+1}$ and pulled back one gap by $a_{i+1}^{-1}$. That cancels the leading transition baked into each $p_k$ and lines the contraction up with the $k\gt i$ split, and the only inverse it needs is over a single training gap rather than the (possibly large) test gap. For extrapolation, drop the missing side: $\xi_\ast=0$ when $t_\ast$ falls left of all the data, and $\zeta_\ast=0$ when it falls to the right.

Predictive mean

Precompute $\beta=K^{-1}y$ with two triangular solves. The two halves of $k_\ast^\top\beta$ are affine scans over the training data,

$$h^{\rm fwd}_k=a_k h^{\rm fwd}_{k-1}+q_k\beta_k\ \ (h^{\rm fwd}_{-1}=0),\qquad h^{\rm bwd}_{k-1}=a_k^\top h^{\rm bwd}_k+p_k\beta_k\ \ (h^{\rm bwd}_{N-1}=0),$$

and the prediction is a single contraction at the anchor, $\mu_\ast=\xi_\ast^\top h_i^{\rm fwd} + \zeta_\ast^\top h_i^{\rm bwd}$. celerite already computes its mean this way.

Predictive variance

Write $\sigma_\ast^2=K_{\ast\ast}-\lVert v\rVert^2$ with $v=L^{-1}k_\ast$, and split $\lVert v\rVert^2=\sum_k v_k^2$ at the anchor.

For $k\le i$ the right-hand side $(k_\ast)_k$ is built from the kernel's own $q_k$, so the forward solve repeats the recursion that produced $L$, and its squared entries accumulate into the carry:

$$\sum_{k\le i}v_k^2=\xi_\ast^\top f_i \xi_\ast.$$

There is nothing new to compute here, since $f_i$ is the inclusive carry from the factorization. For $k\gt i$ the remaining terms collapse onto a single vector $\delta_\ast=\zeta_\ast-f_i\xi_\ast$ and one more train-only accumulator,

$$\sum_{k\gt i}v_k^2=\delta_\ast^\top P_i \delta_\ast,\qquad P_{k-1}=A_k^\top P_k A_k+\tfrac1{c_k^2}p_kp_k^\top,\quad A_k=a_k-\tfrac1{c_k}w_kp_k^\top,\quad P_{N-1}=0.$$

Together,

$$\boxed{\ \mu_\ast=\xi_\ast^\top h^{\rm fwd}_i+\zeta_\ast^\top h^{\rm bwd}_i,\qquad \sigma_\ast^2=K_{\ast\ast}-\xi_\ast^\top f_i \xi_\ast-\delta_\ast^\top P_i \delta_\ast,\qquad \delta_\ast=\zeta_\ast-f_i\xi_\ast\ }$$

Implementing it

Four arrays get precomputed over the training data: $h^{\rm fwd}$ and $h^{\rm bwd}$ for the mean, $f$ and $P$ for the variance. If the Cholesky only stored the exclusive carry $f_{k-1}$, the inclusive $f_k=a_k f_{k-1}a_k^\top+w_kw_k^\top$ is a vmap away with no extra scan.

Pad each array to length $N+1$ and index it by the anchor $i$ from the binary search, so extrapolation falls out on its own (the missing side contributes zero). A test point then costs a binary search and two quadratic forms, $O(J^2)$, and the whole thing vmaps over the test set. There is no second factorization and no scan over a combined train-plus-test grid.

The one caveat is that the carry coincidence relies on solving against the kernel's own generators. Predicting a different process, such as one component of a kernel sum, a derivative or integral of the latent, or the same kernel with different hyperparameters, breaks it and brings in extra accumulators. So the fast path only applies when the prediction kernel is the one that was fit.

dfm and others added 2 commits June 11, 2026 20:14
After the Cholesky factorization at N training points, four train-only affine
scans give a state from which the predictive mean and diagonal variance at any
test point follow from a binary search plus an O(J^2) contraction.

- ops.cholesky returns the Riccati carry f alongside (c, w); QuasisepSolver
  stores it as cholesky_carry so it is computed exactly once.
- predict.py: precompute() runs two affine scans for the y-dependent
  h_fwd/h_bwd plus one backward congruence scan for P (reusing f for the
  forward variance term), bundled in a flat PredictState that also caches the
  sortable training coordinates and densified transitions. A single
  predict_mean_and_var() shares one binary-search/anchor pass between the mean
  and the variance. ConditionedMean and ConditionedKernel (a kernels.Conditioned
  subclass overriding only evaluate_diag) read from this state; ConditionedSolver
  lazily delegates the full conditional covariance to a DirectSolver.
- Solver.condition() returns a ConditionedComponents bundle and now takes
  kernel: Kernel | None, where None means "predict with the training kernel."
  QuasisepSolver.condition() takes the fast path when X_test is given and kernel
  is None; gating on None (rather than `kernel is self.kernel`) keeps the fast
  path live under jax.jit, where object identity is not preserved across the
  trace boundary. The shared mean/kernel construction lives in
  solver.conditioned_mean_parts, used by both DirectSolver and the
  QuasisepSolver fallback.
- gp.py: condition() passes the (possibly None) kernel through unresolved, and
  GaussianProcess gains a variance_value field so cond_gp.variance is a stored
  array, symmetric with mean_value.

Verified to machine precision against dense across Matern32/52/SHO/Cosine and
summed/scaled kernels, in both sequential and parallel modes, and the fast path
is checked to survive a jit/pytree round-trip.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant