Add Lightcurve.acf() and plot_acf() for data-driven and GP-implied autocorrelation#146
Add Lightcurve.acf() and plot_acf() for data-driven and GP-implied autocorrelation#146
Conversation
Agent-Logs-Url: https://github.com/ICSM/pgmuvi/sessions/26a6281f-1a5e-42c3-85e0-d816898da200 Co-authored-by: sundarjhu <15619959+sundarjhu@users.noreply.github.com>
…D GP guard, fitted-state check, mock tests Agent-Logs-Url: https://github.com/ICSM/pgmuvi/sessions/8ee59df0-dfab-4d1d-b0de-42d7488907c0 Co-authored-by: sundarjhu <15619959+sundarjhu@users.noreply.github.com>
…nce with normalize=True Agent-Logs-Url: https://github.com/ICSM/pgmuvi/sessions/8e79cae9-eadf-4596-887d-3985c1bb6447 Co-authored-by: sundarjhu <15619959+sundarjhu@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Adds an autocorrelation-function API to pgmuvi.Lightcurve, supporting both a data-driven estimator (pairwise lag binning for uneven sampling) and a GP-implied ACF derived from the fitted covariance model, plus plotting and tests.
Changes:
- Introduces
ACFResultplusLightcurve.acf()andLightcurve.plot_acf(). - Implements
_acf_data()(pairwise binned estimator) and_acf_gp()(covariance-derived ACF). - Adds a new
tests/test_acf.pycovering core behaviors, 2D preconditions, and constant-series normalization errors.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
pgmuvi/lightcurve.py |
Adds the ACF result container and the acf()/plot_acf() implementations (data + GP paths). |
tests/test_acf.py |
Adds unit tests for the new ACF API, including error cases and basic plotting. |
| self._validate_n_lags(n_lags) | ||
|
|
||
| t = self._get_time_axis().double() | ||
| y = self.ydata.double() | ||
|
|
There was a problem hiding this comment.
Already implemented in commit 0e137aa — _validate_n_lags() is now called only inside the lag_edges is None branch, so passing lag_edges with an irrelevant n_lags value no longer raises.
| raise ValueError( | ||
| "lag_edges must be strictly increasing." | ||
| ) | ||
| max_lag = float(edges[-1]) |
There was a problem hiding this comment.
Already implemented in commit 0e137aa — float(edges[-1]) replaced with edges[-1].detach().item().
| # ------------------------------------------------------------------ | ||
| # Zero-lag entry: all n self-pairs | ||
| # ------------------------------------------------------------------ | ||
| variance = float((y_centered**2).mean()) |
| t_ref = float(t.mean()) if reference_time is None else float(reference_time) | ||
|
|
There was a problem hiding this comment.
Already implemented in commit 0e137aa — float(t.mean()) replaced with t.mean().detach().item().
| if normalize: | ||
| x_ref = x1[:1] | ||
| lazy_var = self.model.covar_module(x_ref, x_ref) | ||
| variance = float(lazy_var.diagonal()[0].detach()) |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
… only when lag_edges is None Agent-Logs-Url: https://github.com/ICSM/pgmuvi/sessions/abc3876e-6f2d-442d-85a8-cc0a6e833166 Co-authored-by: sundarjhu <15619959+sundarjhu@users.noreply.github.com>
| # ------------------------------------------------------------------ | ||
| # Vectorised pairwise computation (i < j pairs) | ||
| # ------------------------------------------------------------------ | ||
| idx_i, idx_j = torch.triu_indices(n, n, offset=1, device=t.device) | ||
| dt = (t[idx_j] - t[idx_i]).abs() | ||
| vals = y_centered[idx_i] * y_centered[idx_j] | ||
|
|
||
| # Keep only pairs within max_lag | ||
| in_range = dt <= max_lag | ||
| dt = dt[in_range] | ||
| vals = vals[in_range] | ||
|
|
||
| # Bin dt into lag bins | ||
| bin_idx = torch.searchsorted(edges, dt) - 1 | ||
| valid_bins = (bin_idx >= 0) & (bin_idx < n_bins) | ||
| bin_idx = bin_idx[valid_bins] | ||
| vals = vals[valid_bins] | ||
|
|
||
| bin_sums = torch.zeros(n_bins, dtype=torch.float64, device=t.device) | ||
| bin_counts = torch.zeros(n_bins, dtype=torch.float64, device=t.device) | ||
| bin_sums.scatter_add_(0, bin_idx, vals) | ||
| bin_counts.scatter_add_(0, bin_idx, torch.ones_like(vals)) | ||
|
|
||
| # Normalise by counts | ||
| valid = bin_counts > 0 | ||
| acf_vals = torch.zeros(n_bins, dtype=torch.float64, device=t.device) | ||
| acf_vals[valid] = bin_sums[valid] / bin_counts[valid] | ||
|
|
||
| if normalize: | ||
| acf_vals = acf_vals / variance | ||
| zero_lag_val = 1.0 | ||
|
|
||
| lag_centers = (edges[:-1] + edges[1:]) / 2.0 | ||
|
|
||
| # Prepend zero-lag point | ||
| zero_lag_tensor = torch.zeros(1, dtype=torch.float64, device=t.device) | ||
| zero_acf_tensor = torch.tensor( | ||
| [zero_lag_val], dtype=torch.float64, device=t.device | ||
| ) | ||
| zero_count_tensor = torch.tensor( | ||
| [float(n)], dtype=torch.float64, device=t.device | ||
| ) | ||
|
|
||
| full_lag = torch.cat([zero_lag_tensor, lag_centers]) | ||
| full_acf = torch.cat([zero_acf_tensor, acf_vals]) | ||
| full_counts = torch.cat([zero_count_tensor, bin_counts]) |
There was a problem hiding this comment.
I notice this is the brute-force ACF calculation, which is an O(n**2) algorithm. Is it worth switching to one of the nlogn algorithms? For example, pyro has one (https://docs.pyro.ai/en/stable/ops.html#pyro.ops.stats.autocorrelation) and since pyro is a (future) dependency for MCMC anyway, is it worth just using that? The complication is that that implementation doesn't know about the time sampling of the data. If that's a bottleneck, the FFT algorithm could perhaps be implemented using finufft? This might be a problem for a future issue, however.
There was a problem hiding this comment.
Good point. I was working with 100-point LCs so I didn't even consider this. Which of the two would you like to consider as a replacement (or offer the user a choice)? We can add it as a feature request.
There was a problem hiding this comment.
I guess we have irregularly-sampled times, so the pyro implementation is probably not an option? @copilot please comment on the practicality of different ways to implement O(n logn) ACF calculations, using either/both FFTs or convolutions instead of brute-force calculations.
There was a problem hiding this comment.
You're right that the current estimator is O(n²). Here's a summary of the O(n log n) options for irregularly-sampled data:
Pyro's autocorrelation — uses FFT via the Wiener–Khinchin theorem, but assumes evenly-spaced samples. It's not directly applicable here without gridding first.
Grid + FFT (O(n log n) after gridding) — Snap or kernel-average the irregular samples onto a regular time grid, then use torch.fft.rfft / irfft on the gridded series. This is practical and requires no new dependencies. The main caveat is that the gridding step introduces a bias proportional to the gap between the true sample times and the grid nodes, which worsens for sparsely or irregularly sampled regions.
NUFFT via finufft — Computes the autocorrelation directly from the irregular sample positions (essentially evaluating the structure function on a frequency grid without gridding). Achieves true O(n log n) with no gridding bias, but adds finufft as a new dependency and the API is considerably more involved.
Lomb–Scargle + inverse FFT — The LS periodogram is already an O(n log n) spectral estimate for unevenly-sampled data, and by the Wiener–Khinchin theorem the ACF is (approximately) its inverse Fourier transform. Since pgmuvi already has multiband_ls_significance.py, this could reuse existing infrastructure. The catch is that the LS window function introduces sidelobes in the recovered ACF unless a debiasing step is added.
Recommendation for a future issue: For the typical astronomical light curve (n ~ 10²–10³), the O(n²) cost is rarely the bottleneck in practice. When it becomes one, the gridding + FFT route is the lowest-friction path (no new dependencies, easily expressed in pure PyTorch). The LS-based route is more principled for very sparse/irregular sampling. I'd suggest filing this as a feature request and noting the gridding + FFT approach as the preferred first step, with finufft as the stretch goal for truly publication-quality bias control.
acf()docstring: clarifybandparam — required for 2D+data, 2D+gp is unsupported (NotImplementedError)_acf_data()whennormalize=True