Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions atlas-embeddings/python/invariant_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Invariant basis extractor for atlas-embeddings.

Extracts the stable invariant basis from a stream of evolving embedding
states using the covariance operator. The dominant eigenmodes define the
intrinsic geometric axes of the system, separating meaningful structure
from transient components.

This acts as a coordinate-locking layer: instead of operating in an
arbitrary embedding space, the system aligns itself to its own intrinsic
geometry.
"""

import numpy as np
from numpy.linalg import eigh


class InvariantExtractor:
"""Extract invariant basis from evolving embedding states.

Maintains a rolling buffer of observed states and computes the
covariance operator. The eigenvectors of this operator define the
invariant basis — the directions where the system's geometry lives.

Parameters
----------
dim : int
Dimensionality of the state vectors (must be >= 1).
memory : int, optional
Number of recent states to retain in the rolling buffer.
Must be >= 2 to compute a meaningful covariance. Default is 50.

Examples
--------
>>> extractor = InvariantExtractor(dim=4, memory=100)
>>> for state in state_stream:
... extractor.update(state)
>>> eigenvalues, basis = extractor.invariants()
>>> coords = extractor.project(state, k=2)
>>> reconstructed = extractor.reconstruct(coords)
"""

def __init__(self, dim: int, memory: int = 50):
if dim < 1:
raise ValueError(f"dim must be >= 1, got {dim}")
if memory < 2:
raise ValueError(f"memory must be >= 2, got {memory}")

self.dim = dim
self.memory = memory
self._buffer = np.zeros((memory, dim))
self._count = 0
self._basis = None
self._eigenvalues = None

@property
def full(self) -> bool:
"""Whether the rolling buffer has been completely filled."""
return self._count >= self.memory

def update(self, state) -> None:
"""Add a state observation to the rolling buffer.

Parameters
----------
state : array_like
State vector of length `dim`.

Raises
------
ValueError
If `state` has wrong dimensionality.
"""
state = np.asarray(state, dtype=float)
if state.shape != (self.dim,):
raise ValueError(
f"Expected state of shape ({self.dim},), got {state.shape}"
)

idx = self._count % self.memory
self._buffer[idx] = state
self._count += 1
# Invalidate cached decomposition
self._basis = None
self._eigenvalues = None

def _compute(self) -> None:
"""Compute eigendecomposition of the covariance operator."""
n = min(self._count, self.memory)
data = self._buffer[:n]
mean = data.mean(axis=0)
centered = data - mean
cov = (centered.T @ centered) / n
eigenvalues, eigenvectors = eigh(cov)
# Sort descending
order = np.argsort(eigenvalues)[::-1]
self._eigenvalues = eigenvalues[order]
self._basis = eigenvectors[:, order]

def invariants(self):
"""Return the invariant basis and associated eigenvalues.

Returns
-------
eigenvalues : ndarray, shape (dim,)
Eigenvalues in descending order (variance along each axis).
basis : ndarray, shape (dim, dim)
Columns are the invariant basis vectors, ordered by eigenvalue.
"""
if self._eigenvalues is None:
self._compute()
return self._eigenvalues.copy(), self._basis.copy()

def project(self, state, k: int = None):
"""Project a state onto the top-k invariant directions.

Parameters
----------
state : array_like
State vector of length `dim`.
k : int, optional
Number of dominant directions to keep. Defaults to `dim`.

Returns
-------
coords : ndarray, shape (k,)
Coordinates in the invariant basis (top-k components).
"""
if k is None:
k = self.dim
if self._basis is None:
self._compute()
state = np.asarray(state, dtype=float)
return (self._basis[:, :k].T @ state)

def reconstruct(self, coords):
"""Reconstruct a state from invariant-basis coordinates.

Parameters
----------
coords : array_like
Coordinates from `project()`.

