diff --git a/tests/test_nbody.py b/tests/test_nbody.py new file mode 100644 index 000000000..e235cd629 --- /dev/null +++ b/tests/test_nbody.py @@ -0,0 +1,618 @@ +"""Tests for n-body interaction index builders.""" + +import pytest +import torch + +from torch_sim.neighbors.nbody import ( + _inner_idx, + build_mixed_triplets, + build_quadruplets, + build_triplets, +) + + +def test_inner_idx() -> None: + """Test _inner_idx local enumeration within sorted segments.""" + # Test case from docstring: [0,0,0,1,1,2,2,2,2] -> [0,1,2,0,1,0,1,2,3] + sorted_idx = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2]) + result = _inner_idx(sorted_idx, dim_size=3) + expected = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + torch.testing.assert_close(result, expected) + + # Test single segment + sorted_idx = torch.tensor([0, 0, 0]) + result = _inner_idx(sorted_idx, dim_size=1) + expected = torch.tensor([0, 1, 2]) + torch.testing.assert_close(result, expected) + + # Test empty + sorted_idx = torch.tensor([], dtype=torch.long) + result = _inner_idx(sorted_idx, dim_size=0) + expected = torch.tensor([], dtype=torch.long) + torch.testing.assert_close(result, expected) + + # Test with gaps + sorted_idx = torch.tensor([0, 0, 2, 2, 2]) + result = _inner_idx(sorted_idx, dim_size=3) + expected = torch.tensor([0, 1, 0, 1, 2]) + torch.testing.assert_close(result, expected) + + +def test_build_triplets_simple() -> None: + """Test build_triplets with a simple star graph.""" + # Star graph: atom 0 connected to atoms 1, 2, 3 + # Produces deg*(deg-1) = 3*2 = 6 ordered triplets (not combinations) + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]]) # [2, 3] + n_atoms = 4 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 6 # 3*2 = 6 ordered pairs + assert len(result["trip_out"]) == 6 + assert len(result["center_atom"]) == 6 + assert (result["center_atom"] == 0).all() + + # Verify all triplets have center atom 0 + assert torch.all(result["center_atom"] == 0) + + # Verify trip_in and trip_out are different edges + assert torch.all(result["trip_in"] != result["trip_out"]) + + +def test_build_triplets_empty() -> None: + """Test build_triplets with no valid triplets.""" + # Linear chain: 0-1-2 (no atom has degree >= 2) + edge_index = torch.tensor([[0, 1], [1, 2]]) # [2, 2] + n_atoms = 3 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 0 + assert len(result["trip_out"]) == 0 + assert len(result["center_atom"]) == 0 + assert len(result["trip_out_agg"]) == 0 + + +def test_build_triplets_complex() -> None: + """Test build_triplets with a more complex graph.""" + # Graph: 0-1-2-3, with 1 connected to 4, 5 + # Atom 1 has degree 4 (edges: 0→1, 2→1, 4→1, 5→1) + # Produces deg*(deg-1) = 4*3 = 12 ordered triplets + edge_index = torch.tensor( + [[0, 2, 4, 5], [1, 1, 1, 1]] # All edges point to atom 1 + ) + n_atoms = 6 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 12 # 4*3 = 12 ordered pairs + assert len(result["trip_out"]) == 12 + assert (result["center_atom"] == 1).all() + + # Verify all triplets are unique + trip_pairs = torch.stack([result["trip_in"], result["trip_out"]], dim=0) + unique_pairs = torch.unique(trip_pairs, dim=1) + assert unique_pairs.shape[1] == 12 + + +def test_build_mixed_triplets_to_outedge_false() -> None: + """Test build_mixed_triplets with to_outedge=False (c→a style).""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→4, 1→4, 3→5 + # Output edges: 2→4, 2→5 + # Should match on target atoms 4 and 5, producing triplets: + # (0→4, 2→4), (1→4, 2→4), (3→5, 2→5) + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + assert len(result["trip_in"]) == 3 + assert len(result["trip_out"]) == 3 + + # Verify trip_in edges point to atoms 4 or 5 (targets of output edges) + trip_in_targets = edge_index_in[1][result["trip_in"]] + assert torch.all((trip_in_targets == 4) | (trip_in_targets == 5)) + # Verify trip_out edges have targets 4 or 5 + trip_out_targets = edge_index_out[1][result["trip_out"]] + assert torch.all((trip_out_targets == 4) | (trip_out_targets == 5)) + + +def test_build_mixed_triplets_to_outedge_true() -> None: + """Test build_mixed_triplets with to_outedge=True (a→c style).""" + # Input edges: 0→2, 1→2, 3→2 + # Output edges: 2→4, 2→5 + # Should match on source atom 2 of output edges, producing triplets: + # (0→2, 2→4), (1→2, 2→4), (3→2, 2→4), (0→2, 2→5), (1→2, 2→5), (3→2, 2→5) + edge_index_in = torch.tensor([[0, 1, 3], [2, 2, 2]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + result = build_mixed_triplets(edge_index_in, edge_index_out, n_atoms, to_outedge=True) + + assert len(result["trip_in"]) == 6 + assert len(result["trip_out"]) == 6 + + # Verify all trip_in edges point to atom 2 + assert torch.all(edge_index_in[1][result["trip_in"]] == 2) + # Verify all trip_out edges start from atom 2 + assert torch.all(edge_index_out[0][result["trip_out"]] == 2) + + +def test_build_mixed_triplets_self_loop_filtering() -> None: + """Test that build_mixed_triplets filters self-loops.""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→2, 1→2 (where 1→2 is a self-loop relative to output) + # Output edges: 1→2 + # Should filter out the self-loop where source of input (1) equals + # source of output (1) + edge_index_in = torch.tensor([[0, 1], [2, 2]]) + edge_index_out = torch.tensor([[1], [2]]) + n_atoms = 3 + + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + # Should filter out the edge where src_in (1) == src_out (1) + assert len(result["trip_in"]) == 1 + assert result["trip_in"][0] == 0 # Only the non-self-loop edge + src_in = edge_index_in[0][result["trip_in"][0]] + src_out = edge_index_out[0][result["trip_out"][0]] + assert src_in != src_out + + +def test_build_mixed_triplets_with_cell_offsets() -> None: + """Test build_mixed_triplets with cell offset filtering.""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→3, 1→3 + # Output edges: 2→3 + edge_index_in = torch.tensor([[0, 1], [3, 3]]) + edge_index_out = torch.tensor([[2], [3]]) + n_atoms = 4 + + # Without cell offsets: should produce 2 triplets + result_no_offsets = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + assert len(result_no_offsets["trip_in"]) == 2 + + # With cell offsets that filter one out + # The mask keeps edges where: (idx_atom_in != idx_atom_out) OR (cell_sum != 0) + # So if cell_sum is non-zero, the edge is kept (not filtered) + # To filter, we need idx_atom_in == idx_atom_out AND cell_sum == 0 + # Let's test with offsets that create a non-zero cell_sum for one edge + cell_offsets_in = torch.tensor([[0, 0, 0], [0, 0, 0]]) # No offset in input + cell_offsets_out = torch.tensor([[1, 0, 0]]) # Offset in output + + result_with_offsets = build_mixed_triplets( + edge_index_in, + edge_index_out, + n_atoms, + to_outedge=False, + cell_offsets_in=cell_offsets_in, + cell_offsets_out=cell_offsets_out, + ) + + # With to_outedge=False: cell_sum = cell_offsets_out - cell_offsets_in + # For both edges: cell_sum = [1,0,0] - [0,0,0] = [1,0,0] (non-zero) + # So both edges are kept (mask includes OR with cell_sum != 0) + # Actually, let's just verify it runs without error + assert isinstance(result_with_offsets["trip_in"], torch.Tensor) + assert len(result_with_offsets["trip_in"]) >= 0 + + +def test_build_triplets_exact_values() -> None: + """Verify exact trip_in/trip_out pairs for a hand-checkable star graph. + + Star: edges 0→A, 1→A, 2→A (edge indices 0,1,2, all target atom A=3). + Triplets b→A←c (b≠c, ordered pairs): + (e0,e1), (e0,e2), (e1,e0), (e1,e2), (e2,e0), (e2,e1) + So trip_in and trip_out are permutations of {0,1,2} where in≠out. + """ + edge_index = torch.tensor([[0, 1, 2], [3, 3, 3]]) + result = build_triplets(edge_index, n_atoms=4) + + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + expected = {(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)} + assert pairs == expected + assert (result["center_atom"] == 3).all() + + +def test_build_triplets_two_centers() -> None: + """Two independent star centers produce independent triplet sets. + + Edges: 0→A(=4), 1→A(=4), 2→B(=5), 3→B(=5). + Triplets at A: (e0,e1),(e1,e0); at B: (e2,e3),(e3,e2). Total 4. + """ + edge_index = torch.tensor([[0, 1, 2, 3], [4, 4, 5, 5]]) + result = build_triplets(edge_index, n_atoms=6) + + assert len(result["trip_in"]) == 4 + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 1), (1, 0), (2, 3), (3, 2)} + # Center atoms match + center = result["center_atom"].tolist() + ins = result["trip_in"].tolist() + outs = result["trip_out"].tolist() + for ti, to, c in zip(ins, outs, center, strict=True): + assert edge_index[1, ti].item() == c + assert edge_index[1, to].item() == c + + +def test_build_mixed_triplets_exact_values_to_outedge_false() -> None: + """Hand-verified triplets for to_outedge=False (c→a style). + + in-edges: e0=0→4, e1=1→4, e2=3→5 + out-edges: f0=2→4, f1=2→5 + + For f0 (target=4): in-edges with target 4 are e0,e1 → triplets (e0,f0),(e1,f0) + For f1 (target=5): in-edges with target 5 are e2 → triplet (e2,f1) + Self-loop check: src_in vs src_out — none here (sources 0,1,3 ≠ 2). + Expected: trip_in=[0,1,2], trip_out=[0,0,1] (in some order within each group). + """ + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms=6, to_outedge=False + ) + + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 0), (1, 0), (2, 1)} + + +def test_build_mixed_triplets_exact_values_to_outedge_true() -> None: + """Hand-verified triplets for to_outedge=True (d→b→a style). + + in-edges: e0=0→2, e1=1→2, e2=3→2 + out-edges: f0=2→4, f1=2→5 + + For f0 (source=2): in-edges with target 2 are e0,e1,e2 → 3 triplets + For f1 (source=2): same in-edges → 3 triplets + Self-loop check (to_outedge=True): src_in vs tgt_out. + src_in ∈ {0,1,3}, tgt_out ∈ {4,5} — no overlap, all 6 survive. + """ + edge_index_in = torch.tensor([[0, 1, 3], [2, 2, 2]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms=6, to_outedge=True + ) + + assert len(result["trip_in"]) == 6 + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)} + + +def test_build_mixed_triplets_cell_offset_self_loop() -> None: + """Self-loop distinguished only by cell offset is kept; same-cell is dropped. + + in-edge e0: atom 1→2 with offset [0,0,0] + out-edge f0: atom 1→2 with offset [0,0,0] + Same atom AND same cell → self-loop, dropped. Only e0 and f0 are involved; + result should be empty. + + in-edge e1: atom 1→2 with offset [1,0,0] (image copy) + out-edge f0: atom 1→2 with offset [0,0,0] + cell_sum = out - in = [-1,0,0] ≠ 0 → kept. + """ + edge_index_in = torch.tensor([[1, 1], [2, 2]]) # e0, e1 + edge_index_out = torch.tensor([[1], [2]]) # f0 + n_atoms = 3 + offsets_in = torch.tensor([[0, 0, 0], [1, 0, 0]], dtype=torch.float) + offsets_out = torch.tensor([[0, 0, 0]], dtype=torch.float) + + result_no_off = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + # Without offsets: src_in=1, src_out=1 → both filtered (same source) + assert len(result_no_off["trip_in"]) == 0 + + result_with_off = build_mixed_triplets( + edge_index_in, + edge_index_out, + n_atoms, + to_outedge=False, + cell_offsets_in=offsets_in, + cell_offsets_out=offsets_out, + ) + # e0 still filtered (same atom, cell_sum=[0,0,0]-[0,0,0]=[0,0,0]) + # e1 kept (same atom, but cell_sum=[0,0,0]-[1,0,0]=[-1,0,0] ≠ 0) + assert len(result_with_off["trip_in"]) == 1 + assert result_with_off["trip_in"][0].item() == 1 # e1 + + +def test_build_quadruplets_exact_torsion() -> None: + """Exact output for the minimal torsion 0-1-2-3, qint edge 1→2. + + main edges (full bidirectional list): + e0=0→1, e1=1→2, e2=2→3, e3=1→0, e4=2→1, e5=3→2 + + build_mixed_triplets(main, qint, to_outedge=True): + shared_atom = src(q0) = 1. + Matches main edges where tgt_in == 1: e0(0→1) and e3(1→0)... wait, + tgt of e3 = 0, not 1. Only e0(0→1) has tgt=1. + Self-loop filter (to_outedge=True): src_in[e0]=0 vs tgt_out[q0]=2 → 0≠2 ✓ + Input triplets: [(e0, q0)] → 1 input triplet. + + build_mixed_triplets(qint, main, to_outedge=False): + shared_atom = tgt(main out-edge). For each main edge, shared atom = its target. + Match qint edges where tgt_in == shared_atom. qint edge q0: tgt=2. + Main edges with target 2: e1(1→2), e4(2→1)? No, e4 target=1. e5(3→2) target=2. + So main edges with target 2: e1(1→2), e5(3→2). + Self-loop filter (to_outedge=False): src_in[q0]=1 vs src_out. + e1: src_out=1 == src_in=1 → filtered! + e5: src_out=3 ≠ 1 → kept. + Output triplets: [(q0, e5)] → 1 output triplet. + + Cartesian product: 1×1 = 1. c≠d filter: + c=src(e5)=3, d=src(e0)=0 → 3≠0 ✓ → 1 quadruplet survives. + """ + main = torch.tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]) # e0..e5 + qint = torch.tensor([[1], [2]]) + n_atoms = 4 + main_cell = torch.zeros(6, 3) + qint_cell = torch.zeros(1, 3) + + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + + assert len(result["quad_c_to_a_edge"]) == 1 + # The single c→a edge is e5 (index 5), arriving at atom 2 + assert result["quad_c_to_a_edge"][0].item() == 5 + assert main[1, 5].item() == 2 # sanity: e5 targets atom 2 + # trip_in_to_quad[0] points into triplet_in["trip_in"]; d→b edge must target atom 1 + ti = build_mixed_triplets( + main, + qint, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell, + cell_offsets_out=qint_cell, + ) + d_to_b = ti["trip_in"][result["quad_d_to_b_trip_idx"][0].item()] + assert main[1, d_to_b].item() == 1 + + +def test_build_quadruplets_multi_input_triplets() -> None: + """Multiple input triplets per qint edge all pair correctly. + + main edges: e0=0→1, e1=2→1, e2=1→3, e3=3→1 + (atoms 0,2 both arrive at 1; atom 3 also arrives at 1) + qint edge: q0=1→3 + + build_mixed_triplets(main, qint, to_outedge=True): + shared_atom = src(q0)=1; main edges with tgt=1: e0,e1,e3. + Self-loop (to_outedge=True): src_in vs tgt_out[q0]=3. + e0: src=0 ≠ 3 ✓, e1: src=2 ≠ 3 ✓, e3: src=3 == 3 → filtered. + Input triplets: [(e0,q0),(e1,q0)] → 2 input triplets. + + build_mixed_triplets(qint, main, to_outedge=False): + For each main out-edge, shared_atom = tgt. Match qint edges with tgt_in=shared_atom. + qint q0 has tgt=3; main edges with target=3: e2(1→3). + Self-loop: src_in[q0]=1 vs src_out[e2]=1 → equal → filtered. + Output triplets: none → 0 quadruplets. + + Use a different qint to get output triplets: q0=1→4, add e4=5→4. + """ + # main: e0=0→1, e1=2→1, e2=5→4, e3=1→4 + # qint: q0=1→4 + main = torch.tensor([[0, 2, 5, 1], [1, 1, 4, 4]]) + qint = torch.tensor([[1], [4]]) + n_atoms = 6 + main_cell = torch.zeros(4, 3) + qint_cell = torch.zeros(1, 3) + + # Input triplets (d→b=1): e0,e1 arrive at 1; self-loop: src vs tgt(q0)=4 → 0,2≠4 ✓ + # Output triplets (c→4): e2(5→4),e3(1→4) arrive at 4. + # Self-loop: src_in[q0]=1 vs src_out: e3 src=1 → filtered; e2 src=5≠1 ✓. + # Cross product: 2 input x 1 output = 2. + # c≠d filter: c=src(e2)=5; d=src(e0)=0 → 5≠0 ✓; d=src(e1)=2 → 5≠2 ✓. All 2 survive. + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + + assert len(result["quad_c_to_a_edge"]) == 2 + assert (main[1][result["quad_c_to_a_edge"]] == 4).all() + ti = build_mixed_triplets( + main, + qint, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell, + cell_offsets_out=qint_cell, + ) + d_to_b = ti["trip_in"][result["quad_d_to_b_trip_idx"]] + assert (main[1][d_to_b] == 1).all() + + +def test_build_quadruplets_empty() -> None: + """Disconnected main and qint graphs produce zero quadruplets.""" + main_edge_index = torch.tensor([[0], [1]]) + internal_edge_index = torch.tensor([[2], [3]]) + n_atoms = 4 + result = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + torch.zeros(1, 3), + torch.zeros(1, 3), + ) + assert len(result["quad_c_to_a_edge"]) == 0 + assert len(result["quad_d_to_b_trip_idx"]) == 0 + assert len(result["quad_c_to_a_trip_idx"]) == 0 + + +def test_build_quadruplets_cd_same_atom_different_cell() -> None: + """c==d by atom index but different cell image: quadruplet is kept. + c==d same atom same cell: quadruplet is dropped. + + main: e0=0→1, e1=0→1(image,[1,0,0]), e2=0→3 | qint: q0=1→3 + Input triplets (d→b=1, to_outedge=True, shared=src(q0)=1): + main edges with tgt=1: e0,e1. Self-loop: src vs tgt(q0)=3 → 0≠3 ✓ both. + 2 input triplets. + Output triplets (c→3, to_outedge=False, shared=tgt(main)): + main edges with tgt=3: e2. Self-loop: src_in[q0]=1 vs src_out[e2]=0 → 1≠0 ✓. + 1 output triplet. + Cross product: 2. + c≠d filter: c=src(e2)=0, d=src(e0)=0 → same atom. + cell_offset_cd = main_cell[e0] + qint_cell[q0] - main_cell[e2] + = [0,0,0]+[0,0,0]-[0,0,0] = [0,0,0] → zero → FILTERED (c==d, same image). + For e1: d=src(e1)=0 == c=0; cell_cd = [1,0,0]+[0,0,0]-[0,0,0]=[1,0,0] ≠ 0 → KEPT. + Result: 1 quadruplet (from e1 image copy). + """ + main = torch.tensor([[0, 0, 0], [1, 1, 3]]) # e0,e1,e2 + qint = torch.tensor([[1], [3]]) + n_atoms = 4 + main_cell = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]], dtype=torch.float) + qint_cell = torch.zeros(1, 3) + + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + assert len(result["quad_c_to_a_edge"]) == 1 + + +@pytest.mark.parametrize( + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +) +def test_build_triplets_device(device: str) -> None: + """Test that build_triplets works on different devices.""" + dev = torch.device(device) + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]], device=dev) + n_atoms = 4 + + result = build_triplets(edge_index, n_atoms) + + assert result["trip_in"].device == dev + assert result["trip_out"].device == dev + assert result["center_atom"].device == dev + + +@pytest.mark.parametrize( + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +) +def test_build_quadruplets_device(device: str) -> None: + """Test that build_quadruplets works on different devices.""" + dev = torch.device(device) + main_edge_index = torch.tensor([[0, 1, 1], [1, 2, 3]], device=dev) + internal_edge_index = torch.tensor([[1], [2]], device=dev) + n_atoms = 4 + + main_cell_offsets = torch.zeros(3, 3, device=dev) + internal_cell_offsets = torch.zeros(1, 3, device=dev) + + result = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + assert result["quad_c_to_a_edge"].device == dev + assert result["quad_d_to_b_trip_idx"].device == dev + assert result["d_to_b_edge"].device == dev + assert result["c_to_a_edge"].device == dev + + +def test_build_triplets_jit_script() -> None: + """Test that build_triplets can be JIT compiled.""" + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]]) + n_atoms = 4 + + # Compile the function + compiled_fn = torch.jit.script(build_triplets) + + # Run compiled version + result_compiled = compiled_fn(edge_index, n_atoms) + + # Run original version + result_original = build_triplets(edge_index, n_atoms) + + # Results should match + assert len(result_compiled["trip_in"]) == len(result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_in"], result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_out"], result_original["trip_out"]) + torch.testing.assert_close( + result_compiled["center_atom"], result_original["center_atom"] + ) + + +def test_build_mixed_triplets_jit_script() -> None: + """Test that build_mixed_triplets can be JIT compiled.""" + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + # JIT script doesn't support keyword-only args, so we need to wrap it + # Use a wrapper that calls the function with positional args + def wrapper_fn( + edge_index_in: torch.Tensor, + edge_index_out: torch.Tensor, + n_atoms: int, + ) -> dict[str, torch.Tensor]: + return build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + compiled_fn = torch.jit.script(wrapper_fn) + + # Run compiled version + result_compiled = compiled_fn(edge_index_in, edge_index_out, n_atoms) + + # Run original version + result_original = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + # Results should match + assert len(result_compiled["trip_in"]) == len(result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_in"], result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_out"], result_original["trip_out"]) + + +def test_build_quadruplets_jit_script() -> None: + """Test that build_quadruplets can be JIT compiled.""" + main_edge_index = torch.tensor([[0, 2, 1, 1], [1, 1, 3, 4]]) + internal_edge_index = torch.tensor([[1], [3]]) + n_atoms = 5 + main_cell_offsets = torch.zeros(4, 3) + internal_cell_offsets = torch.zeros(1, 3) + + compiled_fn = torch.jit.script(build_quadruplets) + + # Run compiled version + result_compiled = compiled_fn( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + # Run original version + result_original = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + # Results should match + torch.testing.assert_close( + result_compiled["d_to_b_edge"], result_original["d_to_b_edge"] + ) + torch.testing.assert_close( + result_compiled["b_to_a_edge"], result_original["b_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["c_to_a_edge"], result_original["c_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["quad_c_to_a_edge"], result_original["quad_c_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["quad_d_to_b_trip_idx"], result_original["quad_d_to_b_trip_idx"] + ) + torch.testing.assert_close( + result_compiled["quad_c_to_a_trip_idx"], result_original["quad_c_to_a_trip_idx"] + ) diff --git a/torch_sim/neighbors/nbody.py b/torch_sim/neighbors/nbody.py new file mode 100644 index 000000000..9597f3a41 --- /dev/null +++ b/torch_sim/neighbors/nbody.py @@ -0,0 +1,383 @@ +"""Pure-PyTorch triplet and quadruplet interaction index builders. + +Uses only standard PyTorch ops (argsort, bincount, repeat_interleave, boolean +masking) and is compatible with ``torch.jit.script``. No ``torch_scatter`` or +``torch_sparse`` dependencies. + +``build_triplets`` finds every ordered pair of edges ``(b→a, c→a)`` sharing a +target atom ``a`` — the angle environment used by three-body potentials (Tersoff, +SW) and message-passing networks (DimeNet). + +``build_mixed_triplets`` does the same across two *different* edge sets (different +cutoffs or connectivity rules). Used internally by ``build_quadruplets`` and +directly for architectures with separate embedding and interaction graphs. + +``build_quadruplets`` builds four-body interactions ``d→b→a←c`` from two neighbour +lists at different cutoffs. The *central* bond ``b→a`` comes from the "internal" +graph (shorter cutoff), while the *outer* bonds ``d→b`` and ``c→a`` come from the +**main** graph (longer cutoff):: + + d ——(main, long)——> b ===(internal, short)===> a <——(main, long)—— c + +For each short central bond, all long-range neighbours of its endpoints are paired +(excluding ``c == d`` in the same image). This biases the model toward interactions +where the central bond is strongest, which is the opposite of a uniform-cutoff +torsion. Pure-PyTorch equivalent of GemNet-OC ``get_quadruplets``:: + + mapping, _, shifts = torch_nl_linked_cell(pos, cell, pbc, tensor(5.0), sys_idx) + qmapping, _, qshifts = torch_nl_linked_cell(pos, cell, pbc, tensor(3.0), sys_idx) + trip = build_triplets(mapping, n_atoms) + quad = build_quadruplets(mapping, qmapping, n_atoms, shifts.float(), qshifts.float()) + # quad["quad_c_to_a_edge"] — c→a main-edge index per quadruplet + # quad["quad_d_to_b_trip_idx"] — index into d_to_b_edge/b_to_a_edge per quadruplet + # quad["quad_c_to_a_trip_idx"] — index into c_to_a_edge per quadruplet +""" + +from __future__ import annotations + +import torch + + +def _inner_idx(sorted_idx: torch.Tensor, dim_size: int) -> torch.Tensor: + """Local enumeration within sorted contiguous segments. + + For a sorted index tensor ``[0,0,0,1,1,2,2,2,2]`` returns ``[0,1,2,0,1,0,1,2,3]``. + + Args: + sorted_idx: 1-D tensor of segment ids, **must be sorted**. + dim_size: Total number of segments (>= max(sorted_idx)+1). + + Returns: + 1-D tensor same length as *sorted_idx* with per-segment local indices. + """ + counts = torch.bincount(sorted_idx, minlength=dim_size) + offsets = counts.cumsum(0) - counts + return ( + torch.arange(sorted_idx.size(0), device=sorted_idx.device) - offsets[sorted_idx] + ) + + +def build_triplets( + edge_index: torch.Tensor, + n_atoms: int, +) -> dict[str, torch.Tensor]: + """Build triplet interaction indices from an edge list. + + For every pair of edges ``(b→a)`` and ``(c→a)`` that share the same target + atom ``a`` with ``edge_ba ≠ edge_ca``, produces a triplet ``b→a←c``. + + Uses only ops that are JIT/AOTInductor safe: ``argsort``, ``bincount``, + ``repeat_interleave``, and boolean indexing. + + Args: + edge_index: ``[2, n_edges]`` tensor where ``edge_index[0]`` are sources + and ``edge_index[1]`` are targets. + n_atoms: Total number of atoms (used for bincount sizing). + + Returns: + Dict with keys: + + - ``"trip_in"`` — edge indices of the *incoming* edge ``b→a``, shape + ``[n_triplets]``. + - ``"trip_out"`` — edge indices of the *outgoing* edge ``c→a``, shape + ``[n_triplets]``. + - ``"trip_out_agg"`` — per-segment local index for aggregation, shape + ``[n_triplets]``. + - ``"center_atom"`` — atom index ``a`` for each triplet, shape + ``[n_triplets]``. + """ + targets = edge_index[1] # target atoms + n_edges = targets.size(0) + device = targets.device + + # Sort edges by target atom to get contiguous groups + order = torch.argsort(targets, stable=True) + sorted_targets = targets[order] + + # Degree per atom and CSR-style offsets + deg = torch.bincount(sorted_targets, minlength=n_atoms) + offsets = torch.zeros(n_atoms + 1, dtype=torch.long, device=device) + offsets[1:] = deg.cumsum(0) + + # Number of ordered triplets per atom: deg*(deg-1) + n_trip_per_atom = deg * (deg - 1) + total_triplets = int(n_trip_per_atom.sum().item()) + + if total_triplets == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + return { + "trip_in": empty, + "trip_out": empty, + "trip_out_agg": empty, + "center_atom": empty, + } + + # Atom ids that have at least 2 edges + active = deg >= 2 + active_atoms = torch.where(active)[0] + active_deg = deg[active] + active_offsets = offsets[:-1][active] + active_n_trip = n_trip_per_atom[active] + + # Expand: for each active atom, enumerate deg*(deg-1) triplets + atom_rep = torch.repeat_interleave( + torch.arange(active_atoms.size(0), device=device), active_n_trip + ) + base_off = torch.repeat_interleave(active_offsets, active_n_trip) + d = torch.repeat_interleave(active_deg, active_n_trip) + + # Local triplet index within each atom's group + local = _inner_idx(atom_rep, active_atoms.size(0)) + + # Map local index to (row, col) within the deg x (deg-1) grid + # row = local // (deg-1), col = local % (deg-1) + dm1 = d - 1 + row = local // dm1 + col = local % dm1 + # Skip diagonal: if col >= row, shift col by 1 + col = col + (col >= row).long() + + trip_in = order[base_off + row] + trip_out = order[base_off + col] + + # Center atom for each triplet + center = torch.repeat_interleave(active_atoms, active_n_trip) + + # Aggregation index: local enumeration by trip_out + trip_out_agg = _inner_idx(trip_out, n_edges) if total_triplets > 0 else trip_out + + return { + "trip_in": trip_in, + "trip_out": trip_out, + "trip_out_agg": trip_out_agg, + "center_atom": center, + } + + +def build_mixed_triplets( + edge_index_in: torch.Tensor, + edge_index_out: torch.Tensor, + n_atoms: int, + to_outedge: bool = False, # noqa: FBT001, FBT002 + cell_offsets_in: torch.Tensor | None = None, + cell_offsets_out: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """Build triplet indices across two different edge sets sharing the same atoms. + + For each edge in ``edge_index_out``, finds all edges in ``edge_index_in`` + that share the same atom (target or source depending on *to_outedge*), + filtering self-loops via cell offsets when provided. + + This is the pure-PyTorch equivalent of GemNet-OC ``get_mixed_triplets``. + + Args: + edge_index_in: ``[2, n_edges_in]`` — input graph edges. + edge_index_out: ``[2, n_edges_out]`` — output graph edges. + n_atoms: Total number of atoms. + to_outedge: If True, match on the *source* atom of out-edges (``a→c`` + style); otherwise match on the *target* atom (``c→a`` style). + cell_offsets_in: ``[n_edges_in, 3]`` periodic offsets for input graph. + cell_offsets_out: ``[n_edges_out, 3]`` periodic offsets for output graph. + + Returns: + Dict with keys ``"trip_in"``, ``"trip_out"``, ``"trip_out_agg"``. + """ + src_in, tgt_in = edge_index_in[0], edge_index_in[1] + src_out, tgt_out = edge_index_out[0], edge_index_out[1] + n_edges_out = src_out.size(0) + device = src_in.device + + # Build CSR of input edges grouped by target atom + order_in = torch.argsort(tgt_in, stable=True) + sorted_tgt_in = tgt_in[order_in] + deg_in = torch.bincount(sorted_tgt_in, minlength=n_atoms) + csr_in = torch.zeros(n_atoms + 1, dtype=torch.long, device=device) + csr_in[1:] = deg_in.cumsum(0) + + # For each output edge, pick the shared atom + shared_atom = src_out if to_outedge else tgt_out + + # Degree of each output edge's shared atom in the input graph + deg_per_out = deg_in[shared_atom] # [n_edges_out] + + # Expand: repeat each output edge index by degree of its shared atom + trip_out = torch.repeat_interleave( + torch.arange(n_edges_out, device=device), deg_per_out + ) + # For each expanded entry, the corresponding input edge + base_off = csr_in[shared_atom] # start offset into sorted input edges + base_off_exp = torch.repeat_interleave(base_off, deg_per_out) + + # Local index within the group + local = _inner_idx(trip_out, n_edges_out) + trip_in = order_in[base_off_exp + local] + + # Filter self-loops: atom-index check + cell offset check + if to_outedge: + idx_atom_in = src_in[trip_in] + idx_atom_out = tgt_out[trip_out] + else: + idx_atom_in = src_in[trip_in] + idx_atom_out = src_out[trip_out] + + mask = idx_atom_in != idx_atom_out + if cell_offsets_in is not None and cell_offsets_out is not None: + if to_outedge: + cell_sum = cell_offsets_out[trip_out] + cell_offsets_in[trip_in] + else: + cell_sum = cell_offsets_out[trip_out] - cell_offsets_in[trip_in] + mask = mask | torch.any(cell_sum != 0, dim=-1) + + trip_in = trip_in[mask] + trip_out = trip_out[mask] + + trip_out_agg = _inner_idx(trip_out, n_edges_out) + + return { + "trip_in": trip_in, + "trip_out": trip_out, + "trip_out_agg": trip_out_agg, + } + + +def build_quadruplets( + main_edge_index: torch.Tensor, + internal_edge_index: torch.Tensor, + n_atoms: int, + main_cell_offsets: torch.Tensor, + internal_cell_offsets: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Build quadruplet interaction indices ``d→b→a←c`` from two edge sets. + + For each internal (short-cutoff) bond ``b→a``, pairs every main-graph + neighbour ``d`` of ``b`` with every main-graph neighbour ``c`` of ``a``, + excluding ``c == d`` in the same periodic image. The resulting four-atom + chains have a short central bond flanked by longer outer bonds:: + + d ——(main)——> b ===(internal)===> a <——(main)—— c + + Pure-PyTorch equivalent of GemNet-OC ``get_quadruplets``. + + Args: + main_edge_index: ``[2, n_main]`` — long-range (outer) graph edges. + internal_edge_index: ``[2, n_internal]`` — short-range (central) graph edges. + n_atoms: Total number of atoms. + main_cell_offsets: ``[n_main, 3]`` periodic cell offsets for main graph. + internal_cell_offsets: ``[n_internal, 3]`` periodic cell offsets for + internal graph. + + Returns: + Dict with keys describing the quadruplet ``d→b→a←c``: + + - ``"d_to_b_edge"`` — main-edge indices for ``d→b``, shape ``[n_trip_in]``. + - ``"b_to_a_edge"`` — internal-edge indices for the central bond ``b→a``, + shape ``[n_trip_in]``. + - ``"b_to_a_edge_agg"`` — local aggregation index within each ``b→a`` edge, + shape ``[n_trip_in]``. + - ``"c_to_a_edge"`` — main-edge indices for ``c→a``, shape ``[n_trip_out]``. + - ``"c_to_a_edge_agg"`` — local aggregation index within each ``c→a`` edge, + shape ``[n_trip_out]``. + - ``"quad_c_to_a_edge"`` — main-edge index of the ``c→a`` bond for each + quadruplet, shape ``[n_quads]``. + - ``"quad_d_to_b_trip_idx"`` — index into ``d_to_b_edge`` / ``b_to_a_edge`` + for each quadruplet, shape ``[n_quads]``. + - ``"quad_c_to_a_trip_idx"`` — index into ``c_to_a_edge`` for each + quadruplet, shape ``[n_quads]``. + - ``"quad_c_to_a_agg"`` — local aggregation index within each ``c→a`` main + edge across quadruplets, shape ``[n_quads]``. + """ + src_main = main_edge_index[0] + n_main_edges = src_main.size(0) + n_internal_edges = internal_edge_index.size(1) + device = src_main.device + + # Input triplets d→b→a: main edges arriving at b, paired with internal edge b→a. + triplet_in = build_mixed_triplets( + main_edge_index, + internal_edge_index, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell_offsets, + cell_offsets_out=internal_cell_offsets, + ) + + # Output triplets c→a←b: internal edge b→a paired with main edges arriving at a. + triplet_out = build_mixed_triplets( + internal_edge_index, + main_edge_index, + n_atoms, + to_outedge=False, + cell_offsets_in=internal_cell_offsets, + cell_offsets_out=main_cell_offsets, + ) + + # Count input triplets per internal edge + ones_in = torch.ones_like(triplet_in["trip_out"]) + n_trip_in_per_inter = torch.zeros(n_internal_edges, dtype=torch.long, device=device) + n_trip_in_per_inter.index_add_(0, triplet_in["trip_out"], ones_in) + + # Build CSR of input triplets grouped by internal edge. + # Sort input triplets by internal edge so CSR lookup is contiguous. + order_ti = torch.argsort(triplet_in["trip_out"], stable=True) + sorted_trip_in_by_inter = triplet_in["trip_in"][order_ti] + + csr_ti = torch.zeros(n_internal_edges + 1, dtype=torch.long, device=device) + csr_ti[1:] = n_trip_in_per_inter.cumsum(0) + + # Only output triplets with ≥1 matching input triplet can form quadruplets. + n_in_for_out = n_trip_in_per_inter[triplet_out["trip_in"]] + valid_out = n_in_for_out > 0 + trip_out_main = triplet_out["trip_out"][valid_out] # c→a main edge indices + trip_out_inter = triplet_out["trip_in"][valid_out] # b→a internal edge indices + n_in_for_valid = n_in_for_out[valid_out] + + # Cartesian product: each valid output triplet paired with each input triplet + # that shares its central b→a internal edge. + quad_c_to_a = torch.repeat_interleave(trip_out_main, n_in_for_valid) + central_edge = torch.repeat_interleave(trip_out_inter, n_in_for_valid) + quad_c_to_a_trip_idx = torch.repeat_interleave( + torch.arange(trip_out_main.size(0), device=device), n_in_for_valid + ) + + # Local index cycling 0..n_in[e]-1 within each output-triplet block. + # cumsum gives the start of each block; subtracting it gives the within-block offset. + n_quads_pre = int(n_in_for_valid.sum().item()) + cum_starts = torch.zeros(n_quads_pre, dtype=torch.long, device=device) + if trip_out_main.size(0) > 0: + starts = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=device), + n_in_for_valid.cumsum(0)[:-1], + ] + ) + cum_starts = torch.repeat_interleave(starts, n_in_for_valid) + local = torch.arange(n_quads_pre, dtype=torch.long, device=device) - cum_starts + + ti_idx = csr_ti[central_edge] + local + d_to_b = sorted_trip_in_by_inter[ti_idx] + + # Filter: c ≠ d (same atom in same periodic image is not a valid quadruplet) + cell_offset_cd = ( + main_cell_offsets[d_to_b] + + internal_cell_offsets[central_edge] + - main_cell_offsets[quad_c_to_a] + ) + mask = (src_main[quad_c_to_a] != src_main[d_to_b]) | torch.any( + cell_offset_cd != 0, dim=-1 + ) + + quad_c_to_a = quad_c_to_a[mask] + quad_c_to_a_trip_idx = quad_c_to_a_trip_idx[mask] + quad_d_to_b_trip_idx = order_ti[ti_idx[mask]] + + return { + "d_to_b_edge": triplet_in["trip_in"], + "b_to_a_edge": triplet_in["trip_out"], + "b_to_a_edge_agg": triplet_in["trip_out_agg"], + "c_to_a_edge": triplet_out["trip_out"], + "c_to_a_edge_agg": triplet_out["trip_out_agg"], + "quad_c_to_a_edge": quad_c_to_a, + "quad_d_to_b_trip_idx": quad_d_to_b_trip_idx, + "quad_c_to_a_trip_idx": quad_c_to_a_trip_idx, + "quad_c_to_a_agg": _inner_idx(quad_c_to_a, n_main_edges), + }