Add O(J^2)-per-test-point predictive mean/variance from the QSM Cholesky#272
Open
dfm wants to merge 2 commits into
Open
Add O(J^2)-per-test-point predictive mean/variance from the QSM Cholesky#272dfm wants to merge 2 commits into
dfm wants to merge 2 commits into
Conversation
After the Cholesky factorization at N training points, four train-only affine scans give a state from which the predictive mean and diagonal variance at any test point follow from a binary search plus an O(J^2) contraction. - ops.cholesky returns the Riccati carry f alongside (c, w); QuasisepSolver stores it as cholesky_carry so it is computed exactly once. - predict.py: precompute() runs two affine scans for the y-dependent h_fwd/h_bwd plus one backward congruence scan for P (reusing f for the forward variance term), bundled in a flat PredictState that also caches the sortable training coordinates and densified transitions. A single predict_mean_and_var() shares one binary-search/anchor pass between the mean and the variance. ConditionedMean and ConditionedKernel (a kernels.Conditioned subclass overriding only evaluate_diag) read from this state; ConditionedSolver lazily delegates the full conditional covariance to a DirectSolver. - Solver.condition() returns a ConditionedComponents bundle and now takes kernel: Kernel | None, where None means "predict with the training kernel." QuasisepSolver.condition() takes the fast path when X_test is given and kernel is None; gating on None (rather than `kernel is self.kernel`) keeps the fast path live under jax.jit, where object identity is not preserved across the trace boundary. The shared mean/kernel construction lives in solver.conditioned_mean_parts, used by both DirectSolver and the QuasisepSolver fallback. - gp.py: condition() passes the (possibly None) kernel through unresolved, and GaussianProcess gains a variance_value field so cond_gp.variance is a stored array, symmetric with mean_value. Verified to machine precision against dense across Matern32/52/SHO/Cosine and summed/scaled kernels, in both sequential and parallel modes, and the fast path is checked to survive a jit/pytree round-trip.
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
celerite's Cholesky carries a small$J\times J$ matrix $f_k$ that it updates at every step. Usually it is just scratch internal to the factorization, but it is also the accumulator that the predictive variance needs at each test point. Once the factorization is done, the predictive mean and variance both follow from it, and the variance only costs one extra backward pass. The mean was already computed this way; the variance is the new part, and it is nearly free.
The idea comes from the state-space (Kalman filter and smoother) picture of these kernels. There, prediction reads the variance off the filter covariance, which evolves by a Riccati recursion that turns out to match the celerite carry. That picture is only motivation; everything below is written with the quasiseparable generators alone.
Setup
Sort the training inputs$t_0,\dots,t_{N-1}$ . The quasiseparable representation has low-rank generators $p_k,q_k\in\mathbb R^J$ , transition matrices $a_k\in\mathbb R^{J\times J}$ , and a diagonal $d_k$ . tinygp builds them from a state-space model with observation vector $h_k=h(t_k)$ , transition $a_k=\Phi(t_{k-1},t_k)$ (and $a_0=I$ ), and stationary covariance $P_\infty$ :
with the observation noise$\eta_k^2$ added to $d_k$ . Writing $\Psi_{m,j}=a_m a_{m-1}\cdots a_{j+1}$ with $\Psi_{j,j}=I$ for the propagator from point $j$ to $m$ , the lower triangle of the kernel is $K_{ij}=p_i^\top \Psi_{i-1,j} q_j$ for $i\gt j$ .
In the language of the celerite papers,$(d,p,q,a)$ are the semiseparable factors $(D,U,V,P)$ : the diagonal $D$ , the two low-rank generators $U$ and $V$ , and the propagators $P$ . The one structural difference is that celerite's $P$ is diagonal, a single scalar decay per mode, whereas $a$ here is a full transition matrix. The carry $f$ defined next is celerite's $S$ .
The Cholesky$K=LL^\top$ (with the noise folded in, $K=K_{\rm train}+\eta^2 I$ ) is the celerite scan. It produces the factor's generators $c_k\in\mathbb R$ and $w_k\in\mathbb R^J$ , with $L_{kk}=c_k$ and $L_{kj}=p_k^\top\Psi_{k-1,j}w_j$ , and threads the carry
The variance reuses$f_k$ , so it is worth holding onto.
Anchoring a test point
Take a test point$t_\ast$ and let $i=\max{k:t_k\le t_\ast}$ be its left training neighbour, found by binary search. The cross-covariance $k_\ast\in\mathbb R^N$ splits at $i$ into two $J$ -vectors $\xi_\ast$ and $\zeta_\ast$ , the only test-dependent quantities:
They are built from the test observation$h_\ast=h(t_\ast)$ , propagated to the two neighbours. With $\phi_L=\Phi(t_i,t_\ast)$ and $\phi_R=\Phi(t_\ast,t_{i+1})$ , each spanning at most one gap,
The right anchor is formed at$t_{i+1}$ and pulled back one gap by $a_{i+1}^{-1}$ . That cancels the leading transition baked into each $p_k$ and lines the contraction up with the $k\gt i$ split, and the only inverse it needs is over a single training gap rather than the (possibly large) test gap. For extrapolation, drop the missing side: $\xi_\ast=0$ when $t_\ast$ falls left of all the data, and $\zeta_\ast=0$ when it falls to the right.
Predictive mean
Precompute$\beta=K^{-1}y$ with two triangular solves. The two halves of $k_\ast^\top\beta$ are affine scans over the training data,
and the prediction is a single contraction at the anchor,$\mu_\ast=\xi_\ast^\top h_i^{\rm fwd} + \zeta_\ast^\top h_i^{\rm bwd}$ . celerite already computes its mean this way.
Predictive variance
Write$\sigma_\ast^2=K_{\ast\ast}-\lVert v\rVert^2$ with $v=L^{-1}k_\ast$ , and split $\lVert v\rVert^2=\sum_k v_k^2$ at the anchor.
For$k\le i$ the right-hand side $(k_\ast)_k$ is built from the kernel's own $q_k$ , so the forward solve repeats the recursion that produced $L$ , and its squared entries accumulate into the carry:
There is nothing new to compute here, since$f_i$ is the inclusive carry from the factorization. For $k\gt i$ the remaining terms collapse onto a single vector $\delta_\ast=\zeta_\ast-f_i\xi_\ast$ and one more train-only accumulator,
Together,
Implementing it
Four arrays get precomputed over the training data:$h^{\rm fwd}$ and $h^{\rm bwd}$ for the mean, $f$ and $P$ for the variance. If the Cholesky only stored the exclusive carry $f_{k-1}$ , the inclusive $f_k=a_k f_{k-1}a_k^\top+w_kw_k^\top$ is a
vmapaway with no extra scan.Pad each array to length$N+1$ and index it by the anchor $i$ from the binary search, so extrapolation falls out on its own (the missing side contributes zero). A test point then costs a binary search and two quadratic forms, $O(J^2)$ , and the whole thing vmaps over the test set. There is no second factorization and no scan over a combined train-plus-test grid.
The one caveat is that the carry coincidence relies on solving against the kernel's own generators. Predicting a different process, such as one component of a kernel sum, a derivative or integral of the latent, or the same kernel with different hyperparameters, breaks it and brings in extra accumulators. So the fast path only applies when the prediction kernel is the one that was fit.