Returns
-------
state : ndarray, shape (dim,)
Reconstructed state vector.
"""
coords = np.asarray(coords, dtype=float)
k = len(coords)
if self._basis is None:
self._compute()
return self._basis[:, :k] @ coords
130 changes: 130 additions & 0 deletions atlas-embeddings/python/test_invariant_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numpy as np
from invariant_extractor import InvariantExtractor


def test_basic_functionality():
"""Test basic functionality of InvariantExtractor."""
print("Testing basic functionality...")

# Test initialization
extractor = InvariantExtractor(dim=4, memory=10)
assert extractor.dim == 4
assert extractor.memory == 10
assert not extractor.full
print("✓ Initialization works")

# Test update
state = np.array([1.0, 2.0, 3.0, 4.0])
extractor.update(state)
print("✓ Single update works")

# Test multiple updates to fill buffer
for i in range(10):
extractor.update(np.random.randn(4))
assert extractor.full
print("✓ Buffer filling works")

# Test invariants extraction
vals, vecs = extractor.invariants()
assert len(vals) == 4
assert vecs.shape == (4, 4)
assert np.all(vals >= 0) # eigenvalues should be non-negative
print("✓ Invariants extraction works")

# Test projection and reconstruction
test_state = np.array([1.0, 0.5, 0.1, 0.01])
coords = extractor.project(test_state, k=2)
reconstructed = extractor.reconstruct(coords)
assert len(coords) == 2
assert reconstructed.shape == (4,)
print("✓ Projection and reconstruction work")


def test_error_handling():
"""Test error handling."""
print("\nTesting error handling...")

# Test invalid dim
try:
InvariantExtractor(dim=0)
assert False, "Should have raised ValueError"
except ValueError:
print("✓ Invalid dim error handling works")

# Test invalid memory
try:
InvariantExtractor(dim=4, memory=1)
assert False, "Should have raised ValueError"
except ValueError:
print("✓ Invalid memory error handling works")

# Test wrong state dimension
extractor = InvariantExtractor(dim=3)
try:
extractor.update([1, 2, 3, 4]) # 4D state for 3D extractor
assert False, "Should have raised ValueError"
except ValueError:
print("✓ Wrong state dimension error handling works")


def test_anisotropic_data():
"""Test with anisotropic data like in the example."""
print("\nTesting with anisotropic data...")

extractor = InvariantExtractor(dim=4)

# Feed anisotropic data
for _ in range(500):
state = np.random.randn(4) * np.array([1.0, 0.5, 0.1, 0.01])
extractor.update(state)

vals, basis = extractor.invariants()

# Check that eigenvalues are in descending order
assert np.all(vals[:-1] >= vals[1:]), "Eigenvalues should be in descending order"

# Check that the first eigenvalue is much larger than the last
# (due to anisotropic scaling)
assert vals[0] > vals[-1] * 10, "First eigenvalue should be much larger"

print(f"✓ Anisotropic data test passed")
print(f" Eigenvalues: {vals}")
print(f" Ratio (first/last): {vals[0]/vals[-1]:.2f}")


def test_reconstruction_accuracy():
"""Test reconstruction accuracy."""
print("\nTesting reconstruction accuracy...")

extractor = InvariantExtractor(dim=4, memory=100)

# Generate some data
for _ in range(200):
state = np.random.randn(4) * np.array([2.0, 1.0, 0.5, 0.1])
extractor.update(state)

# Test reconstruction with different k values
original_state = np.array([1.5, 0.8, 0.2, 0.05])

for k in [1, 2, 3, 4]:
coords = extractor.project(original_state, k=k)
reconstructed = extractor.reconstruct(coords)

# Calculate relative error
error = np.linalg.norm(reconstructed - original_state)
relative_error = error / np.linalg.norm(original_state)
print(f" k={k}: relative error = {relative_error:.6f}")

# Error should decrease as k increases
if k < 4:
# Allow some tolerance for numerical precision
assert relative_error < 1.0, f"Reconstruction error too high for k={k}"


if __name__ == "__main__":
print("Running InvariantExtractor tests...\n")
test_basic_functionality()
test_error_handling()
test_anisotropic_data()
test_reconstruction_accuracy()
print("\n✅ All tests passed!")