Type test_io, neighbors, and transforms#243
Conversation
WalkthroughTests were refactored to use concrete fixtures and order-insensitive distance comparisons; transforms tests use 0‑D torch scalars and explicit int casts. torch_sim.neighbors and torch_sim.transforms were updated to accept optional cell/cell_shifts (None), added overloads/typing adjustments, and branched distance/shift logic accordingly. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as Test/Caller
participant Neigh as torch_sim/neighbors.py
participant Trans as torch_sim/transforms.py
Test->>Neigh: strict_nl(cutoff, positions, cell|None, mapping, system_mapping, shifts_idx)
alt cell is None
Note over Neigh: skip cell-shift computation (None)
Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts=None)
Trans-->>Neigh: distances (no shifts)
else cell provided
Note over Neigh: compute & apply cell shifts
Neigh->>Trans: compute_cell_shifts(cell, shifts_idx, system_mapping)
Trans-->>Neigh: cell_shifts
Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
Trans-->>Neigh: distances (with shifts)
end
Neigh-->>Test: neighbor_pairs, distances, metadata
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
4ccee76 to
a5ee051
Compare
ee412da to
1b7e37c
Compare
b3eadc5 to
e9261e2
Compare
09edd54 to
b2b7b1a
Compare
fix more type issues fixing types for creating cell shifts fix type defining the fixture to use neighbor improvements wip type runners fix types in trajectory backup before thinking of messing with autobatching transforms is typed made test_io conform to types lint lint fixes fix transforms code fix safemask type revert trajectory file rm runners changes fix desc for fn ignore call-arg for pbc
b2b7b1a to
9ac309a
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (6)
tests/test_io.py (1)
264-274: ImportError tests: fixture-based calls are correct; minor hardening suggestionUsing real fixtures instead of None exercises imports earlier and is more realistic. To make the import failure simulation a bit more robust (and concise), consider monkeypatch.dict on sys.modules for the targets instead of multiple setitem calls.
Example:
- monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) + monkeypatch.dict(sys.modules, {"ase": None, "ase.data": None})Also applies to: 276-287, 289-300, 302-312, 314-325, 327-337
torch_sim/transforms.py (4)
352-355: wrap_positions: center scalar support is good—update docstring and keep type clarityAllowing a scalar center is ergonomic. Please update the “Args” doc to reflect tuple[float, float, float] | float, and avoid ambiguity.
- center (Tuple[float, float, float]): Center of the cell as - (x,y,z) tuple, defaults to (0.5, 0.5, 0.5). + center (tuple[float, float, float] | float): Center of the cell + as (x,y,z) tuple or scalar, defaults to (0.5, 0.5, 0.5).Also applies to: 377-381, 389-399
469-496: Prefer public torch.dtype instead of private torch.types._dtypeUsing the private _dtype type and a type: ignore can be avoided by switching to torch.dtype.
-from torch.types import _dtype ... -def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor: +def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: @@ - ) # type: ignore[call-overload] + )If _dtype isn’t used elsewhere, remove the import to keep things clean.
538-565: compute_cell_shifts: doc return type/shape mismatchFunction can return None and returns per-pair shifts (n_pairs, 3), not (n_systems, 3). Update the doc to match behavior.
- Returns: - torch.Tensor: A tensor of shape (n_systems, 3) containing - the computed cell shifts. + Returns: + torch.Tensor | None: A tensor of shape (n_pairs, 3) containing + the computed cell shifts, or None if cell is None.
1165-1169: safe_mask: avoid evaluating fn on masked-out valuesCurrent implementation may compute fn on zeros at masked-out positions (e.g., log(0) -> -inf), which is later hidden but can trigger warnings/NaNs in grads. Apply fn only where mask is True.
-def safe_mask( - mask: torch.Tensor, - fn: Callable[..., torch.Tensor], - operand: torch.Tensor, - placeholder: float = 0.0, -) -> torch.Tensor: +def safe_mask( + mask: torch.Tensor, + fn: Callable[..., torch.Tensor], + operand: torch.Tensor, + placeholder: float = 0.0, +) -> torch.Tensor: @@ - masked = torch.where(mask, operand, torch.zeros_like(operand)) - return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + out = torch.full_like(operand, placeholder) + if mask.any(): + out[mask] = fn(operand[mask]) + return outtests/test_transforms.py (1)
701-702: Passing 0-dim tensors for r_onset/r_cutoff matches the API; optional ergonomicsGood mypy-friendly fix. Optionally, we could make multiplicative_isotropic_cutoff accept float | Tensor and as_tensor internally to keep tests (and users) free to pass floats.
If desired, I can draft a minimal patch to transforms.multiplicative_isotropic_cutoff to accept floats without changing types at call sites.
Also applies to: 719-721, 744-746, 767-768, 785-786, 803-805
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tests/test_io.py(1 hunks)tests/test_neighbors.py(12 hunks)tests/test_transforms.py(7 hunks)torch_sim/neighbors.py(8 hunks)torch_sim/transforms.py(13 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-31T11:15:22.654Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.654Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.
Applied to files:
torch_sim/neighbors.py
🧬 Code graph analysis (4)
tests/test_io.py (2)
tests/conftest.py (5)
si_sim_state(142-144)si_atoms(89-91)device(24-25)si_phonopy_atoms(119-138)si_structure(101-115)torch_sim/io.py (6)
state_to_atoms(29-72)state_to_phonopy(131-177)state_to_structures(75-128)atoms_to_state(180-245)phonopy_to_state(316-391)structures_to_state(248-313)
torch_sim/neighbors.py (1)
torch_sim/transforms.py (1)
compute_cell_shifts_strict(567-588)
tests/test_neighbors.py (1)
torch_sim/neighbors.py (3)
standard_nl(411-497)vesin_nl(571-637)vesin_nl_ts(501-568)
tests/test_transforms.py (1)
torch_sim/transforms.py (2)
cutoff_fn(1120-1122)multiplicative_isotropic_cutoff(1067-1124)
🔇 Additional comments (16)
torch_sim/transforms.py (3)
498-536: compute_distances_with_cell_shifts: API widening looks goodcell_shifts: Optional[...] is a sensible, typed way to express “no shifts.” Logic and validation are clear.
567-589: compute_cell_shifts_strict: clear separation—LGTMNon-None contract is explicit and mirrors compute_cell_shifts behavior. Good addition.
1127-1162: high_precision_sum: widened dim types—LGTMAccepting int | list[int] | tuple[int, ...] | None aligns with torch.sum. Behavior preserved.
tests/test_transforms.py (1)
458-466: Loop bound cast to int avoids range TypeError—LGTMExplicit int() on tensor.count ensures compatibility with range and reads clearly.
tests/test_neighbors.py (9)
2-2: LGTM! Good addition of explicit typing import.The import of
Callablefromcollections.abcsupports the new type annotations in the test functions.
108-123: LGTM! Type annotations for fixture functions are clear and correct.The return type annotations
-> list[Atoms]for bothperiodic_atoms_set()andmolecule_atoms_set()fixtures properly document their return types and help with static type checking.Also applies to: 127-130
135-141: Good parameterization change for fixture handling.The switch from direct fixture usage to parameterized fixture names with
request.getfixturevalue()provides better test control and clearer test naming.Also applies to: 156-156
215-215: LGTM! Proper handling of distance comparison ordering.The change to sort distances before comparison (
np.sort()) ensures order-insensitive comparison, which is the correct approach for validating neighbor list implementations that may return neighbors in different orders.Also applies to: 241-242, 245-245, 249-249
254-265: LGTM! Consistent fixture parameterization pattern.The parameterization pattern is consistently applied across test functions, and the
Callabletype annotation fornl_implementationprovides proper typing.Also applies to: 273-273
292-292: LGTM! Consistent distance sorting for reliable comparisons.The sorted distance comparisons ensure that neighbor list implementations are validated correctly regardless of the order in which they return neighbor pairs.
Also applies to: 314-321
333-333: LGTM! Proper type annotation for callable parameter.The
Callabletype annotation fornl_implementationmaintains consistency with other test functions.
359-359: LGTM! Consistent sorting pattern maintained.The distance sorting approach is consistently applied across all test functions for reliable neighbor list validation.
Also applies to: 372-372, 375-375
568-575: All test call signatures align with the current neighbor-list definitions; no annotation mismatches detected.torch_sim/neighbors.py (3)
177-177: LGTM! Appropriate type ignore annotations for PyTorch overload resolution.The
# type: ignore[call-overload]annotations are correctly applied to PyTorch API calls where the type checker cannot resolve the correct overload. These are necessary for maintaining type safety while avoiding performance-degrading.item()calls, as mentioned in the retrieved learnings.Also applies to: 196-197, 230-231, 247-249
366-369: LGTM! Variable renaming improves code clarity.The renaming from
bin_cnttobin_cnt_sort_idxbetter reflects the variable's purpose as sorting indices, making the code more self-documenting.
643-643: LGTM! Clean handling of optional cell parameter.The addition of
cell: torch.Tensor | Noneparameter and the conditional logic properly handle both periodic and non-periodic systems:
- When
cellis None, distances are computed directly without cell shifts- When
cellis provided, it uses the newcompute_cell_shifts_strictfunctionThe implementation maintains backward compatibility while extending functionality.
Also applies to: 661-662, 692-697
165027d to
9ac309a
Compare
| # Convert center to tensor | ||
| if not hasattr(center, "__len__"): | ||
| center = (center,) * 3 | ||
| center = torch.tensor(center, dtype=positions.dtype, device=device) | ||
| center_pos = torch.tensor((center,) * 3, dtype=positions.dtype, device=device) | ||
| else: | ||
| center_pos = torch.tensor(center, dtype=positions.dtype, device=device) |
There was a problem hiding this comment.
This feels like a vague way to check this? maybe an improvement to do isinstance(x, float) and isinstance(x, tuple) and len(x)==3 and all(isinstance(y, float) for y in x)?
There was a problem hiding this comment.
I think isinstance(x, float) is sufficient. because center is defined as tuple[float, float, float] | float
| return cell_shifts | ||
|
|
||
|
|
||
| def compute_cell_shifts_strict( |
There was a problem hiding this comment.
why not call this function from inside compute_cell_shifts? would be more DRY
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
torch_sim/transforms.py (1)
558-570: Excellent separation of optional and strict variantsThis is a well-designed solution to handle mypy type checking while maintaining flexibility:
compute_cell_shiftshandles the optional case by returning None when cell is Nonecompute_cell_shifts_strictprovides a non-optional variant for internal use where cell is guaranteed to be non-NoneHowever, I notice there's a past review comment suggesting to call
compute_cell_shifts_strictfrom insidecompute_cell_shiftsto be more DRY. The current implementation already does this correctly on line 560.
🧹 Nitpick comments (1)
torch_sim/transforms.py (1)
492-492: Address mypy type check warningThe type ignore comment suggests mypy is complaining about the
torch.arangecall. This is likely due to thedtypeparameter type annotation.Consider checking if this can be resolved without the type ignore:
r1 = torch.arange( -num_repeats[ii], num_repeats[ii] + 1, device=num_repeats.device, dtype=dtype, )If the issue persists, the type ignore is acceptable but consider adding a more descriptive comment explaining why it's needed.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/transforms.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/transforms.py (1)
tests/test_transforms.py (3)
test_wrap_positions_matches_ase(200-217)test_wrap_positions_basic(220-227)test_compute_distances_with_cell_shifts(1168-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (33)
- GitHub Check: test-examples (examples/scripts/7_Others/7.2_Stress_autograd.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: build-docs
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
🔇 Additional comments (9)
torch_sim/transforms.py (9)
8-8: LGTM - Good addition for type annotation supportThe import of
Callablefromcollections.abcis appropriate for the type annotations used in function parameters later in the file.
377-380: Implementation handles both scalar and tuple center values correctlyThe logic correctly handles both cases:
- Single float: creates a 3-tuple with the same value repeated
- Tuple: uses the provided tuple directly
The implementation preserves device and dtype consistency.
389-389: Consistent use of center_pos variableGood refactoring to use the
center_posvariable consistently instead of the originalcenterparameter. This ensures type safety and consistent behavior regardless of the input format.Also applies to: 397-397
501-501: Good type annotation for optional cell_shifts parameterThe addition of
| Noneto the type annotation correctly reflects thatcell_shiftscan be None, improving type safety.Also applies to: 517-517
539-540: Good type annotation updates for optional cell parameterThe type annotations correctly indicate that both the
cellparameter and return value can be None.
857-857: Consistent use of strict variant where cell is guaranteed non-NoneGood choice to use
compute_cell_shifts_stricthere since the cell parameter is guaranteed to be non-None in this context (it's accessed viacell.view(-1, 3, 3)).
1110-1110: Improved type flexibility for dimension parameterThe expanded type annotation
int | list[int] | tuple[int, ...] | Nonebetter reflects the actual accepted types for PyTorch's sum function, improving type safety.
1147-1147: More flexible callable type annotationChanging from
torch.jit.ScriptFunctiontoCallable[..., torch.Tensor]is a good improvement that:
- Increases flexibility by accepting any callable that returns a tensor
- Improves type checking compatibility
- Maintains the same functional requirements
Also applies to: 1159-1159
352-352: Approve — backward-compatibility forcenterverifiedTuple inputs remain supported (converted with torch.tensor) and the default is unchanged; no call sites in the repo pass
center=explicitly. No action required.
Summary
the test_io, neighbors, and transform files now pass mypy.
The main issue that was revealed to me is that pbc is inconsistent in the codebase. we use bools, tuple of bools, tensors, etc. we should just move to a tensor of bools later. but ik we don't support axis-aware pbc yet so that's for a future PR.
Checklist
Before a pull request can be merged, the following items must be checked:
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit installto install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Tests
Refactor