From 006c57383145df2b614ce448e05a5d1368f76f00 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 15:27:16 -0700 Subject: [PATCH 01/40] big breaking API redesign --- .github/workflows/lint.yml | 24 +- CHANGELOG.md | 82 +- CONTRIBUTING.md | 8 - LICENSE | 2 +- README.md | 31 +- citation.cff | 6 +- docs/_static/draw_pkg_treemap.py | 7 +- docs/about/license.md | 2 +- docs/conf.py | 10 +- docs/dev/add_model.md | 6 +- docs/dev/dev_install.md | 10 +- docs/index.md | 6 +- docs/reference/index.rst | 2 +- docs/tutorials/index.rst | 4 +- docs/user/introduction.rst | 2 +- examples/readme.md | 7 +- .../1_Introduction/1.1_Lennard_Jones.py | 12 +- examples/scripts/1_Introduction/1.2_MACE.py | 15 +- .../scripts/1_Introduction/1.3_Fairchem.py | 37 +- .../2.1_Lennard_Jones_FIRE.py | 19 +- .../2.2_Soft_Sphere_FIRE.py | 19 +- .../2.3_MACE_Gradient_Descent.py | 26 +- .../2.4_MACE_FIRE.py | 22 +- ....5_MACE_UnitCellFilter_Gradient_Descent.py | 31 +- .../2.6_MACE_UnitCellFilter_FIRE.py | 25 +- .../2.7_MACE_FrechetCellFilter_FIRE.py | 24 +- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 44 +- .../3.11_Lennard_Jones_NPT_Langevin.py | 44 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 60 +- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 33 +- .../3_Dynamics/3.1_Lennard_Jones_NVE.py | 27 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 28 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 28 +- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 38 +- .../3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py | 29 +- .../3.6_MACE_NVT_Nose_Hoover_temp_profile.py | 26 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 38 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 64 +- .../3.9_MACE_NVT_staggered_stress.py | 24 +- .../4_High_level_api/4.1_high_level_api.py | 43 +- .../4_High_level_api/4.2_auto_batching_api.py | 56 +- .../5_Workflow/5.1_a2c_silicon_batched.py | 77 +- .../scripts/5_Workflow/5.2_In_Flight_WBM.py | 21 +- examples/scripts/5_Workflow/5.3_Elastic.py | 21 +- .../scripts/6_Phonons/6.1_Phonons_MACE.py | 18 +- .../6_Phonons/6.2_QuasiHarmonic_MACE.py | 51 +- .../6_Phonons/6.3_Conductivity_MACE.py | 36 +- .../7_Others/7.1_Soft_sphere_autograd.py | 11 +- .../scripts/7_Others/7.2_Stress_autograd.py | 13 +- .../7_Others/7.3_Batched_neighbor_list.py | 8 +- .../7_Others/7.4_Velocity_AutoCorrelation.py | 10 +- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 148 +- .../7_Others/7.7_Heat_flux_and_kappa.py | 151 ++ examples/tutorials/autobatching_tutorial.py | 52 +- examples/tutorials/diff_sim.py | 448 ++++ examples/tutorials/high_level_tutorial.py | 49 +- examples/tutorials/hybrid_swap_tutorial.py | 44 +- examples/tutorials/low_level_tutorial.py | 75 +- examples/tutorials/metatomic_tutorial.py | 9 +- examples/tutorials/reporting_tutorial.py | 19 +- examples/tutorials/state_tutorial.py | 26 +- examples/tutorials/using_graphpes_tutorial.py | 10 +- pyproject.toml | 60 +- tests/conftest.py | 169 +- tests/models/conftest.py | 21 +- tests/models/test_fairchem.py | 259 +- tests/models/test_graphpes.py | 95 +- tests/models/test_lennard_jones.py | 14 +- tests/models/test_mace.py | 81 +- tests/models/test_mattersim.py | 45 +- tests/models/test_metatomic.py | 21 +- tests/models/test_nequip_framework.py | 81 + tests/models/test_orb.py | 42 +- tests/models/test_sevennet.py | 55 +- tests/models/test_soft_sphere.py | 65 +- tests/test_autobatching.py | 86 +- tests/test_correlations.py | 133 +- tests/test_elastic.py | 49 +- tests/test_integrators.py | 200 +- tests/test_io.py | 126 +- tests/test_math.py | 157 +- tests/test_monte_carlo.py | 253 +- tests/test_neighbors.py | 256 +- tests/test_optimizer_states.py | 59 + tests/test_optimizers.py | 373 +-- tests/test_optimizers_vs_ase.py | 155 +- tests/test_quantities.py | 273 +- tests/test_runners.py | 231 +- tests/test_state.py | 232 +- tests/test_trajectory.py | 61 +- tests/test_transforms.py | 285 +- tests/test_voigt.py | 6 +- tests/workflows/test_a2c.py | 191 +- torch_sim/__init__.py | 60 +- torch_sim/autobatching.py | 241 +- torch_sim/elastic.py | 135 +- torch_sim/integrators/__init__.py | 59 +- torch_sim/integrators/md.py | 48 +- torch_sim/integrators/npt.py | 2310 ++++++++--------- torch_sim/integrators/nve.py | 186 +- torch_sim/integrators/nvt.py | 577 ++-- torch_sim/io.py | 80 +- torch_sim/math.py | 175 +- torch_sim/models/fairchem.py | 388 +-- torch_sim/models/graphpes.py | 62 +- torch_sim/models/interface.py | 16 +- torch_sim/models/lennard_jones.py | 37 +- torch_sim/models/mace.py | 61 +- torch_sim/models/mattersim.py | 17 +- torch_sim/models/metatomic.py | 49 +- torch_sim/models/morse.py | 55 +- torch_sim/models/nequip_framework.py | 379 +++ torch_sim/models/orb.py | 63 +- torch_sim/models/particle_life.py | 18 +- torch_sim/models/sevennet.py | 45 +- torch_sim/models/soft_sphere.py | 37 +- torch_sim/monte_carlo.py | 216 +- torch_sim/neighbors.py | 77 +- torch_sim/optimizers.py | 1743 ------------- torch_sim/optimizers/__init__.py | 38 + torch_sim/optimizers/cell_filters.py | 381 +++ torch_sim/optimizers/fire.py | 473 ++++ torch_sim/optimizers/gradient_descent.py | 129 + torch_sim/optimizers/state.py | 48 + torch_sim/properties/correlations.py | 121 +- torch_sim/quantities.py | 173 +- torch_sim/runners.py | 255 +- torch_sim/state.py | 226 +- torch_sim/trajectory.py | 24 +- torch_sim/transforms.py | 99 +- torch_sim/typing.py | 26 +- torch_sim/workflows/a2c.py | 219 +- 132 files changed, 7991 insertions(+), 7807 deletions(-) create mode 100644 examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py create mode 100644 examples/tutorials/diff_sim.py create mode 100644 tests/models/test_nequip_framework.py create mode 100644 tests/test_optimizer_states.py create mode 100644 torch_sim/models/nequip_framework.py delete mode 100644 torch_sim/optimizers.py create mode 100644 torch_sim/optimizers/__init__.py create mode 100644 torch_sim/optimizers/cell_filters.py create mode 100644 torch_sim/optimizers/fire.py create mode 100644 torch_sim/optimizers/gradient_descent.py create mode 100644 torch_sim/optimizers/state.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 604ce1ee0..3c17bc5a8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,26 +6,12 @@ on: pull_request: branches: [main] -permissions: - contents: read - -concurrency: - group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} - cancel-in-progress: true - jobs: - lint: + prek: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - - name: Install dependencies - run: pip install pre-commit + - name: Check out repo + uses: actions/checkout@v5 - - name: Run pre-commit - run: pre-commit run --all-files --show-diff-on-failure + - name: Run prek + uses: j178/prek-action@v1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 12046fe22..60b74931b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,89 +6,89 @@ ### πŸ’₯ Breaking Changes -* Remove higher level model imports by @CompRhys in https://github.com/TorchSim/torch-sim/pull/179 +* Remove higher level model imports by @CompRhys in #179 ### πŸ›  Enhancements -* Add per atom energies and stresses for batched LJ by @abhijeetgangan in https://github.com/TorchSim/torch-sim/pull/144 -* throw error if autobatcher type is wrong by @orionarcher in https://github.com/TorchSim/torch-sim/pull/167 +* Add per atom energies and stresses for batched LJ by @abhijeetgangan in #144 +* throw error if autobatcher type is wrong by @orionarcher in #167 ### πŸ› Bug Fixes -* Fix column->row cell vector mismatch in integrators by @CompRhys in https://github.com/TorchSim/torch-sim/pull/175 -* Mattersim fix tensors on wrong device (CPU->GPU) by @orionarcher in https://github.com/TorchSim/torch-sim/pull/154 -* fix `npt_langevin` by @jla-gardner in https://github.com/TorchSim/torch-sim/pull/153 -* Make sure to move data to CPU before calling vesin by @Luthaf in https://github.com/TorchSim/torch-sim/pull/156 -* Fix virial calculations in `optimizers` and `integrators` by @janosh in https://github.com/TorchSim/torch-sim/pull/163 -* Pad memory estimation by @orionarcher in https://github.com/TorchSim/torch-sim/pull/160 -* Refactor sevennet model by @YutackPark in https://github.com/TorchSim/torch-sim/pull/172 -* `io` optional dependencies in `pyproject.toml` by @curtischong in https://github.com/TorchSim/torch-sim/pull/185 +* Fix column->row cell vector mismatch in integrators by @CompRhys in #175 +* Mattersim fix tensors on wrong device (CPU->GPU) by @orionarcher in #154 +* fix `npt_langevin` by @jla-gardner in #153 +* Make sure to move data to CPU before calling vesin by @Luthaf in #156 +* Fix virial calculations in `optimizers` and `integrators` by @janosh in #163 +* Pad memory estimation by @orionarcher in #160 +* Refactor sevennet model by @YutackPark in #172 +* `io` optional dependencies in `pyproject.toml` by @curtischong in #185 ### πŸ“– Documentation -* (tiny) add graph-pes to README by @jla-gardner in https://github.com/TorchSim/torch-sim/pull/149 -* Better module fig by @janosh in https://github.com/TorchSim/torch-sim/pull/168 +* (tiny) add graph-pes to README by @jla-gardner in #149 +* Better module fig by @janosh in #168 ### πŸš€ Performance -* More efficient Orb `state_to_atoms_graph` calculation by @AdeeshKolluru in https://github.com/TorchSim/torch-sim/pull/165 +* More efficient Orb `state_to_atoms_graph` calculation by @AdeeshKolluru in #165 ### 🚧 CI -* Refactor `test_math.py` and `test_transforms.py` by @janosh in https://github.com/TorchSim/torch-sim/pull/151 +* Refactor `test_math.py` and `test_transforms.py` by @janosh in #151 ### πŸ₯ Package Health -* Try out hatchling for build vs setuptools by @CompRhys in https://github.com/TorchSim/torch-sim/pull/177 +* Try out hatchling for build vs setuptools by @CompRhys in #177 ### 🏷️ Type Hints -* Add `torch_sim/typing.py` by @janosh in https://github.com/TorchSim/torch-sim/pull/157 +* Add `torch-sim/typing.py` by @janosh in #157 ### πŸ“¦ Dependencies -* Bump `mace-torch` to v0.3.12 by @janosh in https://github.com/TorchSim/torch-sim/pull/170 -* Update metatrain dependency by @Luthaf in https://github.com/TorchSim/torch-sim/pull/186 +* Bump `mace-torch` to v0.3.12 by @janosh in #170 +* Update metatrain dependency by @Luthaf in #186 ## New Contributors -* @Luthaf made their first contribution in https://github.com/TorchSim/torch-sim/pull/156 -* @YutackPark made their first contribution in https://github.com/TorchSim/torch-sim/pull/172 -* @curtischong made their first contribution in https://github.com/TorchSim/torch-sim/pull/185 +* @Luthaf made their first contribution in #156 +* @YutackPark made their first contribution in #172 +* @curtischong made their first contribution in #185 -**Full Changelog**: https://github.com/TorchSim/torch-sim/compare/v0.2.0...v0.2.1 +**Full Changelog**: https://github.com/torchsim/torch-sim/compare/v0.2.0...v0.2.1 ## v0.2.0 ### Bug Fixes πŸ› -* Fix integrate reporting kwarg to arg error, https://github.com/TorchSim/torch-sim/issues/113 (raised by @hn-yu) -* Allow runners to take large initial batches, https://github.com/TorchSim/torch-sim/issues/128 (raised by @YutackPark) -* Add Fairchem model support for PBC, https://github.com/TorchSim/torch-sim/issues/111 (raised by @ryanliu30) +* Fix integrate reporting kwarg to arg error, #113 (raised by @hn-yu) +* Allow runners to take large initial batches, #128 (raised by @YutackPark) +* Add Fairchem model support for PBC, #111 (raised by @ryanliu30) ### Enhancements πŸ›  -* **breaking** Rename `HotSwappingAutobatcher` to `InFlightAutobatcher` and `ChunkingAutoBatcher` to `BinningAutoBatcher`, https://github.com/TorchSim/torch-sim/pull/143 @orionarcher -* Support for Orbv3, https://github.com/TorchSim/torch-sim/pull/140, @AdeeshKolluru -* Support metatensor models, https://github.com/TorchSim/torch-sim/pull/141, @frostedoyter @Luthaf -* Support for graph-pes models, https://github.com/TorchSim/torch-sim/pull/118 @jla-gardner -* Support MatterSim and fix ASE cell convention issues, https://github.com/TorchSim/torch-sim/pull/112 @CompRhys -* Implement positions only FIRE optimization, https://github.com/TorchSim/torch-sim/pull/139 @abhijeetgangan -* Allow different temperatures in batches, https://github.com/TorchSim/torch-sim/pull/123 @orionarcher -* FairChem model updates: PBC handling, test on OMat24 e-trained model, https://github.com/TorchSim/torch-sim/pull/126 @AdeeshKolluru -* FairChem model from_data_list support, https://github.com/TorchSim/torch-sim/pull/138 @ryanliu30 -* New correlation function module, https://github.com/TorchSim/torch-sim/pull/115 @stefanbringuier +* **breaking** Rename `HotSwappingAutobatcher` to `InFlightAutobatcher` and `ChunkingAutoBatcher` to `BinningAutoBatcher`, #143 @orionarcher +* Support for Orbv3, #140, @AdeeshKolluru +* Support metatensor models, #141, @frostedoyter @Luthaf +* Support for graph-pes models, #118 @jla-gardner +* Support MatterSim and fix ASE cell convention issues, #112 @CompRhys +* Implement positions only FIRE optimization, #139 @abhijeetgangan +* Allow different temperatures in batches, #123 @orionarcher +* FairChem model updates: PBC handling, test on OMat24 e-trained model, #126 @AdeeshKolluru +* FairChem model from_data_list support, #138 @ryanliu30 +* New correlation function module, #115 @stefanbringuier ### Documentation πŸ“– -* Improved model documentation, https://github.com/TorchSim/torch-sim/pull/121 @orionarcher -* Plot of TorchSim module graph in docs, https://github.com/TorchSim/torch-sim/pull/132 @janosh +* Improved model documentation, #121 @orionarcher +* Plot of TorchSim module graph in docs, #132 @janosh ### House-Keeping 🧹 -* Only install HF for fairchem tests, https://github.com/TorchSim/torch-sim/pull/134 @CompRhys -* Don't download MBD in CI, https://github.com/TorchSim/torch-sim/pull/135 @orionarcher -* Tighten graph-pes test bounds, https://github.com/TorchSim/torch-sim/pull/143 @orionarcher +* Only install HF for fairchem tests, #134 @CompRhys +* Don't download MBD in CI, #135 @orionarcher +* Tighten graph-pes test bounds, #143 @orionarcher ## v0.1.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5163c72b3..7d731a12a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,14 +2,6 @@ TorchSim is an experimental library and we would appreciate any feedback from the community. -## Contributor License Agreement (CLA) - -Before contributing, you'll need to sign our Contributor License Agreement (CLA). This is a one-time requirement that covers all Radical AI open source projects. The CLA allows you to maintain ownership of your contributions while granting Radical AI the necessary rights to use them. - -[Radical AI CLA](https://www.radical-ai.com/oss) - -Our CLA-bot will automatically verify your signature on pull requests. For questions about the CLA, contact cla@radical-ai.com. - ## Code Reviews All submissions require review by project maintainers before merging: diff --git a/LICENSE b/LICENSE index 6573a2ab4..33225395d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,5 @@ The MIT License (MIT) -Copyright 2025 Radical AI +Copyright 2025 Project TorchSim Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index c1532874c..ef1025167 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # TorchSim -[![CI](https://github.com/TorchSim/torch-sim/actions/workflows/test.yml/badge.svg)](https://github.com/TorchSim/torch-sim/actions/workflows/test.yml) -[![codecov](https://codecov.io/gh/radical-ai/torch-sim/branch/main/graph/badge.svg)](https://codecov.io/gh/radical-ai/torch-sim) -[![This project supports Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads) -[![PyPI](https://img.shields.io/pypi/v/torch_sim_atomistic?logo=pypi&logoColor=white)](https://pypi.org/project/torch_sim_atomistic) +[![CI](https://github.com/torchsim/torch-sim/actions/workflows/test.yml/badge.svg)](https://github.com/torchsim/torch-sim/actions/workflows/test.yml) +[![codecov](https://codecov.io/gh/torchsim/torch-sim/branch/main/graph/badge.svg)](https://codecov.io/gh/torchsim/torch-sim) +[![This project supports Python 3.12+](https://img.shields.io/badge/Python-3.12+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads) +[![PyPI](https://img.shields.io/pypi/v/torch-sim-atomistic?logo=pypi&logoColor=white)](https://pypi.org/project/torch-sim-atomistic) [![Zenodo](https://img.shields.io/badge/Zenodo-15127004-blue?logo=Zenodo&logoColor=white)][zenodo] [zenodo]: https://zenodo.org/records/15127004 @@ -81,27 +81,14 @@ To then relax those structures with FIRE is just a few more lines. relaxed_state = ts.optimize( system=final_state, model=mace_model, - optimizer=ts.frechet_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, autobatcher=True, ) print(relaxed_state.energy) ``` -## Speedup - -TorchSim achieves up to 100x speedup compared to ASE with popular MLIPs. - -Speedup comparison - -This figure compares the time per atom of ASE and `torch_sim`. Time per atom is defined -as the number of atoms / total time. While ASE can only run a single system of `n_atoms` -(on the $x$ axis), `torch_sim` can run as many systems as will fit in memory. On an H100 80 GB card, -the max atoms that could fit in memory was ~8,000 for [EGIP](https://github.com/FAIR-Chem/fairchem), -~10,000 for [MACE-MPA-0](https://github.com/ACEsuit/mace), ~22,000 for [Mattersim V1 1M](https://github.com/microsoft/mattersim), -~2,500 for [SevenNet](https://github.com/MDIL-SNU/SevenNet), and ~9000 for [PET-MAD](https://github.com/lab-cosmo/pet-mad). -This metric describes model performance by capturing speed and memory usage simultaneously. - ## Installation ### PyPI Installation @@ -113,7 +100,7 @@ pip install torch-sim-atomistic ### Installing from source ```sh -git clone https://github.com/TorchSim/torch-sim +git clone https://github.com/torchsim/torch-sim cd torch-sim pip install . ``` @@ -126,11 +113,11 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https TorchSim's package structure is summarized in the [API reference](https://torchsim.github.io/torch-sim/reference/index.html) documentation and drawn as a treemap below. -![TorchSim package treemap](https://github.com/user-attachments/assets/1ccb3a15-233d-4bc0-b11c-35a676a2bcf3) +![TorchSim package treemap](https://github.com/user-attachments/assets/56f894ad-b995-4108-a6de-a48714276d89) ## License -TorchSim is released under an [MIT license](LICENSE). +TorchSim is released under an [MIT license](license). ## Citation diff --git a/citation.cff b/citation.cff index 16faa6b73..9ac5db9b8 100644 --- a/citation.cff +++ b/citation.cff @@ -15,8 +15,8 @@ authors: - family-names: Falletta given-names: Stefano license: MIT -license-url: https://github.com/TorchSim/torch-sim/blob/main/LICENSE -repository-code: https://github.com/TorchSim/torch-sim -url: https://github.com/TorchSim/torch-sim +license-url: https://github.com/torchsim/torch-sim/blob/main/license +repository-code: https://github.com/torchsim/torch-sim +url: https://github.com/torchsim/torch-sim type: software date-released: 2025-04-02 diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index 44e762e74..d10b88b8e 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -1,13 +1,10 @@ -"""Draw a treemap of the torch_sim package structure. +"""Draw a treemap of the torch-sim package structure. Run with `uv run docs/_static/draw_pkg_treemap.py` """ # /// script -# dependencies = [ -# "pymatviz>=0.17.1", -# "plotly>=6.3.0", -# ] +# dependencies = ["pymatviz>=0.17.1", "plotly>=6.3.0"] # /// import os diff --git a/docs/about/license.md b/docs/about/license.md index 8354221ec..8aacb2b17 100644 --- a/docs/about/license.md +++ b/docs/about/license.md @@ -1,7 +1,7 @@ # License The MIT License (MIT) -Copyright 2025 Radical AI +Copyright 2025 Project TorchSim Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/docs/conf.py b/docs/conf.py index d127c7715..d2fbc6fce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,8 +17,8 @@ # -- Project information ----------------------------------------------------- -project = "torch-sim-atomistic" -copyright = "2025, Radical AI" # noqa: A001 +project = "torch_sim" +copyright = "2025, Project TorchSim" # noqa: A001 author = "Abhijeet Gangan, Orion Cohen, Janosh Riebesell" # The short X.Y version @@ -156,7 +156,7 @@ "footer_icons": [ { "name": "GitHub", - "url": "https://github.com/TorchSim/torch-sim", + "url": "https://github.com/torchsim/torch-sim", "html": """ @@ -165,7 +165,7 @@ "class": "", }, ], - "source_repository": "https://github.com/TorchSim/torch-sim/", + "source_repository": "https://github.com/torchsim/torch-sim", "source_branch": "main", "source_directory": "docs/", } @@ -174,7 +174,7 @@ # hide sphinx footer html_show_sphinx = False html_show_sourcelink = False -html_title = "torch-sim" +html_title = "TorchSim" # -- Options for intersphinx extension --------------------------------------- diff --git a/docs/dev/add_model.md b/docs/dev/add_model.md index 3b5c9052d..11aa6481b 100644 --- a/docs/dev/add_model.md +++ b/docs/dev/add_model.md @@ -1,12 +1,10 @@ # Adding New Models -## How to add a new model to torchsim +## How to add a new model to TorchSim We welcome the addition of new models to `torch_sim`. We want easy batched simulations to be available to the whole community of MLIP developers and users. -See https://github.com/TorchSim/torch-sim/discussions/120 for -our current posture on adding models to TorchSim. 1. Open a PR or an issue to get feedback. We are happy to take a look, even if you haven't finished your implementation yet. @@ -35,4 +33,4 @@ is being correctly included in the documentation. We are also happy for developers to implement model interfaces in their own codebases. Steps 1 & 2 should still be followed to ensure the model -implementation is compatible with the rest of torch-sim. +implementation is compatible with the rest of TorchSim. diff --git a/docs/dev/dev_install.md b/docs/dev/dev_install.md index 273bbb615..a08aa1239 100644 --- a/docs/dev/dev_install.md +++ b/docs/dev/dev_install.md @@ -4,7 +4,7 @@ You can install TorchSim with `pip` or from source. ## Install using pip -You can install the basic functionality of torch-sim using pip: +You can install the basic functionality of TorchSim using pip: ```bash pip install torch-sim-atomistic @@ -12,10 +12,10 @@ pip install torch-sim-atomistic ## Install from source -To install torch-sim from source, clone the repository from [github](https://github.com/TorchSim/torch-sim) +To install TorchSim from source, clone the repository from [github](https://github.com/torchsim/torch-sim) ```bash -git clone https://github.com/TorchSim/torch-sim +git clone https://github.com/torchsim/torch-sim cd torch-sim pip install . -e ``` @@ -32,7 +32,7 @@ pre-commit run --all-files ``` The `pre-commit` command will ensure that changes to the source code match the -torch-sim style guidelines by running code linters such as `black` and `ruff` automatically with each commit. +TorchSim style guidelines by running code linters such as `black` and `ruff` automatically with each commit. ## Running unit tests @@ -51,7 +51,7 @@ pytest ## Building the documentation -The torch-sim documentation can be built using the sphinx package. First, install the requirements: +The TorchSim documentation can be built using the sphinx package. First, install the requirements: ```bash pip install .[docs] diff --git a/docs/index.md b/docs/index.md index a7dc91d99..32fc04f4a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,7 @@ :hidden: user/introduction user/overview +user/models tutorials/index ``` @@ -18,6 +19,7 @@ reference/index :hidden: dev/dev_install dev/add_model +dev/batching ``` ```{toctree} @@ -33,8 +35,8 @@ about/license **Date**: {sub-ref}`today` **Useful links**: -[Source Repository](https://github.com/TorchSim/torch-sim) | -[Issues & Ideas](https://github.com/TorchSim/torch-sim/issues) +[Source Repository](https://github.com/torchsim/torch-sim) | +[Issues & Ideas](https://github.com/torchsim/torch-sim/issues) TorchSim is a next-generation open-source atomistic simulation engine for the MLIP era. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index a21d7418b..2ac6fcaa5 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -5,7 +5,7 @@ API reference Overview of the TorchSim API. -.. currentmodule:: torch_sim +.. currentmodule:: torch-sim .. autosummary:: :recursive: diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index fd3ca21c5..7ea9a8370 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -5,9 +5,9 @@ Tutorials For a high-level overview of the tutorials see :doc:`../user/overview`. Runnable versions of the tutorials can also be found in the `torch-sim /examples/tutorials -`_ directory. +`_ directory. -.. currentmodule:: torch_sim +.. currentmodule:: torch-sim .. toctree:: :titlesonly: diff --git a/docs/user/introduction.rst b/docs/user/introduction.rst index 67833f0e1..910a3b816 100644 --- a/docs/user/introduction.rst +++ b/docs/user/introduction.rst @@ -3,6 +3,6 @@ Introduction ============ -.. include:: ../../README.md +.. include:: ../../readme.md :start-after: :parser: myst_parser.sphinx_ diff --git a/examples/readme.md b/examples/readme.md index f799f48fe..7e78979eb 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -1,9 +1,6 @@ -## Types of Examples +## Examples -Examples are provided in two foms: - -* Tutorials are intended to provide pedagogical walkthroughs of TorchSim's core functionality -* Scripts are a holdover from early torch-sim development, they are not currently recommended as a learning resource. See issue [issue 109](https://github.com/TorchSim/torch-sim/issues/109). +Tutorials are intended to provide pedagogical walkthroughs of TorchSim's core functionality ## Tutorial Formatting diff --git a/examples/scripts/1_Introduction/1.1_Lennard_Jones.py b/examples/scripts/1_Introduction/1.1_Lennard_Jones.py index cc91070c7..76d49d590 100644 --- a/examples/scripts/1_Introduction/1.1_Lennard_Jones.py +++ b/examples/scripts/1_Introduction/1.1_Lennard_Jones.py @@ -1,9 +1,8 @@ """Lennard-Jones simple single system example.""" + # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// import itertools @@ -14,7 +13,7 @@ # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Create face-centered cubic (FCC) Argon @@ -75,10 +74,7 @@ # Batched state state = dict( - positions=positions, - cell=cell.unsqueeze(0), - atomic_numbers=atomic_numbers, - pbc=True, + positions=positions, cell=cell.unsqueeze(0), atomic_numbers=atomic_numbers, pbc=True ) # Run the simulation and get results diff --git a/examples/scripts/1_Introduction/1.2_MACE.py b/examples/scripts/1_Introduction/1.2_MACE.py index f627bb5db..ee3f23f26 100644 --- a/examples/scripts/1_Introduction/1.2_MACE.py +++ b/examples/scripts/1_Introduction/1.2_MACE.py @@ -1,9 +1,8 @@ """Minimal MACE batched example.""" + # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// import numpy as np @@ -15,20 +14,18 @@ # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 - # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load the compiled model from the local file -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_Fairchem.py index b6f8dd5b9..23175c033 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_Fairchem.py @@ -1,14 +1,8 @@ -# ruff: noqa: E501 """Minimal FairChem example demonstrating batching.""" # /// script -# dependencies = [ -# "fairchem-core==1.10.0", -# ] +# dependencies = ["fairchem-core>=2.2.0"] # /// - -import sys - import torch from ase.build import bulk @@ -16,41 +10,38 @@ from torch_sim.models.fairchem import FairChemModel -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 -try: - from fairchem.core.models.model_registry import model_name_to_local_file -except ImportError: - print("Skipping example due to missing fairchem dependency") - sys.exit(0) - -MODEL_PATH = model_name_to_local_file( - "EquiformerV2-31M-S2EF-OC20-All+MD", local_cache="." -) +# UMA = Unified Machine Learning for Atomistic simulations +MODEL_NAME = "uma-s-1" # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2)) atomic_numbers = si_dc.get_atomic_numbers() model = FairChemModel( - model=MODEL_PATH, + model=None, + model_name=MODEL_NAME, + task_name="omat", # Open Materials task for crystalline systems cpu=False, - seed=0, ) atoms_list = [si_dc, si_dc] -state = ts.io.atoms_to_state(atoms_list) +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) results = model(state) print(results["energy"].shape) print(results["forces"].shape) -print(results["stress"].shape) +if stress := results.get("stress"): + print(stress.shape) print(f"Energy: {results['energy']}") print(f"Forces: {results['forces']}") -print(f"Stress: {results['stress']}") +if stress := results.get("stress"): + print(f"{stress=}") # Check if the energy, forces, and stress are the same for the Si system across the batch print(torch.max(torch.abs(results["energy"][0] - results["energy"][1]))) print(torch.max(torch.abs(results["forces"][0] - results["forces"][1]))) -print(torch.max(torch.abs(results["stress"][0] - results["stress"][1]))) +if stress := results.get("stress"): + print(torch.max(torch.abs(stress[0] - stress[1]))) diff --git a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py index 9ddd8ba72..507367c87 100644 --- a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py @@ -1,11 +1,8 @@ """Lennard-Jones FIRE optimization.""" # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import itertools import os @@ -13,11 +10,10 @@ import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.optimizers import fire # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Set up the random number generator @@ -85,7 +81,7 @@ ) # Create state with batch dimension -state = ts.state.SimState( +state = ts.SimState( positions=positions, masses=masses, cell=cell.unsqueeze(0), @@ -97,19 +93,14 @@ results = model(state) # Initialize FIRE optimizer -fire_init, fire_update = fire( - model=model, - dt_start=0.005, - dt_max=0.01, -) +state = ts.fire_init(model=model, state=state, dt_start=0.005) -state = fire_init(state=state) # Run optimization for N_steps for step in range(N_steps): if step % 100 == 0: print(f"{step=}: Potential energy: {state.energy[0].item()} eV") - state = fire_update(state) + state = ts.fire_step(model, state, dt_max=0.01) # Print max force after optimization print(f"Initial energy: {results['energy'][0].item()} eV") diff --git a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py index 3956c9564..61eedb24e 100644 --- a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py @@ -1,11 +1,8 @@ """Structural optimization with soft sphere potential using FIRE optimizer.""" # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import itertools import os @@ -13,11 +10,10 @@ import torch_sim as ts from torch_sim.models.soft_sphere import SoftSphereModel -from torch_sim.optimizers import fire # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float64 # Set up the random number generator @@ -76,7 +72,7 @@ masses = torch.full((positions.shape[0],), 63.546, device=device, dtype=dtype) # Create state with batch dimension -state = ts.state.SimState( +state = ts.SimState( positions=positions, masses=masses, cell=cell.unsqueeze(0), @@ -98,19 +94,14 @@ results = model(state) # Initialize FIRE optimizer -fire_init, fire_update = fire( - model=model, - dt_start=0.005, - dt_max=0.01, -) +state = ts.fire_init(model=model, state=state, dt_start=0.005) -state = fire_init(state=state) # Run optimization for N_steps for step in range(N_steps): if step % 100 == 0: print(f"{step=}: Total energy: {state.energy[0].item()} eV") - state = fire_update(state) + state = ts.fire_step(model, state, dt_max=0.01) # Print max force after optimization print(f"Initial energy: {results['energy'][0].item()} eV") diff --git a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py index 819ec6a49..37c6c6c21 100644 --- a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py @@ -1,11 +1,8 @@ """Batched MACE gradient descent example.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import numpy as np @@ -15,24 +12,21 @@ import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import gradient_descent # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 - # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -115,18 +109,16 @@ learning_rate = 0.01 # Initialize batched gradient descent optimizer -gd_init, gd_update = gradient_descent( - model=batched_model, - lr=learning_rate, -) +state = ts.gradient_descent_init(model=batched_model, state=state) -state = gd_init(state) # Run batched optimization for a few steps print("\nRunning batched gradient descent:") for step in range(N_steps): if step % 10 == 0: print(f"Step {step}, Energy: {[res.item() for res in state.energy]} eV") - state = gd_update(state) + state = ts.gradient_descent_step( + model=batched_model, state=state, pos_lr=learning_rate + ) print(f"Initial energies: {[res.item() for res in results['energy']]} eV") print(f"Final energies: {[res.item() for res in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py index 8ecc74bf4..c152385ac 100644 --- a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py @@ -1,11 +1,8 @@ """Batched MACE FIRE optimizer.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import numpy as np @@ -15,24 +12,22 @@ import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import fire # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -78,11 +73,8 @@ results = model(state) # Initialize unit cell gradient descent optimizer -init_fn, update_fn = fire( - model=model, -) +state = ts.fire_init(model=model, state=state, dt_start=0.005) -state = init_fn(state) # Run optimization for a few steps print("\nRunning FIRE:") @@ -90,7 +82,7 @@ if step % 20 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") - state = update_fn(state) + state = ts.fire_step(model, state, dt_max=0.01) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py index 5417724c4..c7b6bbe09 100644 --- a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py @@ -1,11 +1,8 @@ """Batched MACE unit cell filter with gradient descent optimizer.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import numpy as np @@ -15,25 +12,23 @@ import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import unit_cell_gradient_descent from torch_sim.units import UnitConversion # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -79,21 +74,19 @@ results = model(state) # Use same learning rate for all batches -positions_lr = 0.01 -cell_lr = 0.1 +pos_lr, cell_lr = 0.01, 0.1 + -# Initialize unit cell gradient descent optimizer -gd_init, gd_update = unit_cell_gradient_descent( +state = ts.gradient_descent_init( model=model, + state=state, + cell_filter=ts.CellFilter.unit, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, - positions_lr=positions_lr, - cell_lr=cell_lr, ) -state = gd_init(state) # Run optimization for a few steps print("\nRunning batched unit cell gradient descent:") @@ -108,7 +101,9 @@ f"P1={P1:.4f} GPa, P2={P2:.4f} GPa, P3={P3:.4f} GPa" ) - state = gd_update(state) + state = ts.gradient_descent_step( + model=model, state=state, pos_lr=pos_lr, cell_lr=cell_lr + ) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py index 85a7bd13a..60044fc2f 100644 --- a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py @@ -1,11 +1,8 @@ """Batched MACE unit cell filter with FIRE optimizer.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import numpy as np @@ -15,25 +12,23 @@ import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import unit_cell_fire from torch_sim.units import UnitConversion # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -78,17 +73,17 @@ # Run initial inference results = model(state) -# Initialize unit cell gradient descent optimizer -fire_init, fire_update = unit_cell_fire( +# Initialize FIRE optimizer with unit cell filter +state = ts.fire_init( model=model, + state=state, + cell_filter=ts.CellFilter.unit, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, ) -state = fire_init(state) - # Run optimization for a few steps print("\nRunning batched unit cell gradient descent:") for step in range(N_steps): @@ -102,7 +97,7 @@ f"P1={P1:.4f} GPa, P2={P2:.4f} GPa, P3={P3:.4f} GPa" ) - state = fire_update(state) + state = ts.fire_step(model=model, state=state) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py index ba06f850d..56c8eeb6e 100644 --- a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py @@ -1,11 +1,8 @@ """Batched MACE frechet cell filter with FIRE optimizer.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import numpy as np @@ -19,20 +16,19 @@ # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -77,17 +73,17 @@ # Run initial inference results = model(state) -# Initialize unit cell gradient descent optimizer -fire_init, fire_update = ts.optimizers.frechet_cell_fire( +# Initialize FIRE optimizer with Frechet cell filter +state = ts.fire_init( model=model, + state=state, + cell_filter=ts.CellFilter.frechet, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, constant_volume=False, scalar_pressure=0.0, ) -state = fire_init(state) - # Run optimization for a few steps print("\nRunning batched frechet cell filter with FIRE:") for step in range(N_steps): @@ -101,7 +97,7 @@ f"P1={P1:.4f} GPa, P2={P2:.4f} GPa, P3={P3:.4f} GPa" ) - state = fire_update(state) + state = ts.fire_step(model=model, state=state) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index fb3f79830..3aaaa420c 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -1,12 +1,8 @@ """Hybrid swap Monte Carlo simulation.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# "pymatgen>=2025.2.18", -# ] +# dependencies = ["mace-torch>=0.3.12", "pymatgen>=2025.2.18"] # /// - from dataclasses import dataclass import torch @@ -14,14 +10,14 @@ from pymatgen.core import Structure import torch_sim as ts -from torch_sim.integrators import MDState, nvt_langevin +from torch_sim.integrators.md import MDState from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.monte_carlo import swap_monte_carlo from torch_sim.units import MetalUnits as Units -device = "cuda" if torch.cuda.is_available() else "cpu" -dtype = torch.float64 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + kT = 1000 * Units.temperature @@ -29,13 +25,12 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) model = MaceModel( model=loaded_model, @@ -67,7 +62,7 @@ # %% @dataclass -class HybridSwapMCState(MDState): +class HybridSwapMCState(ts.SwapMCState, MDState): """State for Monte Carlo simulations. Attributes: @@ -77,19 +72,17 @@ class HybridSwapMCState(MDState): last_permutation: torch.Tensor _atom_attributes = ( - MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | {"last_permutation"} # noqa: SLF001 ) -nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) -md_state = nvt_init(state, seed=42) +md_state = ts.nvt_langevin_init(model=model, state=state, kT=torch.tensor(kT), seed=42) -swap_init, swap_step = swap_monte_carlo(model=model, kT=kT, seed=42) -swap_state = swap_init(md_state) +swap_state = ts.swap_mc_init(model=model, state=md_state) hybrid_state = HybridSwapMCState( **vars(md_state), - last_permutation=torch.zeros( - md_state.n_systems, device=md_state.device, dtype=torch.bool + last_permutation=torch.arange( + md_state.n_atoms, device=md_state.device, dtype=torch.long ), ) @@ -97,8 +90,13 @@ class HybridSwapMCState(MDState): generator.manual_seed(42) n_steps = 100 +dt = torch.tensor(0.002) for step in range(n_steps): if step % 10 == 0: - hybrid_state = swap_step(hybrid_state, kT=torch.tensor(kT), generator=generator) + hybrid_state = ts.swap_mc_step( + model=model, state=hybrid_state, kT=kT, seed=42 + step + ) else: - hybrid_state = nvt_step(hybrid_state, dt=torch.tensor(0.002), kT=torch.tensor(kT)) + hybrid_state = ts.nvt_langevin_update( + model=model, state=hybrid_state, dt=dt, kT=torch.tensor(kT) + ) diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 6f81519ed..1a1c47939 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -1,26 +1,21 @@ """Lennard-Jones simulation in NPT ensemble using Langevin thermostat.""" # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import itertools import os import torch import torch_sim as ts -from torch_sim.integrators import npt_langevin from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.quantities import calc_kinetic_energy, calc_kT, get_pressure from torch_sim.units import MetalUnits as Units from torch_sim.units import UnitConversion # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Set random seed and deterministic behavior for reproducibility @@ -103,30 +98,28 @@ # Run initial simulation and get results results = model(state) -dt = 0.001 * Units.time # Time step (1 ps) -kT = 200 * Units.temperature # Temperature (200 K) +dt = torch.tensor(0.001 * Units.time, device=device, dtype=dtype) # Time step (1 ps) +kT = torch.tensor( + 200 * Units.temperature, device=device, dtype=dtype +) # Temperature (200 K) target_pressure = ( torch.tensor(10_000, device=device, dtype=dtype) * Units.pressure ) # Target pressure (10 kbar) -npt_init, npt_update = npt_langevin( - model=model, dt=dt, kT=kT, external_pressure=target_pressure -) - -state = npt_init(state=state, seed=1) +state = ts.npt_langevin_init(model=model, state=state, dt=dt, kT=kT, seed=1) # Run the simulation for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) - pressure = get_pressure( + pressure = ts.get_pressure( model(state)["stress"], - calc_kinetic_energy( + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), torch.linalg.det(state.cell), @@ -138,21 +131,30 @@ f"{pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = npt_update(state, kT=kT, external_pressure=target_pressure) + state = ts.npt_langevin_update( + model=model, + state=state, + dt=dt, + kT=kT, + external_pressure=target_pressure, + alpha=1.0 / (100 * dt), + cell_alpha=1.0 / (100 * dt), + b_tau=1 / (1000 * dt), + ) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") stress = model(state)["stress"] -kinetic_energy = calc_kinetic_energy( +kinetic_energy = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) volume = torch.linalg.det(state.cell) -pressure = get_pressure(stress, kinetic_energy, volume) +pressure = ts.get_pressure(stress, kinetic_energy, volume) pressure = pressure.item() / Units.pressure print(f"Final {pressure=:.4f}") print(stress * UnitConversion.eV_per_Ang3_to_GPa) diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index a07a4d749..3bf1b6ee0 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -1,11 +1,8 @@ """NPT simulation with MACE and Nose-Hoover thermostat.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import torch @@ -13,28 +10,24 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators.npt import npt_langevin -from torch_sim.integrators.nvt import nvt_nose_hoover, nvt_nose_hoover_invariant from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kinetic_energy, calc_kT, get_pressure from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -56,36 +49,34 @@ SMOKE_TEST = os.getenv("CI") is not None N_steps_nvt = 20 if SMOKE_TEST else 2_000 N_steps_npt = 20 if SMOKE_TEST else 2_000 -dt = 0.001 * Units.time # Time step (1 ps) +dt = torch.tensor(0.001 * Units.time, device=device, dtype=dtype) # Time step (1 ps) kT = ( torch.tensor(300, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (300 K) -target_pressure = 10_000 * Units.pressure # Target pressure (0 bar) +target_pressure = torch.tensor( + 10_000 * Units.pressure, device=device, dtype=dtype +) # Target pressure (0 bar) -nvt_init, nvt_update = nvt_nose_hoover(model=model, kT=kT, dt=dt) -state = nvt_init(state=state, seed=1) +state = ts.nvt_nose_hoover_init(model=model, state=state, kT=kT, dt=dt, seed=1) for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) - invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) + invariant = float(ts.nvt_nose_hoover_invariant(state, kT=kT)) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") - state = nvt_update(state, kT=kT) + state = ts.nvt_nose_hoover_update(model=model, state=state, dt=dt, kT=kT) -npt_init, npt_update = npt_langevin( - model=model, kT=kT, dt=dt, external_pressure=target_pressure -) -state = npt_init(state=state, seed=1) +state = ts.npt_langevin_init(model=model, state=state, kT=kT, dt=dt, seed=1) for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature @@ -93,9 +84,9 @@ stress = model(state)["stress"] volume = torch.det(state.cell) pressure = ( - get_pressure( + ts.get_pressure( stress, - calc_kinetic_energy( + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx, @@ -110,19 +101,28 @@ f"pressure: {pressure:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = npt_update(state, kT=kT, external_pressure=target_pressure) + state = ts.npt_langevin_update( + model=model, + state=state, + dt=dt, + kT=kT, + external_pressure=target_pressure, + alpha=1.0 / (100 * dt), + cell_alpha=1.0 / (100 * dt), + b_tau=1 / (1000 * dt), + ) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f} K") final_stress = model(state)["stress"] final_volume = torch.det(state.cell) final_pressure = ( - get_pressure( + ts.get_pressure( final_stress, - calc_kinetic_energy( + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), final_volume, diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index 07fcb4c8c..bd8228807 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -1,11 +1,8 @@ """NVE simulation with MACE.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import time @@ -14,27 +11,24 @@ from mace.calculators.foundations_models import mace_off import torch_sim as ts -from torch_sim.integrators import nve from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kinetic_energy from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_off( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -52,7 +46,7 @@ enable_cueq=False, ) -state = ts.io.atoms_to_state(mol, device, dtype) +state = ts.io.atoms_to_state(mol, device=device, dtype=dtype) # Run initial inference results = model(state) @@ -61,31 +55,26 @@ kT = ( torch.tensor(300, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (K) -dt = 0.002 * Units.time # Timestep (ps) +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # Timestep (ps) # Initialize NVE integrator -nve_init, nve_update = nve( - model=model, - dt=dt, - kT=kT, -) -state = nve_init(state=state, seed=1) +state = ts.nve_init(model=model, state=state, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") start_time = time.perf_counter() for step in range(N_steps): - total_energy = state.energy + calc_kinetic_energy( + total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = nve_update(state=state, dt=dt) + state = ts.nve_update(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results print("\nSimulation complete!") print(f"Time taken: {end_time - start_time:.2f} seconds") -print(f"Average time per step: {(end_time - start_time) / 1000:.4f} seconds") +print(f"Average time per step: {(end_time - start_time) / N_steps:.4f} seconds") print(f"Final total energy: {total_energy.item()} eV") diff --git a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py index 9506e9d4d..54a3fab17 100644 --- a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py +++ b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py @@ -1,25 +1,20 @@ """NVE simulation with Lennard-Jones potential.""" # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import itertools import os import torch import torch_sim as ts -from torch_sim.integrators import nve from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.quantities import calc_kinetic_energy from torch_sim.units import MetalUnits as Units # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Number of steps to run @@ -69,7 +64,9 @@ # Create the cell tensor cell = torch.tensor( - [[4 * a_len, 0, 0], [0, 4 * a_len, 0], [0, 0, 4 * a_len]], device=device, dtype=dtype + [[4 * a_len, 0, 0], [0, 4 * a_len, 0], [0, 0, 4 * a_len]], + device=device, + dtype=dtype, ) # Create the atomic numbers tensor (Argon = 18) @@ -102,27 +99,25 @@ # Set up NVE simulation # kT: initial temperature in metal units (K) # dt: timestep in metal units (ps) -kT = 80 * Units.temperature -dt = 0.001 * Units.time +kT = torch.tensor(80 * Units.temperature, device=device, dtype=dtype) +dt = torch.tensor(0.001 * Units.time, device=device, dtype=dtype) # Initialize NVE integrator -nve_init, nve_update = nve(model=model, dt=dt, kT=kT) - -state = nve_init(state=state) +state = ts.nve_init(model=model, state=state, kT=kT, seed=1) # Run NVE simulation for 1000 steps for step in range(N_steps): if step % 100 == 0: # Calculate total energy (potential + kinetic) - total_energy = state.energy + calc_kinetic_energy( + total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta ) print(f"{step=}: Total energy: {total_energy.item():.4f}") # Update state using NVE integrator - state = nve_update(state=state, dt=dt) + state = ts.nve_update(model=model, state=state, dt=dt) -final_total_energy = state.energy + calc_kinetic_energy( +final_total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta ) print(f"Final total energy: {final_total_energy.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index 4151eb1a9..d32aa23c9 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -1,11 +1,8 @@ """NVE simulation with MACE.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import time @@ -14,27 +11,25 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nve from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.quantities import calc_kinetic_energy from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -68,16 +63,11 @@ # Setup NVE MD simulation parameters kT = torch.tensor(1000, device=device, dtype=dtype) * Units.temperature -dt = 0.002 * Units.time # Timestep (ps) +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # Timestep (ps) # Initialize NVE integrator -nve_init, nve_update = nve( - model=model, - dt=dt, - kT=kT, -) -state = nve_init(state=state, seed=1) +state = ts.nve_init(model=model, state=state, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") @@ -88,11 +78,11 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = nve_update(state=state, dt=dt) + state = ts.nve_update(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results print("\nSimulation complete!") print(f"Time taken: {end_time - start_time:.2f} seconds") -print(f"Average time per step: {(end_time - start_time) / 1000:.4f} seconds") +print(f"Average time per step: {(end_time - start_time) / N_steps:.4f} seconds") print(f"Final total energy: {total_energy.item()} eV") diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index 2dd507fec..f926df135 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -1,11 +1,8 @@ """NVE simulation with MACE and cuEquivariance enabled.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import time @@ -14,27 +11,25 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nve +from torch_sim.integrators import nve_init, nve_update from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kinetic_energy from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -59,27 +54,26 @@ # Setup NVE MD simulation parameters kT = torch.tensor(1000, device=device, dtype=dtype) * Units.temperature -dt = 0.002 * Units.time # Timestep (ps) +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # Timestep (ps) # Initialize NVE integrator -nve_init, nve_update = nve(model=model, dt=dt, kT=kT) -state = nve_init(state=state, seed=1) +state = nve_init(model=model, state=state, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") start_time = time.perf_counter() for step in range(N_steps): - total_energy = state.energy + calc_kinetic_energy( + total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = nve_update(state=state, dt=dt) + state = nve_update(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results print("\nSimulation complete!") print(f"Time taken: {end_time - start_time:.2f} seconds") -print(f"Average time per step: {(end_time - start_time) / 1000:.4f} seconds") +print(f"Average time per step: {(end_time - start_time) / N_steps:.4f} seconds") print(f"Final total energy: {total_energy.item()} eV") diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index c998e9502..45dcff874 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -1,11 +1,8 @@ """MACE NVT Langevin dynamics.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import torch @@ -13,27 +10,25 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nvt_langevin +from torch_sim.integrators import nvt_langevin_init, nvt_langevin_update from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kT from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -62,33 +57,28 @@ positions=positions, masses=masses, cell=cell, atomic_numbers=atomic_numbers, pbc=True ) -dt = 0.002 * Units.time # Timestep (ps) +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # Timestep (ps) kT = torch.tensor(1000, device=device, dtype=dtype) * Units.temperature -gamma = 10 / Units.time # Langevin friction coefficient (ps^-1) +gamma = torch.tensor( + 10 / Units.time, device=device, dtype=dtype +) # Langevin friction coefficient (ps^-1) # Initialize NVT Langevin integrator -langevin_init, langevin_update = nvt_langevin( - model=model, - kT=kT, - dt=dt, - gamma=gamma, -) - -state = langevin_init(state=state, seed=1) +state = nvt_langevin_init(model=model, state=state, kT=kT, seed=1) for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) print(f"{step=}: Temperature: {temp.item():.4f}") - state = langevin_update(state=state, kT=kT) + state = nvt_langevin_update(model=model, state=state, dt=dt, kT=kT, gamma=gamma) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py index 97cb06e9c..b9fd7864e 100644 --- a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py @@ -1,11 +1,8 @@ """NVT simulation with MACE and Nose-Hoover thermostat.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import torch @@ -13,27 +10,24 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators.nvt import nvt_nose_hoover, nvt_nose_hoover_invariant from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kT from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -56,29 +50,28 @@ # Run initial inference results = model(state) -dt = 0.002 * Units.time # Timestep (ps) +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # Timestep (ps) kT = ( torch.tensor(1000, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (K) -nvt_init, nvt_update = nvt_nose_hoover(model=model, kT=kT, dt=dt) -state = nvt_init(state=state, kT=kT, seed=1) +state = ts.nvt_nose_hoover_init(model=model, state=state, kT=kT, dt=dt) for step in range(N_steps): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) - invariant = float(nvt_nose_hoover_invariant(state, kT=kT)) + invariant = float(ts.nvt_nose_hoover_invariant(state, kT=kT)) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") - state = nvt_update(state=state, kT=kT) + state = ts.nvt_nose_hoover_update(model=model, state=state, dt=dt, kT=kT) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") diff --git a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py index aafc28ac6..d13daa8ca 100644 --- a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py +++ b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py @@ -9,7 +9,6 @@ # "kaleido", # ] # /// - import os import numpy as np @@ -19,9 +18,7 @@ from plotly.subplots import make_subplots import torch_sim as ts -from torch_sim.integrators.nvt import nvt_nose_hoover, nvt_nose_hoover_invariant from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kT from torch_sim.units import MetalUnits as Units @@ -73,7 +70,7 @@ def get_kT( # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Model configuration @@ -81,8 +78,8 @@ def get_kT( loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file @@ -136,11 +133,10 @@ def get_kT( results = model(state) # Set up simulation parameters -dt = 0.002 * Units.time +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) kT = torch.tensor(init_temp, device=device, dtype=dtype) * Units.temperature -nvt_init, nvt_update = nvt_nose_hoover(model=model, kT=kT, dt=dt) -state = nvt_init(state, kT=kT, seed=1) +state = ts.nvt_nose_hoover_init(model=model, state=state, kT=kT, dt=dt, seed=1) # Run simulation with temperature profile actual_temps = np.zeros(n_steps) @@ -163,18 +159,24 @@ def get_kT( # Calculate current temperature and save data temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) / Units.temperature ) actual_temps[step] = temp expected_temps[step] = current_kT # Calculate invariant and progress report - invariant = float(nvt_nose_hoover_invariant(state, kT=current_kT * Units.temperature)) + invariant = float( + ts.nvt_nose_hoover_invariant(state, kT=current_kT * Units.temperature) + ) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") # Update simulation state - state = nvt_update(state, kT=current_kT * Units.temperature) + state = ts.nvt_nose_hoover_update( + model=model, state=state, dt=dt, kT=current_kT * Units.temperature + ) # Visualize temperature profile fig = make_subplots() diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index 0c1ffa58c..c1d4a5db9 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -1,25 +1,20 @@ """Lennard-Jones simulation in NPT ensemble using Nose-Hoover chain.""" # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import itertools import os import torch import torch_sim as ts -from torch_sim.integrators.npt import npt_nose_hoover, npt_nose_hoover_invariant from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.quantities import calc_kinetic_energy, calc_kT, get_pressure from torch_sim.units import MetalUnits as Units # Set up the device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Set random seed and deterministic behavior for reproducibility @@ -102,39 +97,40 @@ # Run initial simulation and get results results = model(state) -dt = 0.001 * Units.time # Time step (1 ps) -kT = 200 * Units.temperature # Temperature (200 K) +dt = torch.tensor(0.001 * Units.time, device=device, dtype=dtype) # Time step (1 ps) +kT = torch.tensor( + 200 * Units.temperature, device=device, dtype=dtype +) # Temperature (200 K) target_pressure = ( torch.tensor(10, device=device, dtype=dtype) * Units.pressure ) # Target pressure (10 kbar) -npt_init, npt_update = npt_nose_hoover( +state = ts.npt_nose_hoover_init( model=model, + state=state, dt=dt, kT=kT, - external_pressure=target_pressure, chain_length=3, # Chain length chain_steps=1, sy_steps=1, ) -state = npt_init(state=state, seed=1) # Run the simulation for step in range(N_steps): if step % 50 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) invariant = float( - npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) + ts.npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) - e_kin = calc_kinetic_energy( + e_kin = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) - pressure = get_pressure( + pressure = ts.get_pressure( model(state)["stress"], e_kin, torch.det(state.current_cell) ) pressure = float(pressure) / Units.pressure @@ -144,17 +140,19 @@ f"{invariant=:.4f}, {pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = npt_update(state, kT=kT, external_pressure=target_pressure) + state = ts.npt_nose_hoover_update( + model=model, state=state, dt=dt, kT=kT, external_pressure=target_pressure + ) temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {temp.item():.4f}") -pressure = get_pressure( +pressure = ts.get_pressure( model(state)["stress"], - calc_kinetic_energy( + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), torch.det(state.current_cell), diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index 9dbc402a1..97d169de7 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -1,11 +1,8 @@ """NPT simulation with MACE and Nose-Hoover thermostat.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import torch @@ -13,27 +10,24 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators.npt import npt_nose_hoover, npt_nose_hoover_invariant from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.quantities import calc_kinetic_energy, calc_kT, get_pressure from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -59,67 +53,75 @@ kT = ( torch.tensor(300, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (300 K) -target_pressure = 0.0 * Units.pressure # Target pressure (0 bar) +target_pressure = torch.tensor( + 0.0 * Units.pressure, device=device, dtype=dtype +) # Target pressure (0 bar) -npt_init, npt_update = npt_nose_hoover( - model=model, kT=kT, dt=dt, external_pressure=target_pressure -) -state = npt_init(state=state, seed=1) +state = ts.npt_nose_hoover_init(model=model, state=state, kT=kT, dt=torch.tensor(dt)) for step in range(N_steps_nvt): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) invariant = float( - npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) + ts.npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") - state = npt_update(state, kT=kT) + state = ts.npt_nose_hoover_update( + model=model, + state=state, + dt=torch.tensor(dt), + kT=kT, + external_pressure=target_pressure, + ) -npt_init, npt_update = npt_nose_hoover( - model=model, kT=kT, dt=dt, external_pressure=target_pressure -) -state = npt_init(state=state, seed=1) +state = ts.npt_nose_hoover_init(model=model, state=state, kT=kT, dt=torch.tensor(dt)) for step in range(N_steps_npt): if step % 10 == 0: temp = ( - calc_kT( + ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) / Units.temperature ) invariant = float( - npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) + ts.npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) stress = model(state)["stress"] volume = torch.det(state.current_cell) - e_kin = calc_kinetic_energy( + e_kin = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) - pressure = float(get_pressure(stress, e_kin, volume)) + pressure = float(ts.get_pressure(stress, e_kin, volume)) xx, yy, zz = torch.diag(state.current_cell[0]) print( f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, " f"{pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = npt_update(state, kT=kT, external_pressure=target_pressure) + state = ts.npt_nose_hoover_update( + model=model, + state=state, + dt=torch.tensor(dt), + kT=kT, + external_pressure=target_pressure, + ) final_temp = ( - calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) + ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) / Units.temperature ) print(f"Final temperature: {final_temp.item():.4f}") final_stress = model(state)["stress"] final_volume = torch.det(state.current_cell) -final_pressure = get_pressure( +final_pressure = ts.get_pressure( final_stress, - calc_kinetic_energy( + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ), final_volume, diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index 0106f3b8e..6f5ae7b26 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -1,11 +1,8 @@ """MACE NVT simulation with staggered stress calculation.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - import os import torch @@ -13,27 +10,25 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nvt_langevin from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.quantities import calc_kT from torch_sim.units import MetalUnits as Units # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Option 2: Load from local file (comment out Option 1 to use this) -# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" -# loaded_model = torch.load(MODEL_PATH, map_location=device) +# loaded_model = torch.load("path/to/model.pt", map_location=device) # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None @@ -59,8 +54,7 @@ torch.tensor(1000, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (K) -nvt_init, nvt_update = nvt_langevin(model=model, kT=kT, dt=dt) -state = nvt_init(state, kT=kT, seed=1) +state = ts.nvt_langevin_init(model=model, state=state, kT=kT) stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) for step in range(N_steps): @@ -70,12 +64,14 @@ ) # Calculate kinetic energy: KE = 0.5 * sum(p^2 / m) - kinetic_energy = 0.5 * torch.sum(state.momenta**2 / state.masses.unsqueeze(-1)) + kinetic_energy = 0.5 * torch.sum( + torch.pow(state.momenta, 2) / state.masses.unsqueeze(-1) + ) # Total energy = kinetic + potential invariant = float(kinetic_energy + state.energy) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") - state = nvt_update(state, kT=kT) + state = ts.nvt_langevin_update(model=model, state=state, dt=torch.tensor(dt), kT=kT) if step % 10 == 0: results = model(state) stress[step // 10] = results["stress"] diff --git a/examples/scripts/4_High_level_api/4.1_high_level_api.py b/examples/scripts/4_High_level_api/4.1_high_level_api.py index 396ca0350..5568b3bc4 100644 --- a/examples/scripts/4_High_level_api/4.1_high_level_api.py +++ b/examples/scripts/4_High_level_api/4.1_high_level_api.py @@ -3,12 +3,8 @@ """ # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# "pymatgen>=2025.2.18", -# ] +# dependencies = ["mace-torch>=0.3.12", "pymatgen>=2025.2.18"] # /// - import os import numpy as np @@ -18,11 +14,8 @@ from pymatgen.core import Structure import torch_sim as ts -from torch_sim.integrators import nvt_langevin from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.mace import MaceModel -from torch_sim.optimizers import unit_cell_fire -from torch_sim.quantities import calc_kinetic_energy from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter from torch_sim.units import MetalUnits @@ -41,20 +34,19 @@ final_state = ts.integrate( system=si_atoms, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, ) final_atoms = ts.io.state_to_atoms(final_state) - -trajectory_file = "lj_trajectory.h5md" +trajectory_file = "tmp/lj_trajectory.h5md" # report potential energy every 10 steps and kinetic energy every 20 steps prop_calculators = { 10: {"potential_energy": lambda state: state.energy}, 20: { - "kinetic_energy": lambda state: calc_kinetic_energy( + "kinetic_energy": lambda state: ts.calc_kinetic_energy( momenta=state.momenta, masses=state.masses ) }, @@ -70,7 +62,7 @@ final_state = ts.integrate( system=si_atoms, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, @@ -89,7 +81,7 @@ ### basic mace example # cuda if available -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mace = mace_mp(model="small", return_raw_model=True) @@ -110,7 +102,7 @@ final_state = ts.integrate( system=si_atoms, model=mace_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, @@ -128,7 +120,7 @@ final_state = ts.integrate( system=[si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell], model=mace_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, @@ -139,9 +131,9 @@ ### basic mace example with batching and reporting -systems = [si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell] +systems = (si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell) -filenames = [f"batch_traj_{i}.h5md" for i in range(len(systems))] +filenames = [f"tmp/batch_traj_{i}.h5md" for i in range(len(systems))] batch_reporter = TrajectoryReporter( filenames, state_frequency=100, @@ -150,7 +142,7 @@ final_state = ts.integrate( system=systems, model=mace_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, @@ -164,18 +156,14 @@ final_energies_per_atom.append(final_energy / len(traj.get_atoms(-1))) -systems = [si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell] - final_state = ts.optimize( system=systems, model=mace_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, max_steps=10 if SMOKE_TEST else 1000, ) - -systems = [si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell] - rng = np.random.default_rng() for system in systems: system.positions += rng.random(system.positions.shape) * 0.01 @@ -183,7 +171,8 @@ final_state = ts.optimize( system=systems, model=mace_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=lambda state, last_energy: last_energy - state.energy < 1e-6 * MetalUnits.energy, max_steps=10 if SMOKE_TEST else 1000, @@ -206,7 +195,7 @@ final_state = ts.integrate( system=structure, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100 if SMOKE_TEST else 1000, temperature=2000, timestep=0.002, diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index 66463d981..b3336a1cc 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -1,15 +1,8 @@ -"""Examples of using the auto-batching API.""" +"""Examples of using the auto-batching API. Meant to be run as an interactive script.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// - -"""Run as a interactive script.""" -# ruff: noqa: E402 - - # %% import os @@ -23,9 +16,7 @@ InFlightAutoBatcher, calculate_memory_scaler, ) -from torch_sim.integrators import nvt_langevin from torch_sim.models.mace import MaceModel -from torch_sim.optimizers import unit_cell_fire from torch_sim.runners import generate_force_convergence_fn from torch_sim.units import MetalUnits @@ -37,7 +28,7 @@ si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) - +state: ts.FireState | None = None device = torch.device("cuda") mace = mace_mp(model="small", return_raw_model=True) @@ -51,10 +42,14 @@ si_state = ts.io.atoms_to_state(si_atoms, device=device, dtype=torch.float64) fe_state = ts.io.atoms_to_state(fe_atoms, device=device, dtype=torch.float64) -fire_init, fire_update = unit_cell_fire(mace_model) +state = ts.fire_init(model=mace_model, state=si_state, cell_filter=ts.CellFilter.unit) -si_fire_state = fire_init(si_state) -fe_fire_state = fire_init(fe_state) +si_fire_state = ts.fire_init( + model=mace_model, state=si_state, cell_filter=ts.CellFilter.unit +) +fe_fire_state = ts.fire_init( + model=mace_model, state=fe_state, cell_filter=ts.CellFilter.unit +) fire_states = [si_fire_state, fe_fire_state] * (2 if SMOKE_TEST else 20) fire_states = [state.clone() for state in fire_states] @@ -75,31 +70,36 @@ batcher.load_states(fire_states) all_completed_states, convergence_tensor, state = [], None, None while (result := batcher.next_batch(state, convergence_tensor))[0] is not None: - state, completed_states = result + state, completed_states = result[0], result[1] print(f"Starting new batch of {state.n_systems} states.") all_completed_states.extend(completed_states) - print("Total number of completed states", len(all_completed_states)) + print(f"Total number of completed states {len(all_completed_states)}") for _step in range(10): - state = fire_update(state) + state = ts.fire_step(model=mace_model, state=state) convergence_tensor = converge_max_force(state, last_energy=None) all_completed_states.extend(result[1]) -print("Total number of completed states", len(all_completed_states)) +print(f"Total number of completed states {len(all_completed_states)}") # %% run binning autobatcher -nvt_init, nvt_update = nvt_langevin( - model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature +si_nvt_state = ts.nvt_langevin_init( + model=mace_model, + state=si_state, + dt=torch.tensor(0.001), + kT=torch.tensor(300 * MetalUnits.temperature), +) +fe_nvt_state = ts.nvt_langevin_init( + model=mace_model, + state=fe_state, + dt=torch.tensor(0.001), + kT=torch.tensor(300 * MetalUnits.temperature), ) - si_state = ts.io.atoms_to_state(si_atoms, device=device, dtype=torch.float64) fe_state = ts.io.atoms_to_state(fe_atoms, device=device, dtype=torch.float64) -si_nvt_state = nvt_init(si_state) -fe_nvt_state = nvt_init(fe_state) - nvt_states = [si_nvt_state, fe_nvt_state] * (2 if SMOKE_TEST else 20) nvt_states = [state.clone() for state in nvt_states] for state in nvt_states: @@ -113,9 +113,9 @@ max_memory_scaler=single_system_memory * 2.5 if SMOKE_TEST else None, ) batcher.load_states(nvt_states) -finished_states = [] -for batch in batcher: +finished_states: list[ts.SimState] = [] +for batch, _indices in batcher: for _ in range(100): - batch = nvt_update(batch) + batch = ts.nvt_langevin_update(model=mace_model, state=batch) finished_states.extend(batch.split()) diff --git a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py index 5c6ff991e..9aae48b09 100644 --- a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py +++ b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py @@ -1,4 +1,4 @@ -"""Demo of the amorphous-to-crystalline (A2C) algorithm for a-Si, ported to torchsim from +"""Demo of the amorphous-to-crystalline (A2C) algorithm for a-Si, ported to TorchSim from jax-md https://github.com/jax-md/jax-md/blob/main/jax_md/a2c/a2c_workflow.py. """ @@ -9,7 +9,6 @@ # "pymatgen>=2025.2.18", # ] # /// - import os import time from collections import defaultdict @@ -24,11 +23,7 @@ from tqdm import tqdm import torch_sim as ts -from torch_sim.integrators.nvt import ( - NVTNoseHooverState, - nvt_nose_hoover, - nvt_nose_hoover_invariant, -) +from torch_sim.integrators.nvt import NVTNoseHooverState from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.units import MetalUnits as Units from torch_sim.workflows import a2c @@ -36,7 +31,7 @@ """ # Example of how to use random_packed_structure_multi -from torch_sim.utils.a2c import random_packed_structure_multi +from torch_sim.workflows.a2c import random_packed_structure_multi comp = Composition("Fe80B20") cell = torch.tensor( @@ -54,15 +49,22 @@ SMOKE_TEST = os.getenv("CI") is not None -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 -raw_model = mace_mp(model=MaceUrls.mace_mpa_medium, return_raw_model=True) +raw_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), +) # Define system and model comp = Composition("Si64") cell = torch.tensor( - [[11.1, 0.0, 0.0], [0.0, 11.1, 0.0], [0.0, 0.0, 11.1]], dtype=dtype, device=device + [[11.1, 0.0, 0.0], [0.0, 11.1, 0.0], [0.0, 0.0, 11.1]], + dtype=dtype, + device=device, ) atomic_numbers = [Element(el).Z for el in comp.get_el_amt_dict()] * int(comp.num_atoms) @@ -81,7 +83,7 @@ enable_cueq=False, ) # Workflow starts here -structure = a2c.random_packed_structure( +structure, _log = a2c.random_packed_structure( composition=comp, cell=cell, auto_diameter=True, @@ -89,7 +91,6 @@ dtype=dtype, max_iter=100, ) - # Relax structure in batches of 6 batch_size = 1 if SMOKE_TEST else 6 max_optim_steps = ( @@ -102,17 +103,11 @@ final_steps = 25 if SMOKE_TEST else 2500 # MD steps for amorphous phase equilibration T_high = 2000 # Melt temperature T_low = 300 # Quench to this temperature -dt = 0.002 * Units.time # time step = 2fs +dt = torch.tensor(0.002 * Units.time, device=device, dtype=dtype) # time step = 2fs tau = 40 * dt # oscillation period in Nose-Hoover thermostat simulation_steps = equi_steps + cool_steps + final_steps -nvt_nose_hoover_init, nvt_nose_hoover_update = nvt_nose_hoover( - model=model, - kT=T_high * Units.temperature, - dt=dt, -) - state_dict = { "positions": structure.positions, "masses": torch.tensor(atomic_masses, device=device, dtype=dtype), @@ -120,7 +115,13 @@ "pbc": True, "atomic_numbers": atomic_numbers, } -state = nvt_nose_hoover_init(state_dict) +state = ts.nvt_nose_hoover_init( + model=model, + state=state_dict, + kT=torch.tensor(T_high * Units.temperature, device=device, dtype=dtype), + dt=dt, + seed=1, +) logger = { "T": torch.zeros((simulation_steps, 1), device=device, dtype=dtype), @@ -137,10 +138,16 @@ def step_fn( ts.quantities.calc_kT(masses=state.masses, momenta=state.momenta) / Units.temperature ) - logger["H"][step] = nvt_nose_hoover_invariant( - state, kT=current_temp * Units.temperature + logger["H"][step] = ts.nvt_nose_hoover_invariant( + state, + kT=torch.tensor(current_temp * Units.temperature, device=device, dtype=dtype), ).item() - state = nvt_nose_hoover_update(state, kT=current_temp * Units.temperature) + state = ts.nvt_nose_hoover_update( + model=model, + state=state, + dt=dt, + kT=torch.tensor(current_temp * Units.temperature, device=device, dtype=dtype), + ) return state, logger @@ -209,43 +216,43 @@ def step_fn( enable_cueq=False, ) -pymatgen_relaxed_struct_list = [] +pymatgen_relaxed_struct_list: list[tuple[Structure, float, float]] = [] # Process structures in batches of 4 -for i in tqdm(range(0, len(pymatgen_struct_list), batch_size)): - batch_structs = pymatgen_struct_list[i : i + batch_size] +for batch_idx in tqdm(range(0, len(pymatgen_struct_list), batch_size)): + batch_structs = pymatgen_struct_list[batch_idx : batch_idx + batch_size] # Combine structures into a single batched state batch_state = ts.io.structures_to_state(batch_structs, device=device, dtype=dtype) final_state, logger, final_energy, final_pressure = ( - a2c.get_unit_cell_relaxed_structure( + a2c.get_frechet_cell_relaxed_structure( state=batch_state, model=model, max_iter=max_optim_steps, ) ) - final_struct_list = ts.io.state_to_structures(final_state) + final_structs = ts.io.state_to_structures(final_state) # NOTE: Possible OOM, so we don't store the logger # relaxed_structures.append((pymatgen_struct, logger, final_energy, final_pressure)) - for i, final_struct in enumerate(final_struct_list): + for sys_idx, final_struct in enumerate(final_structs): pymatgen_relaxed_struct_list.append( - (final_struct, final_energy[i], final_pressure[i]) + (final_struct, final_energy[sys_idx], final_pressure[sys_idx]) ) lowest_e_struct = sorted( pymatgen_relaxed_struct_list, key=lambda x: x[-2] / x[0].num_sites )[0] spg = SpacegroupAnalyzer(lowest_e_struct[0]) -print("Space group of predicted crystallization product:", spg.get_space_group_symbol()) +print(f"Space group of predicted crystallization product: {spg.get_space_group_symbol()}") spg_counter = defaultdict(int) for struct in pymatgen_relaxed_struct_list: sym_data = MoyoDataset(MoyoAdapter.from_py_obj(struct[0])) sp = (sym_data.number, SpaceGroupType(sym_data.number).arithmetic_symbol) -spg_counter[sp] += 1 + spg_counter[sp] += 1 -print("All space groups encountered:", dict(spg_counter)) +print(f"All space groups encountered: {dict(spg_counter)}") si_diamond = Structure( lattice=[ [0.0, 2.732954, 2.732954], @@ -257,7 +264,7 @@ def step_fn( coords_are_cartesian=False, ) struct_match = StructureMatcher().fit(lowest_e_struct[0], si_diamond) -print("Prediction matches diamond-cubic Si?", struct_match) +print(f"Prediction matches diamond-cubic Si? {struct_match}") end_time = time.perf_counter() print(f"Total time taken to run the workflow: {end_time - start_time:.2f} seconds") diff --git a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py index 1003f6363..8bebb4f79 100644 --- a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py +++ b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py @@ -1,12 +1,8 @@ """Example script demonstrating batched MACE model optimization with hot-swapping.""" # /// script -# dependencies = [ -# "mace-torch>=0.3.10", -# "matbench-discovery>=1.3.1", -# ] +# dependencies = ["mace-torch>=0.3.10", "matbench-discovery>=1.3.1"] # /// - import os import time @@ -21,7 +17,7 @@ # --- Setup and Configuration --- # Device and data type configuration SMOKE_TEST = os.getenv("CI") is not None -device = torch.device("cpu") if SMOKE_TEST else torch.device("cuda") +device = torch.device("cpu" if SMOKE_TEST else "cuda") dtype = torch.float32 print(f"job will run on {device=}") @@ -64,9 +60,10 @@ # Statistics tracking # Initialize first batch -fire_init, fire_update = ts.optimizers.frechet_cell_fire(model=mace_model) -fire_states = fire_init( - ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype) +fire_states = ts.fire_init( + model=mace_model, + state=ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype), + cell_filter=ts.CellFilter.frechet, ) batcher = ts.autobatching.InFlightAutoBatcher( @@ -86,13 +83,13 @@ print(f"Starting new batch of {state.n_systems} states.") all_completed_states.extend(completed_states) - print("Total number of completed states", len(all_completed_states)) + print(f"Total number of completed states {len(all_completed_states)}") for _step in range(10): - state = fire_update(state) + state = ts.fire_step(model=mace_model, state=state) convergence_tensor = converge_max_force(state, last_energy=None) all_completed_states.extend(result[1]) -print("Total number of completed states", len(all_completed_states)) +print(f"Total number of completed states {len(all_completed_states)}") # --- Final Statistics --- end_time = time.perf_counter() diff --git a/examples/scripts/5_Workflow/5.3_Elastic.py b/examples/scripts/5_Workflow/5.3_Elastic.py index c49ec63b0..891439703 100644 --- a/examples/scripts/5_Workflow/5.3_Elastic.py +++ b/examples/scripts/5_Workflow/5.3_Elastic.py @@ -1,12 +1,8 @@ """Bulk and Shear modulus with MACE.""" # /// script -# dependencies = [ -# "ase>=3.24", -# "mace-torch>=0.3.12", -# ] +# dependencies = ["ase>=3.26", "mace-torch>=0.3.12"] # /// - import torch from ase.build import bulk from mace.calculators.foundations_models import mace_mp @@ -18,13 +14,14 @@ # Calculator unit_conv = ts.units.UnitConversion -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float64 + loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, enable_cueq=False, - device=device, - default_dtype="float64", + device=str(device), + default_dtype=str(dtype).lstrip("torch."), return_raw_model=True, ) @@ -43,10 +40,10 @@ fmax = 1e-3 # Relax positions and cell -fire_init, fire_update = ts.optimizers.frechet_cell_fire(model=model, scalar_pressure=0.0) - state = ts.io.atoms_to_state(atoms=struct, device=device, dtype=dtype) -state = fire_init(state=state) +state = ts.fire_init( + model=model, state=state, scalar_pressure=0.0, cell_filter=ts.CellFilter.frechet +) for step in range(300): pressure = -torch.trace(state.stress.squeeze()) / 3 * unit_conv.eV_per_Ang3_to_GPa @@ -58,7 +55,7 @@ ) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = fire_update(state=state) + state = ts.fire_step(model=model, state=state) # Get bravais type bravais_type = get_bravais_type(state) diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 3968acdab..19e60d25e 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -10,7 +10,6 @@ # "ase", # ] # /// - import numpy as np import pymatviz as pmv import seekpath @@ -87,22 +86,22 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 # Load the raw model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Structure and input parameters struct = bulk("Si", "diamond", a=5.431, cubic=True) # ASE structure supercell_matrix = 2 * np.eye(3) # supercell matrix for phonon calculation mesh = [20, 20, 20] # Phonon mesh -Nrelax = 300 # number of relaxation steps +max_steps = 300 # number of relaxation steps displ = 0.01 # atomic displacement for phonons (in Angstrom) # Relax atomic positions @@ -117,10 +116,11 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b final_state = ts.optimize( system=struct, model=model, - optimizer=ts.optimizers.frechet_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, + max_steps=max_steps, constant_volume=True, hydrostatic_strain=True, - max_steps=Nrelax, ) # Define atoms and Phonopy object @@ -130,6 +130,8 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b # Generate FC2 displacements ph.generate_displacements(distance=displ) supercells = ph.supercells_with_displacements +if supercells is None: + raise ValueError("supercells cannot be None") # Convert PhonopyAtoms to state state = ts.io.phonopy_to_state(supercells, device, dtype) @@ -160,7 +162,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b pbc=True, ) q_pts, connections = get_qpts_and_connections(ase_atoms) -ph.run_band_structure(q_pts, connections) +ph.run_band_structure(q_pts, path_connections=connections) # Define axis style for plots axis_style = dict( diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 4b4edea48..1031245fc 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -10,7 +10,6 @@ # "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635 # ] # /// - import os import numpy as np @@ -31,17 +30,17 @@ def get_relaxed_structure( struct: Atoms, model: ModelInterface, - Nrelax: int = 300, + max_steps: int = 300, fmax: float = 1e-3, *, use_autobatcher: bool = False, -) -> ts.state.SimState: +) -> ts.SimState: """Get relaxed structure. Args: struct: ASE structure model: MACE model - Nrelax: Maximum number of relaxation steps + max_steps: Maximum number of relaxation steps fmax: Force convergence criterion use_autobatcher: Whether to use automatic batching @@ -63,23 +62,23 @@ def get_relaxed_structure( final_state = ts.optimize( system=struct, model=model, - optimizer=ts.optimizers.frechet_cell_fire, - constant_volume=True, - hydrostatic_strain=True, - max_steps=Nrelax, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, + max_steps=max_steps, convergence_fn=converge_max_force, trajectory_reporter=reporter, autobatcher=use_autobatcher, + constant_volume=True, + hydrostatic_strain=True, ) - # Remove trajectory file os.remove(trajectory_file) return final_state def get_qha_structures( - state: ts.state.SimState, + state: ts.SimState, length_factors: np.ndarray, model: ModelInterface, Nmax: int = 300, @@ -117,12 +116,13 @@ def get_qha_structures( scaled_state = ts.optimize( system=scaled_structs, model=model, - optimizer=ts.optimizers.frechet_cell_fire, - constant_volume=True, - hydrostatic_strain=True, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, max_steps=Nmax, convergence_fn=ts.runners.generate_force_convergence_fn(force_tol=fmax), autobatcher=use_autobatcher, + constant_volume=True, + hydrostatic_strain=True, ) return scaled_state.to_phonopy() @@ -165,9 +165,9 @@ def get_qha_phonons( ) ph.generate_displacements(distance=displ) supercells = ph.supercells_with_displacements - n_atoms = sum(len(cell) for cell in supercells) + n_atoms = 0 if supercells is None else sum(len(cell) for cell in supercells) supercell_boundaries.append(supercell_boundaries[-1] + n_atoms) - supercells_flat.extend(supercells) + supercells_flat.extend([] if supercells is None else supercells) ph_sets.append(ph) # Run the model on flattened structure @@ -194,15 +194,15 @@ def get_qha_phonons( energies = ( torch.tensor([r["potential_energy"] for r in results]).detach().cpu().numpy() ) - for i, ph in enumerate(ph_sets): - start, end = supercell_boundaries[i], supercell_boundaries[i + 1] + for sys_idx, ph in enumerate(ph_sets): + start, end = supercell_boundaries[sys_idx], supercell_boundaries[sys_idx + 1] forces_i = forces[start:end] n_atoms = len(ph.supercell) n_displacements = len(ph.supercells_with_displacements) force_sets_i = [] - for j in range(n_displacements): - start_j = j * n_atoms - end_j = (j + 1) * n_atoms + for disp_idx in range(n_displacements): + start_j = disp_idx * n_atoms + end_j = (disp_idx + 1) * n_atoms force_sets_i.append(forces_i[start_j:end_j]) force_sets.append(force_sets_i) @@ -210,16 +210,17 @@ def get_qha_phonons( # Set device and data type -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float64 + autobatcher = False # Load the raw model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) model = MaceModel( model=loaded_model, @@ -244,7 +245,7 @@ def get_qha_phonons( # Relax initial structure state = get_relaxed_structure( - struct=struct, model=model, Nrelax=Nmax, fmax=fmax, use_autobatcher=autobatcher + struct=struct, model=model, max_steps=Nmax, fmax=fmax, use_autobatcher=autobatcher ) # Get relaxed structures at different volumes @@ -272,7 +273,7 @@ def get_qha_phonons( free_energies = [] entropies = [] heat_capacities = [] -n_displacements = len(ph_sets[0].supercells_with_displacements) +n_displacements = len(getattr(ph_sets[0], "supercells_with_displacements", [])) for i in range(len(ph_sets)): ph_sets[i].forces = force_sets[i] ph_sets[i].produce_force_constants() diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index ac52a574e..066a06add 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -9,22 +9,26 @@ # "pymatgen>=2025.2.18", # ] # /// - import os import time +from typing import TYPE_CHECKING, Literal, cast import numpy as np import plotly.graph_objects as go import torch -import tqdm from ase.build import bulk from mace.calculators.foundations_models import mace_mp from phono3py import Phono3py +from tqdm import tqdm import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls +if TYPE_CHECKING: + from phonopy.structure.atoms import PhonopyAtoms + + def print_relax_info(trajectory_file: str, device: torch.device) -> None: """Print relaxation information from trajectory file. @@ -47,15 +51,15 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: start_time = time.perf_counter() -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float64 # Load the raw model from URL loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) model = MaceModel( model=loaded_model, @@ -72,10 +76,10 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: # supercell matrix for phonon calculation (use larger cell for better accuracy) supercell_matrix = [1, 1, 1] supercell_matrix_fc2 = [2, 2, 2] # supercell matrix for FC2 calculation -Nrelax = 300 # number of relaxation steps +max_steps = 300 # number of relaxation steps fmax = 1e-3 # force convergence displ = 0.05 # atomic displacement for phonons (in Angstrom) -conductivity_type = "wigner" # "wigner", "kubo" +conductivity_type: Literal["wigner", "kubo"] = "wigner" temperatures = np.arange( 0, 1600, 10 ) # temperature range for thermal conductivity calculation @@ -96,12 +100,12 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: final_state = ts.optimize( system=struct, model=model, - optimizer=ts.optimizers.frechet_cell_fire, - constant_volume=True, - hydrostatic_strain=True, - max_steps=Nrelax, + optimizer=ts.OptimFlavor.fire, + max_steps=max_steps, convergence_fn=converge_max_force, trajectory_reporter=reporter, + constant_volume=True, + hydrostatic_strain=True, ) print_relax_info(trajectory_file, device) @@ -117,12 +121,12 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: # Calculate FC2 ph3.generate_fc2_displacements(distance=displ) supercells_fc2 = ph3.phonon_supercells_with_displacements -state = ts.io.phonopy_to_state(supercells_fc2, device, dtype) +state = ts.io.phonopy_to_state(supercells_fc2, device=device, dtype=dtype) results = model(state) n_atoms_per_supercell = [len(sc) for sc in supercells_fc2] force_sets = [] start_idx = 0 -for n_atoms in tqdm.tqdm(n_atoms_per_supercell, desc="FC2"): +for n_atoms in tqdm(n_atoms_per_supercell, desc="FC2"): end_idx = start_idx + n_atoms force_sets.append(results["forces"][start_idx:end_idx].detach().cpu().numpy()) start_idx = end_idx @@ -131,13 +135,13 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: # Calculate FC3 ph3.generate_displacements(distance=displ) -supercells_fc3 = ph3.supercells_with_displacements -state = ts.io.phonopy_to_state(supercells_fc3, device, dtype) +supercells_fc3 = cast("list[PhonopyAtoms]", ph3.supercells_with_displacements) +state = ts.io.phonopy_to_state(supercells_fc3, device=device, dtype=dtype) results = model(state) n_atoms_per_supercell = [len(sc) for sc in supercells_fc3] force_sets = [] start_idx = 0 -for n_atoms in tqdm.tqdm(n_atoms_per_supercell, desc="FC3"): +for n_atoms in tqdm(n_atoms_per_supercell, desc="FC3"): end_idx = start_idx + n_atoms force_sets.append(results["forces"][start_idx:end_idx].detach().cpu().numpy()) start_idx = end_idx diff --git a/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py b/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py index 4a71bcdfd..beba61504 100644 --- a/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py +++ b/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py @@ -7,22 +7,19 @@ # "scipy>=1.15", # ] # /// - # ruff: noqa: RUF001 - -import numpy as np import torch from plotly.subplots import make_subplots from torch_sim.models.soft_sphere import soft_sphere_pair, soft_sphere_pair_force -sigma = 1.0 -epsilon = 1.0 -alpha = 2 +sigma = torch.tensor(1.0) +epsilon = torch.tensor(1.0) +alpha = torch.tensor(2) # Generate distance values from 0.1*sigma to 2*sigma -dr = np.linspace(0.1 * sigma, 2 * sigma, 1000) +dr = torch.linspace(0.1 * sigma, 2 * sigma, 1000) dr_tensor = torch.sqrt(torch.tensor(dr)) # Calculate potential energy diff --git a/examples/scripts/7_Others/7.2_Stress_autograd.py b/examples/scripts/7_Others/7.2_Stress_autograd.py index faac1b8e5..97c581209 100644 --- a/examples/scripts/7_Others/7.2_Stress_autograd.py +++ b/examples/scripts/7_Others/7.2_Stress_autograd.py @@ -8,11 +8,8 @@ """ # /// script -# dependencies = [ -# "scipy>=1.15", -# ] +# dependencies = ["scipy>=1.15"] # /// - import timeit import torch @@ -23,9 +20,9 @@ torch.set_default_tensor_type(torch.DoubleTensor) # Set simulation parameters n_steps = 10_000 -kT = 0.722 # Temperature in energy units -sigma = 1.0 # Length parameter -epsilon = 1.0 # Energy parameter +kT = torch.tensor(0.722) # Temperature in energy units +sigma = torch.tensor(1.0) # Length parameter +epsilon = torch.tensor(1.0) # Energy parameter # Grid initialization Nx = 10 @@ -93,7 +90,7 @@ def force_fn(R: torch.Tensor, box: torch.Tensor) -> torch.Tensor: return force_components.sum(dim=0) -def stress_fn(R: torch.Tensor, box: torch.Tensor) -> torch.Tensor: +def stress_fn(R: torch.Tensor, box: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Calculate stress using a brute force method.""" # Create displacement vectors for all pairs ri = R.unsqueeze(0) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 2b845c074..489b9dd5a 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -1,12 +1,8 @@ """Batched neighbor list.""" # /// script -# dependencies = [ -# "ase>=3.24", -# "scipy>=1.15", -# ] +# dependencies = ["ase>=3.26", "scipy>=1.15"] # /// - import torch from ase.build import bulk @@ -16,7 +12,7 @@ atoms_list = [bulk("Si", "diamond", a=5.43), bulk("Ge", "diamond", a=5.65)] -state = ts.io.atoms_to_state(atoms_list, device="cpu", dtype=torch.float32) +state = ts.io.atoms_to_state(atoms_list, device=torch.device("cpu"), dtype=torch.float32) pos, cell, pbc = state.positions, state.cell, state.pbc system_idx, n_atoms = state.system_idx, state.n_atoms cutoff = 4.0 diff --git a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py index faaa9be5d..50bb17a70 100644 --- a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py +++ b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py @@ -2,12 +2,11 @@ # /// script # dependencies = [ -# "ase>=3.24", +# "ase>=3.26", # "matplotlib", # "numpy", # ] # /// - from typing import Any import matplotlib.pyplot as plt @@ -72,9 +71,8 @@ def plot_results(*, time: np.ndarray, vacf: np.ndarray, window_count: int) -> No def main() -> None: """Run velocity autocorrelation simulation using Lennard-Jones model.""" - state, lj_model, dt, kT, device, dtype, timestep = prepare_system() - nve_init, nve_update = ts.integrators.nve(model=lj_model, dt=dt, kT=kT) - state = nve_init(state) # type: ignore[call-arg] + state, lj_model, dt, kT, device, _dtype, timestep = prepare_system() + state = ts.nve_init(model=lj_model, state=state, kT=kT) window_size = 150 # Length of correlation: dt * correlation_dt * window_size vacf_calc = VelocityAutoCorrelation( @@ -95,7 +93,7 @@ def main() -> None: num_steps = 15000 # NOTE: short run for step in range(num_steps): - state = nve_update(state) # type: ignore[call-arg] + state = ts.nve_update(model=lj_model, state=state, dt=dt) # type: ignore[call-arg] reporter.report(state, step) reporter.close() diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 1eed7d9f5..819ebb68f 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -3,15 +3,12 @@ """ # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# "plotly>=6.0.0", -# ] +# dependencies = ["mace-torch>=0.3.12", "plotly>=6.0.0"] # /// - import os import time -from typing import Literal +from functools import partial +from typing import Literal, TypedDict import numpy as np import plotly.graph_objects as go @@ -25,13 +22,13 @@ import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import GDState, fire, frechet_cell_fire +from torch_sim.optimizers import OptimState from torch_sim.state import SimState # Set device, data type and unit conversion SMOKE_TEST = os.getenv("CI") is not None -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 unit_conv = ts.units.UnitConversion @@ -39,8 +36,8 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, - device=device, + default_dtype=str(dtype).lstrip("torch."), + device=str(device), ) # Number of steps to run @@ -127,20 +124,20 @@ def run_optimization_ts( # noqa: PLR0915 *, - initial_state: SimState, + initial_state: SimState | OptimState, ts_md_flavor: Literal["vv_fire", "ase_fire"], ts_use_frechet: bool, force_tol: float, max_iterations_ts: int, -) -> tuple[torch.Tensor, SimState | None]: - """Runs Torch-Sim optimization and returns convergence steps and final state.""" +) -> tuple[torch.Tensor, OptimState | None]: + """Runs torch-sim optimization and returns convergence steps and final state.""" print( - f"\n--- Running Torch-Sim optimization: flavor={ts_md_flavor}, " + f"\n--- Running torch-sim optimization: flavor={ts_md_flavor}, " f"frechet_cell_opt={ts_use_frechet}, force_tol={force_tol} ---" ) start_time = time.perf_counter() - print("Initial cell parameters (Torch-Sim):") + print("Initial cell parameters (torch-sim):") for k_idx in range(initial_state.n_systems): cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) @@ -151,20 +148,18 @@ def run_optimization_ts( # noqa: PLR0915 ) if ts_use_frechet: - init_fn_opt, update_fn_opt = frechet_cell_fire( - model=model, md_flavor=ts_md_flavor - ) + init_fn_opt = partial(ts.fire_init, cell_filter=ts.CellFilter.frechet) + step_fn_opt = ts.fire_step else: - init_fn_opt, update_fn_opt = fire(model=model, md_flavor=ts_md_flavor) + init_fn_opt, step_fn_opt = ts.fire_init, ts.fire_step - opt_state = init_fn_opt(initial_state.clone()) + opt_state = init_fn_opt(model=model, state=initial_state.clone()) batcher = ts.InFlightAutoBatcher( model=model, memory_scales_with="n_atoms", max_memory_scaler=1000, max_iterations=max_iterations_ts, - return_indices=True, ) batcher.load_states(opt_state) @@ -184,8 +179,9 @@ def run_optimization_ts( # noqa: PLR0915 last_active_state = opt_state while True: - result = batcher.next_batch(last_active_state, convergence_tensor_for_batcher) - opt_state, converged_states_from_batcher, current_indices_list = result + opt_state, converged_states_from_batcher = batcher.next_batch( + last_active_state, convergence_tensor_for_batcher + ) all_converged_states.extend(converged_states_from_batcher) if opt_state is None: @@ -194,12 +190,12 @@ def run_optimization_ts( # noqa: PLR0915 last_active_state = opt_state current_indices = torch.tensor( - current_indices_list, dtype=torch.long, device=device + batcher.current_idx, dtype=torch.long, device=device ) steps_this_round = 1 for _ in range(steps_this_round): - opt_state = update_fn_opt(opt_state) + opt_state = step_fn_opt(model=model, state=opt_state) global_step += steps_this_round convergence_tensor_for_batcher = convergence_fn(opt_state, None) @@ -229,7 +225,7 @@ def run_optimization_ts( # noqa: PLR0915 final_state_concatenated = ts.concatenate_states(final_states_list) if final_state_concatenated is not None and hasattr(final_state_concatenated, "cell"): - print("Final cell parameters (Torch-Sim):") + print("Final cell parameters (torch-sim):") for k_idx in range(final_state_concatenated.n_systems): cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() ase_cell_k = Cell(cell_tensor_k) @@ -240,13 +236,13 @@ def run_optimization_ts( # noqa: PLR0915 ) else: print( - "Final cell parameters (Torch-Sim): Not available (final_state_concatenated " + "Final cell parameters (torch-sim): Not available (final_state_concatenated " "is None or has no cell)." ) end_time = time.perf_counter() print( - f"Finished Torch-Sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " + f"Finished torch-sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " f"{end_time - start_time:.2f} seconds." ) return convergence_steps, final_state_concatenated @@ -258,7 +254,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ase_use_frechet_filter: bool, force_tol: float, max_steps_ase: int, -) -> tuple[torch.Tensor, GDState | None]: +) -> tuple[torch.Tensor, OptimState | None]: """Runs ASE optimization and returns convergence steps and final state.""" print( f"\n--- Running ASE optimization: frechet_filter={ase_use_frechet_filter}, " @@ -271,14 +267,14 @@ def run_optimization_ase( # noqa: C901, PLR0915 final_ase_atoms_list = [] convergence_steps_list = [] - for i, single_sim_state in enumerate(individual_initial_states): - print(f"Optimizing structure {i + 1}/{num_structures} with ASE...") + for sys_idx, single_sim_state in enumerate(individual_initial_states): + print(f"Optimizing structure {sys_idx + 1}/{num_structures} with ASE...") ase_atoms_orig = ts.io.state_to_atoms(single_sim_state)[0] initial_cell_ase = ase_atoms_orig.get_cell() initial_params_str = ", ".join([f"{p:.2f}" for p in initial_cell_ase.cellpar()]) print( - f" Initial cell (ASE Structure {i + 1}): " + f" Initial cell (ASE Structure {sys_idx + 1}): " f"Volume={initial_cell_ase.volume:.2f} Γ…Β³, " f"Params=[{initial_params_str}]" ) @@ -292,7 +288,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 optim_target_atoms = ase_atoms_orig if ase_use_frechet_filter: - print(f"Applying FrechetCellFilter to structure {i + 1}") + print(f"Applying FrechetCellFilter to structure {sys_idx + 1}") optim_target_atoms = FrechetCellFilter(ase_atoms_orig) dyn = ASEFIRE(optim_target_atoms, trajectory=None, logfile=None) @@ -301,40 +297,38 @@ def run_optimization_ase( # noqa: C901, PLR0915 dyn.run(fmax=force_tol, steps=max_steps_ase) if dyn.converged(): convergence_steps_list.append(dyn.nsteps) - print(f"ASE structure {i + 1} converged in {dyn.nsteps} steps.") + print(f"ASE structure {sys_idx + 1} converged in {dyn.nsteps} steps.") else: print( - f"ASE optimization for structure {i + 1} did not converge within " - f"{max_steps_ase} steps. Steps taken: {dyn.nsteps}." + f"ASE optimization for structure {sys_idx + 1} did not converge " + f"within {max_steps_ase} steps. Steps taken: {dyn.nsteps}." ) convergence_steps_list.append(-1) - except Exception as e: # noqa: BLE001 - print(f"ASE optimization failed for structure {i + 1}: {e}") + except Exception as exc: # noqa: BLE001 + print(f"ASE optimization failed for structure {sys_idx + 1}: {exc}") convergence_steps_list.append(-1) - final_ats_for_print = ( - optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig - ) + final_ats_for_print = getattr(optim_target_atoms, "atoms", ase_atoms_orig) final_cell_ase = final_ats_for_print.get_cell() final_params_str = ", ".join([f"{p:.2f}" for p in final_cell_ase.cellpar()]) print( - f" Final cell (ASE Structure {i + 1}): " + f" Final cell (ASE Structure {sys_idx + 1}): " f"Volume={final_cell_ase.volume:.2f} Γ…Β³, " f"Params=[{final_params_str}]" ) final_ase_atoms_list.append(final_ats_for_print) - all_positions = [] - all_masses = [] - all_atomic_numbers = [] - all_cells = [] - all_systems_for_gd = [] - final_energies_ase = [] - final_forces_ase_tensors = [] + all_positions: list[torch.Tensor] = [] + all_masses: list[torch.Tensor] = [] + all_atomic_numbers: list[torch.Tensor] = [] + all_cells: list[torch.Tensor] = [] + all_systems_for_gd: list[torch.Tensor] = [] + final_energies_ase: list[float] = [] + final_forces_ase_tensors: list[torch.Tensor] = [] current_atom_offset = 0 - for system_idx, ats_final in enumerate(final_ase_atoms_list): + for sys_idx, ats_final in enumerate(final_ase_atoms_list): all_positions.append( torch.tensor(ats_final.get_positions(), device=device, dtype=dtype) ) @@ -351,17 +345,15 @@ def run_optimization_ase( # noqa: C901, PLR0915 num_atoms_in_current = len(ats_final) all_systems_for_gd.append( - torch.full( - (num_atoms_in_current,), system_idx, device=device, dtype=torch.long - ) + torch.full((num_atoms_in_current,), sys_idx, device=device, dtype=torch.long) ) current_atom_offset += num_atoms_in_current try: if ats_final.calc is None: print( - "Re-attaching ASE calculator for final energy/forces for " - f"structure {system_idx}." + "Re-attaching ASE calculator for final energy/forces for structure " + f"{sys_idx}." ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, @@ -373,10 +365,8 @@ def run_optimization_ase( # noqa: C901, PLR0915 final_forces_ase_tensors.append( torch.tensor(ats_final.get_forces(), device=device, dtype=dtype) ) - except Exception as e: # noqa: BLE001 - print( - f"Couldn't get final energy/forces for an ASE structure {system_idx}: {e}" - ) + except Exception as exc: # noqa: BLE001 + print(f"Couldn't get final energy/forces for ASE structure {sys_idx}: {exc}") final_energies_ase.append(float("nan")) if all_positions and len(all_positions[-1]) > 0: final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) @@ -386,7 +376,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) if not all_positions: # If all optimizations failed early - print("Warning: No successful ASE structures to form GDState.") + print("Warning: No successful ASE structures to form OptimState.") return torch.tensor(convergence_steps_list, dtype=torch.long, device=device), None # Concatenate all parts @@ -403,11 +393,11 @@ def run_optimization_ase( # noqa: C901, PLR0915 if torch.isnan(concatenated_energies).any(): print( "Warning: NaN values found in final ASE energies. " - "GDState energy tensor will contain NaNs." + "OptimState energy tensor will contain NaNs." ) - # Create GDState instance - final_state_as_gd = GDState( + # Create OptimState instance + final_state_as_gd = OptimState( positions=concatenated_positions, masses=concatenated_masses, cell=concatenated_cells, @@ -437,25 +427,25 @@ def run_optimization_ase( # noqa: C901, PLR0915 configs_to_run = [ { "name": "torch-sim VV-FIRE (PosOnly)", - "type": "torch_sim", + "type": "torch-sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": False, }, { "name": "torch-sim ASE-FIRE (PosOnly)", - "type": "torch_sim", + "type": "torch-sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": False, }, { "name": "torch-sim VV-FIRE (Frechet Cell)", - "type": "torch_sim", + "type": "torch-sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": True, }, { "name": "torch-sim ASE-FIRE (Frechet Cell)", - "type": "torch_sim", + "type": "torch-sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": True, }, @@ -471,7 +461,15 @@ def run_optimization_ase( # noqa: C901, PLR0915 }, ] -all_results = {} + +class ResultData(TypedDict): + """Result data for a single optimization run.""" + + steps: torch.Tensor + final_state: OptimState | None + + +all_results: dict[str, ResultData] = {} for config_run in configs_to_run: print(f"\n\nStarting configuration: {config_run['name']}") @@ -481,11 +479,11 @@ def run_optimization_ase( # noqa: C901, PLR0915 ase_use_frechet_filter_val = config_run.get("ase_use_frechet_filter", False) steps: torch.Tensor | None = None - final_state_opt: SimState | GDState | None = None + final_state_opt: OptimState | None = None - if optimizer_type_val == "torch_sim": + if optimizer_type_val == "torch-sim": if ts_md_flavor_val is None: - raise ValueError(f"{ts_md_flavor_val=} must be provided for torch_sim") + raise ValueError(f"{ts_md_flavor_val=} must be provided for torch-sim") steps, final_state_opt = run_optimization_ts( initial_state=state.clone(), ts_md_flavor=ts_md_flavor_val, @@ -557,7 +555,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 mean_displacements = [] for s1, s2 in zip(state1_list, state2_list, strict=True): - if s1.n_atoms == 0 or s2.n_atoms == 0: + if 0 in {s1.n_atoms, s2.n_atoms}: mean_displacements.append(float("nan")) continue pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) @@ -736,11 +734,11 @@ def run_optimization_ase( # noqa: C901, PLR0915 "Setting energy diff to NaN." ) avg_energy_diffs_fig2.append(np.nan) - elif not processed_current_name and name not in [ + elif not processed_current_name and name not in ( n for n, v in zip(plot_names_fig2, avg_energy_diffs_fig2, strict=False) if not np.isnan(v) - ]: + ): print(f"Plot2: Fallback for {name}, setting energy diff to NaN.") avg_energy_diffs_fig2.append(np.nan) @@ -898,7 +896,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 fig3_plotly.add_bar(name=name, x=structure_names, y=disp_data_fig3[:, idx]) -title = "Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)" +title = "Mean Displacement of torch-sim Methods to ASE Counterparts (per Structure)" fig3_plotly.update_layout( barmode="group", title=dict(text=title, x=0.5, y=1), diff --git a/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py new file mode 100644 index 000000000..bd2d0ff53 --- /dev/null +++ b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py @@ -0,0 +1,151 @@ +"""Heat flux and thermal conductivity example with Lennard-Jones potential.""" + +# /// script +# dependencies = [ +# "ase>=3.26", +# "matplotlib", +# "numpy", +# ] +# /// +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase.build import bulk + +import torch_sim as ts +from torch_sim.elastic import full_3x3_to_voigt_6_stress +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.properties.correlations import HeatFluxAutoCorrelation +from torch_sim.units import MetalUnits as Units + + +SMOKE_TEST = os.getenv("CI") is not None + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float64 + +# Using solid Ar w/ LJ for ease +atoms = bulk("Ar", crystalstructure="fcc", a=5.376, cubic=True) +N_repeats = 3 if SMOKE_TEST else 4 +atoms = atoms.repeat((N_repeats, N_repeats, N_repeats)) +state = ts.io.atoms_to_state(atoms, device=device, dtype=dtype) + +# Simulation parameters +# See https://docs.lammps.org/compute_heat_flux.html for more details +epsilon = 0.0104 # eV +sigma = 3.405 # Γ… +cutoff = 13 # Γ… +temperature = 70.0 # Kelvin +timestep = 0.004 # ps (4 fs) +num_steps_equilibration = 1000 if SMOKE_TEST else 8000 +num_steps_production = 2000 if SMOKE_TEST else 100000 +window_size = 200 # Length of correlation: dt * correlation_dt * window_size +correlation_dt = 10 # Step delta between correlations + +# Lennard-Jones model +lj_model = LennardJonesModel( + sigma=sigma, + epsilon=epsilon, + cutoff=cutoff, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, + per_atom_energies=True, + per_atom_stresses=True, +) + +dt = torch.tensor(timestep * Units.time, device=device, dtype=dtype) +kT = torch.tensor(temperature * Units.temperature, device=device, dtype=dtype) +state = ts.nvt_langevin_init(model=lj_model, state=state, kT=kT) + +# Short equilibration run +# Shape: (num_steps, batch, dim) +heat_flux = torch.zeros((num_steps_equilibration, 3), device=device, dtype=dtype) + +for step in range(num_steps_equilibration): + state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + results = lj_model(state) + J = ts.quantities.calc_heat_flux( + momenta=state.momenta, + masses=state.masses, + velocities=None, + energies=results["energies"], + stresses=full_3x3_to_voigt_6_stress(results["stresses"]), + batch=state.system_idx, + is_centroid_stress=False, + is_virial_only=False, + ) + heat_flux[step] = J + if step % 1000 == 0: + print(f"Step {step} | {state.energy.item():.4f} eV") + +state = ts.nvt_langevin_init(model=lj_model, state=state, kT=kT) + +hfacf_calc = HeatFluxAutoCorrelation( + model=lj_model, + window_size=window_size, + device=device, + use_running_average=True, + normalize=False, +) + +# Sampling freq is controlled by prop_calculators +# trajectory = "kappa_example.h5" + +reporter = ts.TrajectoryReporter( + None, # add trajectory name here if you want to save the trajectory to disk + state_frequency=100, + prop_calculators={correlation_dt: {"hfacf": hfacf_calc}}, +) + +# Short production run +for step in range(num_steps_production): + state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + reporter.report(state, step) + if step % 1000 == 0: + print(f"Step {step} | {state.energy.item():.4f} eV") + +reporter.close() + +# HFACF results and plot +# Timesteps -> Time in fs +time_steps = np.arange(window_size) +time_fs = time_steps * correlation_dt * timestep * 1000 +hface_numpy = hfacf_calc.hfacf.detach().cpu().numpy() + +# Calculate kappa +integral = np.trapezoid(hface_numpy) +constant = ( + state.volume.item() + / (3 * temperature * temperature * Units.temperature) + * timestep + * correlation_dt +) +kappa = constant * integral +print(f"kappa: {kappa:.8f} (eV/ps/Ang^2/K)") + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + +ax1.plot(heat_flux[:, 0].detach().cpu().numpy(), "b-", linewidth=2, label=r"$J_x$") +ax1.plot(heat_flux[:, 1].detach().cpu().numpy(), "r-", linewidth=2, label=r"$J_y$") +ax1.plot(heat_flux[:, 2].detach().cpu().numpy(), "g-", linewidth=2, label=r"$J_z$") +ax1.set_xlabel("Time (fs)", fontsize=12) +ax1.set_ylabel(r"$J$ (eV/ps $\AA^2$)", fontsize=12) +ax1.set_title("Heat Flux for Ar (LJ)", fontsize=14) +ax1.axhline(y=0, color="k", linestyle="--", alpha=0.3) +ax1.legend(fontsize=12) + +ax2.plot(time_fs, hface_numpy, "b-", linewidth=2) +ax2.set_xlabel("Time (fs)", fontsize=12) +ax2.set_ylabel(r"$\langle \vec{J}(0) \cdot \vec{J}(t) \rangle$", fontsize=12) +ax2.set_title( + rf"$\kappa$ = {kappa:.8f} (eV/ps $\AA^2$ K) (Average of {hfacf_calc._window_count} windows)", # noqa: E501, SLF001 + fontsize=14, +) +ax2.axhline(y=0, color="k", linestyle="--", alpha=0.3) + +plt.tight_layout() +plt.savefig("heat_flux_and_kappa.pdf") diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index fec399458..317567dcd 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -1,12 +1,7 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// -#
# %% [markdown] @@ -26,7 +21,7 @@ ## Introduction Simulating many molecular systems on GPUs can be challenging when the total number of -atoms exceeds available GPU memory. The `torch_sim.autobatching` module solves this by: +atoms exceeds available GPU memory. The `ts.autobatching` module solves this by: 1. Automatically determining optimal batch sizes based on GPU memory constraints 2. Providing two complementary strategies: binning and in-flight @@ -43,11 +38,7 @@ import torch_sim as ts -def mock_determine_max_batch_size(*args, **kwargs): - return 3 - - -ts.autobatching.determine_max_batch_size = mock_determine_max_batch_size +ts.autobatching.determine_max_batch_size = lambda *args, **kwargs: 3 # type: ignore[invalid-assignment] # %% [markdown] @@ -152,7 +143,7 @@ def process_batch(batch): # Process each batch processed_batches = [] -for batch in batcher: +for batch, _indices in batcher: # Process the batch (e.g., run dynamics or optimization) batch = process_batch(batch) processed_batches.append(batch) @@ -186,10 +177,7 @@ def process_batch(batch): """ # %% Initialize nvt langevin integrator -nvt_init, nvt_update = ts.nvt_langevin(mace_model, dt=0.001, kT=0.01) - -# Prepare states for optimization -nvt_state = nvt_init(state) +nvt_state = ts.nvt_langevin_init(mace_model, state, kT=0.01) # Initialize the batcher batcher = ts.BinningAutoBatcher( @@ -199,15 +187,15 @@ def process_batch(batch): max_memory_scaler = batcher.load_states(nvt_state) print(f"Max memory scaler: {max_memory_scaler}") -print("There are ", len(batcher.index_bins), " bins") -print("The indices of the states in each bin are: ", batcher.index_bins) +print(f"There are {len(batcher.index_bins)} bins") +print(f"The indices of the states in each bin are: {batcher.index_bins}") # Run optimization on each batch finished_states = [] -for batch in batcher: - # Run 5 steps of FIRE optimization +for batch, _indices in batcher: + # Run 5 steps of NVT dynamics for _ in range(5): - batch = nvt_update(batch) + batch = ts.nvt_langevin_update(mace_model, batch, dt=0.001, kT=0.01) finished_states.append(batch) @@ -232,8 +220,9 @@ def process_batch(batch): """ # %% -fire_init, fire_update = ts.frechet_cell_fire(mace_model) -fire_state = fire_init(state) +fire_state = ts.fire_init( + model=mace_model, state=state, cell_filter=ts.CellFilter.frechet +) # Initialize the batcher batcher = ts.InFlightAutoBatcher( @@ -263,7 +252,7 @@ def process_batch(batch): # optimize the batch, we stagger the steps to avoid state processing overhead for _ in range(10): - fire_state = fire_update(fire_state) + fire_state = ts.fire_step(model=mace_model, state=fire_state) # Check which states have converged convergence_tensor = convergence_fn(fire_state, None) @@ -295,17 +284,14 @@ def process_batch(batch): using the `TrajectoryReporter`, because the files must be regularly updated. """ -# %% Initialize with return_indices=True +# %% Initialize batcher batcher = ts.BinningAutoBatcher( - model=mace_model, - memory_scales_with="n_atoms", - max_memory_scaler=80, - return_indices=True, + model=mace_model, memory_scales_with="n_atoms", max_memory_scaler=80 ) batcher.load_states(state) -# Iterate with indices -for batch, indices in batcher: +# Iterate over batches +for idx, (batch, indices) in enumerate(batcher): print(f"Processing states with original indices: {indices}") # Process batch... diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py new file mode 100644 index 000000000..8fbe01a85 --- /dev/null +++ b/examples/tutorials/diff_sim.py @@ -0,0 +1,448 @@ +# %% +# /// script +# dependencies = ["matplotlib"] +# /// + +# %% +import torch_sim as ts +from typing import cast +from numpy.typing import NDArray +import torch +import matplotlib.pyplot as plt +from torch_sim.models.interface import ModelInterface +from torch_sim.models.soft_sphere import ( + soft_sphere_pair, + DEFAULT_SIGMA, + DEFAULT_EPSILON, + DEFAULT_ALPHA, +) +from torch_sim import transforms +from collections.abc import Callable +from dataclasses import dataclass +from torch._functorch import config + +config.donated_buffer = False + +# %% [markdown] +""" +# Differentiable Simulation + +In this tutorial, we will explore how to use TorchSim to perform differentiable simulations. +This tutorial will reproduce the bubble raft example from [JAX-MD](https://github.com/jax-md/jax-md/blob/main/notebooks/meta_optimization.ipynb) +and perform meta-optimization to find the optimal diameter. +""" + + +# %% +def draw_system( + R: torch.Tensor, box_size: float, marker_size: float, color: list[float] | None = None +): + """Draw a system of particles on the plot.""" + if color == None: + color = [64 / 256] * 3 + ms = marker_size / box_size + + R = torch.tensor(R) + + marker_style = dict( + linestyle="none", + markeredgewidth=3, + marker="o", + markersize=ms, + color=color, + fillstyle="none", + ) + + plt.plot(R[:, 0], R[:, 1], **marker_style) + plt.plot(R[:, 0] + box_size, R[:, 1], **marker_style) + plt.plot(R[:, 0], R[:, 1] + box_size, **marker_style) + plt.plot(R[:, 0] + box_size, R[:, 1] + box_size, **marker_style) + plt.plot(R[:, 0] - box_size, R[:, 1], **marker_style) + plt.plot(R[:, 0], R[:, 1] - box_size, **marker_style) + plt.plot(R[:, 0] - box_size, R[:, 1] - box_size, **marker_style) + + plt.xlim([0, box_size]) + plt.ylim([0, box_size]) + plt.axis("off") + plt.gca().set_facecolor([1, 1, 1]) + + +# %% [markdown] +""" +## Soft Sphere potential + +We will use the soft sphere potential as our model. + +$$ +U(r_{ij}) = \begin{cases} + \left(1 - \frac{r_{ij}}{\sigma_{ij}}\right)^2 & \text{if } r_{ij} < \sigma_{ij} \\ + 0 & \text{if } r_{ij} \geq \sigma_{ij} +\end{cases} +$$ +""" +# %% +plt.gca().axhline(y=0, color="k") +plt.xlim([0, 1.5]) +plt.ylim([-0.2, 0.8]) + +# model = SoftSphereMultiModel(sigma_matrix=torch.tensor([1.0])) +dr = torch.linspace(0, 3.0, 80) +plt.plot(dr, soft_sphere_pair(dr, sigma=1), "b-", linewidth=3) +plt.fill_between(dr, soft_sphere_pair(dr), alpha=0.4) + +plt.xlabel(r"$r$", fontsize=20) +plt.ylabel(r"$U(r)$", fontsize=20) + +plt.show() + +# %% [markdown] +"""## Define the simple TorchSim model for the soft sphere potential.""" + + +# %% +@dataclass +class BaseState: + """Simple simulation state""" + + positions: torch.Tensor + cell: torch.Tensor + pbc: bool + species: torch.Tensor + + +class SoftSphereMultiModel(ModelInterface): + """Soft sphere potential""" + + def __init__( + self, + species: torch.Tensor | None = None, + sigma_matrix: torch.Tensor | None = None, + epsilon_matrix: torch.Tensor | None = None, + alpha_matrix: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + *, # Force keyword-only arguments + pbc: bool = True, + cutoff: float | None = None, + ) -> None: + """Initialize a soft sphere model for multi-component systems.""" + super().__init__() + self.device = device or torch.device("cpu") + self.dtype = dtype + self.pbc = pbc + + # Store species list and determine number of unique species + self.species = species + n_species = len(torch.unique(species)) + + # Initialize parameter matrices with defaults if not provided + default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype) + default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype) + default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype) + + # Validate matrix shapes match number of species + if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species): + raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})") + if epsilon_matrix is not None and epsilon_matrix.shape != ( + n_species, + n_species, + ): + raise ValueError(f"epsilon_matrix must have shape ({n_species}, {n_species})") + if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species): + raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})") + + # Create parameter matrices, using defaults if not provided + self.sigma_matrix = ( + sigma_matrix + if sigma_matrix is not None + else default_sigma + * torch.ones((n_species, n_species), dtype=dtype, device=device) + ) + self.epsilon_matrix = ( + epsilon_matrix + if epsilon_matrix is not None + else default_epsilon + * torch.ones((n_species, n_species), dtype=dtype, device=device) + ) + self.alpha_matrix = ( + alpha_matrix + if alpha_matrix is not None + else default_alpha + * torch.ones((n_species, n_species), dtype=dtype, device=device) + ) + + # Ensure parameter matrices are symmetric (required for energy conservation) + for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"): + matrix = getattr(self, matrix_name) + if not torch.allclose(matrix, matrix.T): + raise ValueError(f"{matrix_name} is not symmetric") + + # Set interaction cutoff distance + self.cutoff = torch.tensor( + cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device + ) + + def forward( + self, custom_state: BaseState, species: torch.Tensor | None = None + ) -> dict[str, torch.Tensor]: + """Compute energies and forces for a single unbatched system with multiple + species.""" + # Convert inputs to proper device/dtype and handle species + positions = custom_state.positions.requires_grad_(True) + cell = custom_state.cell + species = custom_state.species + + if species is not None: + species = species.to(device=self.device, dtype=torch.long) + else: + species = self.species + + species_idx = species + + # Direct N^2 computation of all pairs (minimum image convention) + dr_vec, distances = transforms.get_pair_displacements( + positions=positions, + cell=cell, + pbc=self.pbc, + ) + # Remove self-interactions and apply cutoff + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) + distances = distances.masked_fill(mask, float("inf")) + mask = distances < self.cutoff + + # Get valid pairs and their displacements + i, j = torch.where(mask) + mapping = torch.stack([j, i]) + dr_vec = dr_vec[mask] + distances = distances[mask] + + # Look up species-specific parameters for each interacting pair + pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair + pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair + + # Get interaction parameters from parameter matrices + pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2] + pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2] + pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2] + + # Calculate pair energies using species-specific parameters + pair_energies = soft_sphere_pair( + distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas + ) + + # Initialize results with total energy (divide by 2 to avoid double counting) + potential_energy = pair_energies.sum() / 2 + + grad_outputs: list[torch.Tensor | None] = [torch.ones_like(potential_energy)] + grad = torch.autograd.grad( + outputs=[potential_energy], + inputs=[positions], + grad_outputs=grad_outputs, + create_graph=False, + retain_graph=True, + ) + + force_grad = grad[0] + if force_grad is not None: + forces = torch.neg(force_grad) + + return {"energy": potential_energy, "forces": forces} + + +# %% [markdown] +""" +## Gradient Descent + +We will use a simple gradient descent to optimize the positions of the particles. +""" + + +# %% [markdown] +"""## Setup the simulation environment.""" + + +# %% +def box_size_at_number_density( + particle_count: int, number_density: torch.Tensor +) -> torch.Tensor: + return (particle_count / number_density) ** 0.5 + + +def box_size_at_packing_fraction( + diameter: torch.Tensor, packing_fraction: float +) -> torch.Tensor: + bubble_volume = N_2 * torch.pi * (torch.square(diameter) + 1) / 4 + return torch.sqrt(bubble_volume / packing_fraction) + + +def species_sigma(diameter: torch.Tensor) -> torch.Tensor: + d_AA = diameter + d_BB = 1 + d_AB = 0.5 * (diameter + 1) + return torch.tensor([[d_AA, d_AB], [d_AB, d_BB]]) + + +N = 128 +N_2 = N // 2 +species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.int32) +simulation_steps = 1000 +packing_fraction = 0.98 +markersize = 260 + + +# %% +def simulation( + diameter: torch.Tensor, seed: int = 42 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Create the simulation environment. + box_size = box_size_at_packing_fraction(diameter, packing_fraction) + cell = torch.eye(2) * box_size + # Create the energy function. + sigma = species_sigma(diameter) + model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) + model = cast(SoftSphereMultiModel, torch.compile(model)) + # Randomly initialize the system. + # Fix seed for reproducible random positions + torch.manual_seed(seed) + R = torch.rand(N, 2) * box_size + + # Minimize to the nearest minimum. + custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + state = ts.gradient_descent_init(model, state=custom_state) + + for _ in range(simulation_steps): + state = ts.gradient_descent_step(model, state, pos_lr=0.1) + return box_size, model(state)["energy"], state.positions + + +# %% [markdown] +"""## Packing at different diameters.""" + +# %% +plt.subplot(1, 2, 1) + +box_size, raft_energy, bubble_positions = simulation(torch.tensor(1.0)) +draw_system(bubble_positions, box_size.numpy(), markersize) + +plt.subplot(1, 2, 2) + +box_size, raft_energy, bubble_positions = simulation(torch.tensor(0.8)) +draw_system(bubble_positions[:N_2], box_size.numpy(), 0.8 * markersize) +draw_system(bubble_positions[N_2:], box_size.numpy(), markersize) +# %% [markdown] +"""## Forward simulation for different diameters and seeds.""" + +# %% +diameters = torch.linspace(0.4, 1.0, 10) +seeds = torch.arange(1, 6) +box_size_tensor = torch.zeros(len(diameters), len(seeds)) +raft_energy_tensor = torch.zeros(len(diameters), len(seeds)) +bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 2) +for i, d in enumerate(diameters): + for j, s in enumerate(seeds): + box_size, raft_energy, bubble_positions = simulation(d, s) + box_size_tensor[i, j] = box_size + raft_energy_tensor[i, j] = raft_energy.detach() + bubble_positions_tensor[i, j] = bubble_positions + print(f"Finished simulation for diameter {d}, final energy: {raft_energy.detach()}") +# %% +U_mean = torch.mean(raft_energy_tensor, dim=1) +U_std = torch.std(raft_energy_tensor, dim=1) +plt.plot(diameters.detach().numpy(), U_mean, linewidth=3) +plt.fill_between(diameters.detach().numpy(), U_mean + U_std, U_mean - U_std, alpha=0.4) + +plt.xlim([0.4, 1.0]) +plt.xlabel(r"$D$", fontsize=20) +plt.ylabel(r"$U$", fontsize=20) +plt.show() +# %% +ms = 185 +for i, d in enumerate(diameters): + plt.subplot(2, 5, i + 1) + c = min(1, max(0, (U_mean[i].detach().numpy() - 0.4) * 4)) + color = [c, 0, 1 - c] + draw_system( + bubble_positions_tensor[i, 0, :N_2].detach().numpy(), + box_size_tensor[i, 0].detach().numpy(), + d * ms, + color=color, + ) + draw_system( + bubble_positions_tensor[i, 0, N_2:].detach().numpy(), + box_size_tensor[i, 0].detach().numpy(), + ms, + color=color, + ) + + +# %% [markdown] +"""## Meta-optimization with differentiable simulation.""" + +# %% +short_simulation_steps = 10 + + +def short_simulation( + diameter: torch.Tensor, R: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + diameter = diameter.requires_grad_(True) + box_size = box_size_at_packing_fraction(diameter, packing_fraction) + cell = torch.eye(2) * box_size + # Create the energy function. + sigma = species_sigma(diameter) + model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) + + # Minimize to the nearest minimum. + custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + state = ts.gradient_descent_init(model, state=custom_state) + + for i in range(short_simulation_steps): + state = ts.gradient_descent_step(model, state, pos_lr=0.1) + + grad_outputs: list[torch.Tensor | None] = [ + torch.ones_like( + diameter, + ) + ] + grad = torch.autograd.grad( + outputs=[ + model(state)["energy"], + ], + inputs=[diameter], + grad_outputs=grad_outputs, + create_graph=True, + retain_graph=False, + ) + + dU_dd = grad[0] + return model(state)["energy"], dU_dd + + +# %% +dU_dD = torch.zeros(len(diameters), len(seeds)) +for i, d in enumerate(diameters): + for j, s in enumerate(seeds): + _, dU_dD[i, j] = short_simulation(d, bubble_positions_tensor[i, j]) + +# %% +plt.subplot(2, 1, 1) +dU_dD = dU_dD.detach() +dU_mean = torch.mean(dU_dD, dim=1) +dU_std = torch.std(dU_dD, dim=1) +plt.plot(diameters.detach().numpy(), dU_mean, linewidth=3) +plt.fill_between( + diameters.detach().numpy(), dU_mean + dU_std, dU_mean - dU_std, alpha=0.4 +) + + +plt.xlim([0.4, 1.0]) +plt.xlabel(r"$D$", fontsize=20) +plt.ylabel(r"$\langle{dU}/{dD}\rangle$", fontsize=20) + +plt.subplot(2, 1, 2) +plt.plot(diameters.detach().numpy(), U_mean, linewidth=3) +plt.fill_between(diameters.detach().numpy(), U_mean + U_std, U_mean - U_std, alpha=0.4) + +plt.xlim([0.4, 1.0]) +plt.xlabel(r"$D$", fontsize=20) +plt.ylabel(r"$U$", fontsize=20) diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index cf0debb4d..834cac7f1 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -1,14 +1,11 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script # dependencies = [ # "mace-torch>=0.3.12", # "pymatgen>=2025.2.18", -# "ase>=3.23.1", +# "ase>=3.26", # ] # /// -#
# %% [markdown] @@ -47,8 +44,8 @@ """ # %% -import torch import torch_sim as ts +import torch from ase.build import bulk from torch_sim.models.lennard_jones import LennardJonesModel @@ -76,7 +73,7 @@ final_state = ts.integrate( system=cu_atoms, # Input atomic system model=lj_model, # Energy/force model - integrator=ts.nvt_langevin, # Integrator to use + integrator=ts.MdFlavor.nvt_langevin, # Integrator to use n_steps=n_steps, # Number of MD steps temperature=2000, # Target temperature (K) timestep=0.002, # Integration timestep (ps) @@ -101,7 +98,7 @@ final_state = ts.integrate( system=cu_atoms, model=lj_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -154,7 +151,7 @@ final_state = ts.integrate( system=cu_atoms, model=lj_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -197,7 +194,7 @@ from torch_sim.models.mace import MaceModel # Use CUDA if available -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the MACE "small" foundation model mace = mace_mp(model="small", return_raw_model=True) @@ -212,7 +209,7 @@ final_state = ts.integrate( system=cu_atoms, model=mace_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -248,7 +245,7 @@ final_state = ts.integrate( system=systems, # List of systems to simulate model=mace_model, # Single model for all systems - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -267,7 +264,7 @@ """ # %% Create individual filenames for each system -filenames = [f"batch_traj_{i}.h5" for i in range(len(systems))] +filenames = [f"tmp/batch_traj_{i}.h5" for i in range(len(systems))] # Create a reporter that handles multiple trajectories batch_reporter = ts.TrajectoryReporter( @@ -280,7 +277,7 @@ final_state = ts.integrate( system=systems, model=mace_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -295,12 +292,14 @@ # %% Calculate final energy per atom for each system final_energies_per_atom = [] -for i, filename in enumerate(filenames): +for sys_idx, filename in enumerate(filenames): with ts.TorchSimTrajectory(filename) as traj: final_energy = traj.get_array("potential_energy")[-1].item() n_atoms = len(traj.get_atoms(-1)) final_energies_per_atom.append(final_energy / n_atoms) - print(f"System {i}: {final_energy:.6f} eV, {final_energy / n_atoms:.6f} eV/atom") + print( + f"System {sys_idx}: {final_energy:.6f} eV, {final_energy / n_atoms:.6f} eV/atom" + ) # %% [markdown] @@ -317,11 +316,7 @@ # %% -def mock_determine_max_batch_size(*args, **kwargs): - return 10 - - -ts.autobatching.determine_max_batch_size = mock_determine_max_batch_size +ts.autobatching.determine_max_batch_size = lambda *args, **kwargs: 10 # type: ignore[invalid-assignment] # %% [markdown] @@ -332,7 +327,7 @@ def mock_determine_max_batch_size(*args, **kwargs): final_state = ts.integrate( system=systems, model=mace_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, @@ -361,7 +356,8 @@ def mock_determine_max_batch_size(*args, **kwargs): final_state = ts.optimize( system=systems, model=mace_model, - optimizer=ts.unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, ) final_atoms = final_state.to_atoms() @@ -392,7 +388,7 @@ def default_energy_convergence(state, last_energy): # we arbitrarily add energy so nothing is converged convergence_tensor = default_energy_convergence(final_state, final_state.energy + 1) -print("Any converged?", torch.any(convergence_tensor).item()) +print(f"Any converged? {torch.any(convergence_tensor).item()}") # %% [markdown] @@ -408,7 +404,8 @@ def default_energy_convergence(state, last_energy): final_state = ts.optimize( system=systems, model=mace_model, - optimizer=ts.unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=force_convergence_fn, # Custom convergence function ) @@ -484,7 +481,7 @@ def default_energy_convergence(state, last_energy): final_state = ts.integrate( system=structure, model=lj_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=n_steps, temperature=2000, timestep=0.002, diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 08c3914b6..74b8f8e5d 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -1,13 +1,7 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# "pymatgen>=2025.2.18", -# ] +# dependencies = ["mace-torch>=0.3.12", "pymatgen>=2025.2.18"] # /// -#
# %% [markdown] @@ -34,15 +28,16 @@ """ # %% -from typing import ClassVar +from typing import ClassVar, cast import torch import torch_sim as ts from mace.calculators.foundations_models import mace_mp from torch_sim.integrators.md import MDState from torch_sim.models.mace import MaceModel +from torch_sim.monte_carlo import SwapMCState # Initialize the mace model -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel(model=mace, device=device) @@ -95,7 +90,7 @@ @dataclass -class HybridSwapMCState(ts.integrators.MDState): +class HybridSwapMCState(SwapMCState, MDState): """State for hybrid MD-Monte Carlo simulations. This state class extends the standard MDState with: @@ -127,18 +122,16 @@ class HybridSwapMCState(ts.integrators.MDState): kT = 1000 * MetalUnits.temperature # Initialize NVT Langevin dynamics state -nvt_init, nvt_step = ts.nvt_langevin(model=mace_model, dt=0.002, kT=kT, seed=42) -md_state = nvt_init(state) +md_state = ts.nvt_langevin_init(model=mace_model, state=state, kT=kT, seed=42) # Initialize swap Monte Carlo state -swap_init, swap_step = ts.swap_monte_carlo(model=mace_model, kT=kT, seed=42) -swap_state = swap_init(md_state) +swap_state = ts.swap_mc_init(model=mace_model, state=md_state) # Create hybrid state combining both hybrid_state = HybridSwapMCState( **vars(md_state), - last_permutation=torch.zeros( - md_state.n_systems, device=md_state.device, dtype=torch.bool + last_permutation=torch.arange( + md_state.n_atoms, device=md_state.device, dtype=torch.long ), ) @@ -159,12 +152,14 @@ class HybridSwapMCState(ts.integrators.MDState): # %% Run the hybrid simulation n_steps = 100 for step in range(n_steps): - if step % 10 == 0: - # Attempt swap Monte Carlo move - hybrid_state = swap_step(hybrid_state, kT=torch.tensor(kT)) - else: - # Perform MD step - hybrid_state = nvt_step(hybrid_state, dt=torch.tensor(0.002), kT=torch.tensor(kT)) + if step % 10 == 0: # Attempt swap Monte Carlo move + hybrid_state = ts.swap_mc_step( + model=mace_model, state=hybrid_state, kT=kT, seed=42 + step + ) + else: # Perform MD step + hybrid_state = ts.nvt_langevin_update( + model=mace_model, state=hybrid_state, dt=0.002, kT=kT + ) if step % 20 == 0: print(f"Step {step}: Energy = {hybrid_state.energy.item():.3f} eV") @@ -172,13 +167,10 @@ class HybridSwapMCState(ts.integrators.MDState): # %% [markdown] """ -## Concluding Remarks - This tutorial demonstrated how to combine different TorchSim components to create new simulation methods. Key takeaways: 1. TorchSim's components (integrators, MC movers, etc.) are designed to be modular 2. Custom state objects can combine features from different simulation types 3. Complex simulation workflows can be built by mixing and matching components - """ diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 99c8702d4..71049eb91 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -1,12 +1,7 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script -# dependencies = [ -# "mace-torch>=0.3.12", -# ] +# dependencies = ["mace-torch>=0.3.12"] # /// -#
# %% [markdown] @@ -27,7 +22,9 @@ """ ## Setting up the system -TorchSim's state aka `SimState` is a class that contains the information of the +TorchSim's state aka `SimState` is a +import torch_sim as ts +class that contains the information of the system like positions, cell, etc. of the system(s). All the models in the TorchSim package take in a `SimState` as an input and return the properties of the system(s). @@ -44,7 +41,7 @@ fe_bcc = bulk("Fe", "bcc", a=2.8665, cubic=True).repeat((3, 3, 3)) atoms_list = [si_dc, fe_bcc] -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 state = ts.initialize_state(atoms_list, device=device, dtype=dtype) @@ -71,7 +68,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=dtype, + default_dtype=str(dtype).lstrip("torch."), device=device, ) @@ -94,13 +91,11 @@ """ # %% -print("Model device:", model.device) -print("Model dtype:", model.dtype) -print("Model compute_forces:", model.compute_forces) -print("Model compute_stress:", model.compute_stress) - -# see the autobatching tutorial for more details -print("Model memory_scales_with:", model.memory_scales_with) +print(f"{model.device=}") +print(f"{model.dtype=}") +print(f"{model.compute_forces=}") +print(f"{model.compute_stress=}") +print(f"{model.memory_scales_with=}") # see the autobatching tutorial for more details # %% [markdown] @@ -115,7 +110,6 @@ # %% model_outputs = model(state) print(f"Model outputs: {', '.join(list(model_outputs))}") - print(f"Energy is a systemwise property with shape: {model_outputs['energy'].shape}") print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}") print(f"Stress is a systemwise property with shape: {model_outputs['stress'].shape}") @@ -126,17 +120,16 @@ ## Optimizers and Integrators All optimizers and integrators share a similar interface. They accept a model and -return two functions: `init_fn` and `update_fn`. The `init_fn` function returns the -initialized optimizer-specific state, while the `update_fn` function updates the +return two functions: `init_fn` and step_fn`. The `init_fn` function returns the +initialized optimizer-specific state, while the step_fn` function updates the simulation state. ### Unit Cell Fire -We will walk through the `unit_cell_fire` optimizer as an example. +We will walk through the fire optimizer with unit cell filter as an example. """ # %% -fire_init_fn, fire_update_fn = ts.unit_cell_fire(model=model) # %% [markdown] @@ -148,34 +141,31 @@ """ # %% -state = fire_init_fn(state=state) +state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.unit) # add a little noise so we have something to relax state.positions = state.positions + torch.randn_like(state.positions) * 0.05 for step in range(20): - state = fire_update_fn(state=state) + state = ts.fire_step(model=model, state=state) print(f"{step=}: Total energy: {state.energy} eV") # %% [markdown] """ -In general, you can set the optimizer-specific arguments in the `optimize` function -(e.g. `unit_cell_fire`) and they will be baked into the returned functions. Fixed +You can set the optimizer-specific arguments in the `optimize` function +optimizer=ts.OptimFlavor.fire, cell_filter=ts.CellFilter.unit. Fixed parameters can usually be passed to the `init_fn` and parameters that vary over -the course of the simulation can be passed to the `update_fn`. +the course of the simulation can be passed to the step_fn`. """ # %% -fire_init_fn, fire_update_fn = ts.unit_cell_fire( - model=model, - dt_max=0.1, - dt_start=0.02, +state = ts.fire_init( + model=model, state=state, dt_start=0.02, cell_filter=ts.CellFilter.unit ) -state = fire_init_fn(state=state) for step in range(5): - state = fire_update_fn(state=state) + state = ts.fire_step(model=model, state=state, dt_max=0.1) # %% [markdown] @@ -199,17 +189,11 @@ # %% [markdown] """ - -Like the `unit_cell_fire` optimizer, the `nvt_langevin` integrator accepts -a model and configuration kwargs and returns an `init_fn` and `update_fn`. +Like the `fire` optimizer with unit cell filter, the `nvt_langevin` integrator accepts +a model, state and config kwargs. """ -# %% -nvt_langevin_init_fn, nvt_langevin_update_fn = ts.nvt_langevin( - model=model, dt=dt, kT=kT, gamma=gamma -) - -# we'll also reinialize the state to clean up the previous state +# %% we'll also reinitialize the state to clean up the previous state state = ts.initialize_state(atoms_list, device=device, dtype=dtype) @@ -222,11 +206,14 @@ """ # %% -state = nvt_langevin_init_fn(state=state) +state = ts.nvt_langevin_init(model=model, state=state, kT=kT) initial_kT = kT for step in range(30): - state = nvt_langevin_update_fn(state=state, kT=initial_kT * (1 + step / 30)) + current_kT = initial_kT * (1 + step / 30) + state = ts.nvt_langevin_update( + model=model, state=state, dt=dt, kT=current_kT, gamma=gamma + ) if step % 5 == 0: temp_E_units = ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx diff --git a/examples/tutorials/metatomic_tutorial.py b/examples/tutorials/metatomic_tutorial.py index 54479439c..64a89bcbb 100644 --- a/examples/tutorials/metatomic_tutorial.py +++ b/examples/tutorials/metatomic_tutorial.py @@ -1,6 +1,4 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script # dependencies = [ # "metatrain[pet]==2025.7", @@ -8,7 +6,6 @@ # "vesin-torch>=0.3.7", # ] # /// -#
# %% [markdown] @@ -48,7 +45,7 @@ equilibrated_state = ts.integrate( system=atoms, model=model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=100, temperature=300, # K timestep=0.001, # ps @@ -57,7 +54,7 @@ final_state = ts.integrate( system=equilibrated_state, model=model, - integrator=ts.nve, + integrator=ts.MdFlavor.nve, n_steps=100, temperature=300, # K timestep=0.001, # ps diff --git a/examples/tutorials/reporting_tutorial.py b/examples/tutorials/reporting_tutorial.py index 474776134..a9e724e43 100644 --- a/examples/tutorials/reporting_tutorial.py +++ b/examples/tutorials/reporting_tutorial.py @@ -1,14 +1,11 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script # dependencies = [ # "mace-torch>=0.3.12", # "pymatgen>=2025.2.18", -# "ase>=3.23.1", +# "ase>=3.26", # ] # /// -#
# %% [markdown] @@ -45,7 +42,7 @@ a simple interface for storing and retrieving trajectory data from HDF5 files. Through the power of HDF5, the TorchSimTrajectory supports: * Saving arbitrary arrays from the user in a natural way -* First class support for `torch_sim.state.SimState` objects +* First class support for `ts.SimState` objects * Binary encoding + compression for minimal file sizes * Easy interoperability with ASE and pymatgen @@ -99,16 +96,16 @@ # Create a bulk Si diamond structure state = ts.initialize_state( - bulk("Si", "diamond", a=5.43), device="cpu", dtype=torch.float64 + bulk("Si", "diamond", a=5.43), device=torch.device("cpu"), dtype=torch.float64 ) # Open a new trajectory file in a context manager with ts.TorchSimTrajectory("random_state.h5", mode="w") as traj: # Write the state with additional options - for i in range(5): + for step in range(5): traj.write_state( state, - steps=i + 1, + steps=step + 1, save_velocities=False, # our basic state doesn't have velocities save_forces=False, # our basic state doesn't have forces variable_cell=False, # True for an NPT simulation, where the cell changes @@ -210,12 +207,12 @@ # Define some property calculators -def calculate_com(state: ts.state.SimState) -> torch.Tensor: +def calculate_com(state: ts.SimState) -> torch.Tensor: """Calculate center of mass - only needs state""" return torch.mean(state.positions * state.masses.unsqueeze(1), dim=0) -def calculate_energy(state: ts.state.SimState, model: ModelInterface) -> torch.Tensor: +def calculate_energy(state: ts.SimState, model: ModelInterface) -> torch.Tensor: """Calculate energy - needs both state and model""" return model(state)["energy"] diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 5cd43b176..97d02d450 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -1,15 +1,12 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script # dependencies = [ # "mace-torch>=0.3.12", -# "pymatgen>=2024.11.3", -# "ase>=3.24", +# "pymatgen>=2025.6.14", +# "ase>=3.26", # "phonopy>=2.37.0", # ] # /// -#
# %% [markdown] @@ -32,7 +29,6 @@ * Periodic boundary conditions * Atomic numbers (elements) * System indices (for processing multiple systems simultaneously) - """ # %% [markdown] @@ -44,7 +40,6 @@ New SimStates can be either created manually or from existing atomistic objects. Here we'll start by creating an ase atoms object and converting it to a SimState. The `initialize_state` function can take in pymatgen Structure, PhonopyAtoms, or other SimStates and convert them into a single SimState. - """ # %% @@ -95,12 +90,11 @@ # loop through each attribute: for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"): - print(f"per-atom attribute: {attr_name}") - print(f"value: {attr_value}") + print(f"per-atom attribute: {attr_name} = {attr_value}") # or access the attributes via a dict: -print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501 -print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) +print(f"Per-system attributes: {dict(get_attrs_for_scope(si_state, 'per-system'))}") # noqa: E501 +print(f"Global attributes: {dict(get_attrs_for_scope(si_state, 'global'))}") # %% [markdown] """ @@ -247,7 +241,7 @@ ## Extending SimState: The MDState -MDState is defined in the `torch_sim.integrators` module. It is a subclass of SimState +MDState is defined in the `ts.integrators` module. It is a subclass of SimState for molecular dynamics simulations. It includes additional properties like momenta, forces, and energy. Here, we instantiate an MDState from a SimState by zeroing out the additional properties. @@ -269,9 +263,9 @@ ) print("MDState properties:") -print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom"))) -print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) -print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) +print(f"Per-atom attributes: {dict(get_attrs_for_scope(si_state, 'per-atom'))}") +print(f"Per-system attributes: {dict(get_attrs_for_scope(si_state, 'per-system'))}") +print(f"Global attributes: {dict(get_attrs_for_scope(si_state, 'global'))}") # %% [markdown] diff --git a/examples/tutorials/using_graphpes_tutorial.py b/examples/tutorials/using_graphpes_tutorial.py index 7b84bf411..edb6ae462 100644 --- a/examples/tutorials/using_graphpes_tutorial.py +++ b/examples/tutorials/using_graphpes_tutorial.py @@ -1,6 +1,4 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script # dependencies = [ # "graph-pes>=0.0.30", @@ -8,7 +6,6 @@ # "vesin-torch>=0.3.7", # ] # /// -#
# %% [markdown] @@ -38,7 +35,7 @@ # here, we just create a TensorNet model with random weights model = TensorNet(cutoff=5.0) -print("Number of parameters:", sum(p.numel() for p in model.parameters())) +print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}") # %% [markdown] """ @@ -65,6 +62,7 @@ Now that we have a model, we can drive MD simulations with it. For this, we will use the `integrate` function. """ + # %% from ase.build import molecule import torch_sim as ts @@ -76,7 +74,7 @@ final_state = ts.integrate( system=atoms, model=ts_model, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=50, temperature=300, timestep=0.001, diff --git a/pyproject.toml b/pyproject.toml index 8667285de..4fc030287 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,14 @@ [project] -name = "torch_sim_atomistic" +name = "torch-sim-atomistic" version = "0.3.0" description = "A pytorch toolkit for calculating material properties using MLIPs" authors = [ { name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" }, - { name = "Janosh Riebesell", email = "jriebesell@radical-ai.com" }, - { name = "Orion Cohen", email = "orion@radical-ai.com" }, - { name = "Radical AI", email = "info@radical.ai" }, + { name = "Janosh Riebesell", email = "janosh.riebesell@gmail.com" }, + { name = "Orion Cohen", email = "orioncohen@berkeley.edu" }, ] -readme = "README.md" -license = { file = "LICENSE" } +readme = "readme.md" +license = { file = "license" } keywords = [ "chemistry", "interatomic-potentials", @@ -19,38 +18,41 @@ keywords = [ classifiers = [ "Intended Audience :: Science/Research", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Chemistry", "Topic :: Scientific/Engineering :: Physics", ] -requires-python = ">=3.11" +requires-python = ">=3.12" dependencies = [ "h5py>=3.12.1", - "numpy>=1.26", + "numpy>=2", "tables>=3.10.2", "torch>=2", "tqdm>=4.67", - "vesin[torch]>=0.3.7", + "vesin-torch>=0.3.7", + "vesin>=0.3.7", ] [project.optional-dependencies] test = [ - "ase>=3.24", + "ase>=3.26", "phonopy>=2.37.0", "psutil>=7.0.0", - "pymatgen>=2024.11.3", + "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", ] -io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"] +io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] mace = ["mace-torch>=0.3.12"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.1,<0.2", "metatrain[pet]==2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] +nequip = ["nequip>=0.12.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", @@ -60,14 +62,14 @@ docs = [ "jupytext==1.16.7", "myst_parser==4.0.0", "nbsphinx>=0.9.7", - "numpydoc==1.6.0", + "numpydoc==1.8.0", "sphinx-copybutton==0.5.2", - "sphinx==7.2.6", + "sphinx==8.1.3", "sphinx_design==0.6.1", ] [project.urls] -Repo = "https://github.com/TorchSim/torch-sim" +Repo = "https://github.com/torchsim/torch-sim" [build-system] requires = ["uv_build>=0.7.12"] @@ -78,7 +80,7 @@ module-name = "torch_sim" module-root = "" [tool.ruff] -target-version = "py311" +target-version = "py312" line-length = 90 output-format = "concise" @@ -97,6 +99,7 @@ ignore = [ "FIX002", # Line contains TODO, consider resolving the issue "N803", # Variable name should be lowercase "N806", # Uppercase letters in variable names + "PLC0415", # import` should be at the top-level of a file "PLR0912", # too many branches "PLR0913", # too many function arguments "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable @@ -121,15 +124,6 @@ pep8-naming.ignore-names = ["get_kT", "kT"] [tool.ruff.format] docstring-code-format = true -[tool.mypy] -warn_unused_configs = true -ignore_missing_imports = true -check_untyped_defs = true -explicit_package_bases = true -warn_unreachable = true -warn_redundant_casts = true -warn_unused_ignores = true - [tool.codespell] check-filenames = true ignore-words-list = ["convertor"] @@ -159,3 +153,17 @@ conflicts = [ { extra = "sevenn" }, ], ] + +[dependency-groups] +dev = ["pre-commit>=4.3.0", "ty>=0.0.1a20"] + +[tool.ty.rules] +# TODO: Unable to work with **kwargs: https://github.com/astral-sh/ty/issues/247 +missing-argument = "ignore" + +[[tool.ty.overrides]] +include = ["tests/models/**/*.py", "torch_sim/models/**/*.py"] + +# TODO would be nice to only ignore unresolved model imports but fail on all other packages +[tool.ty.overrides.rules] +unresolved-import = "ignore" diff --git a/tests/conftest.py b/tests/conftest.py index 9a19696e7..6e6f5ac6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np import pytest import torch +import torch.distributions.weibull from ase import Atoms from ase.build import bulk, molecule from ase.spacegroup import crystal @@ -10,69 +11,29 @@ from pymatgen.core import Structure import torch_sim as ts -from torch_sim.io import atoms_to_state from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.state import concatenate_states +from torch_sim.models.mace import MaceModel -if TYPE_CHECKING: - from mace.calculators import MACECalculator +DEVICE = torch.device("cpu") +DTYPE = torch.float64 @pytest.fixture -def device() -> torch.device: - return torch.device("cpu") - - -@pytest.fixture -def dtype() -> torch.dtype: - return torch.float64 - - -@pytest.fixture -def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: +def lj_model() -> LennardJonesModel: """Create a Lennard-Jones model with reasonable parameters for Ar.""" return LennardJonesModel( use_neighbor_list=True, sigma=3.405, epsilon=0.0104, - device=device, - dtype=dtype, + device=DEVICE, + dtype=DTYPE, compute_forces=True, compute_stress=True, cutoff=2.5 * 3.405, ) -@pytest.fixture -def ase_mace_mpa() -> "MACECalculator": - """Provides an ASE MACECalculator instance using mace_mp.""" - from mace.calculators.foundations_models import mace_mp - - # Ensure dtype matches the one used in the torchsim fixture (float64) - return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") - - -@pytest.fixture -def torchsim_mace_mpa() -> MaceModel: - """Provides a MACE MP model instance for the optimizer tests.""" - from mace.calculators.foundations_models import mace_mp - - # Use float64 for potentially higher precision needed in optimization - dtype = getattr(torch, dtype_str := "float64") - raw_mace = mace_mp( - model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str - ) - return MaceModel( - model=raw_mace, - device="cpu", - dtype=dtype, - compute_forces=True, - compute_stress=True, - ) - - @pytest.fixture def ar_atoms() -> Atoms: """Create a face-centered cubic (FCC) Argon structure.""" @@ -139,41 +100,41 @@ def si_phonopy_atoms() -> Any: @pytest.fixture -def si_sim_state(si_atoms: Any, device: torch.device, dtype: torch.dtype) -> Any: +def si_sim_state(si_atoms: Any) -> Any: """Create a basic state from si_structure.""" - return ts.io.atoms_to_state(si_atoms, device, dtype) + return ts.io.atoms_to_state(si_atoms, DEVICE, DTYPE) @pytest.fixture -def cu_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def cu_sim_state() -> ts.SimState: """Create crystalline copper using ASE.""" atoms = bulk("Cu", "fcc", a=3.58, cubic=True) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def mg_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def mg_sim_state() -> ts.SimState: """Create crystalline magnesium using ASE.""" atoms = bulk("Mg", "hcp", a=3.17, c=5.14) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def sb_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def sb_sim_state() -> ts.SimState: """Create crystalline antimony using ASE.""" atoms = bulk("Sb", "rhombohedral", a=4.58, alpha=60) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def ti_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def ti_sim_state() -> ts.SimState: """Create crystalline titanium using ASE.""" atoms = bulk("Ti", "hcp", a=2.94, c=4.64) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def tio2_sim_state() -> ts.SimState: """Create crystalline TiO2 using ASE.""" a, c = 4.60, 2.96 basis = [("Ti", 0.5, 0.5, 0), ("O", 0.695679, 0.695679, 0.5)] @@ -183,11 +144,11 @@ def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: spacegroup=136, # P4_2/mnm cellpar=[a, a, c, 90, 90, 90], ) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def ga_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def ga_sim_state() -> ts.SimState: """Create crystalline Ga using ASE.""" a, b, c = 4.43, 7.60, 4.56 basis = [("Ga", 0, 0.344304, 0.415401)] @@ -197,11 +158,11 @@ def ga_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: spacegroup=64, # Cmce cellpar=[a, b, c, 90, 90, 90], ) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def niti_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def niti_sim_state() -> ts.SimState: """Create crystalline NiTi using ASE.""" a, b, c = 2.89, 3.97, 4.83 alpha, beta, gamma = 90.00, 105.23, 90.00 @@ -215,11 +176,11 @@ def niti_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: spacegroup=11, cellpar=[a, b, c, alpha, beta, gamma], ) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def sio2_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def sio2_sim_state() -> ts.SimState: """Create an alpha-quartz SiO2 system for testing.""" atoms = crystal( symbols=["O", "Si"], @@ -227,15 +188,11 @@ def sio2_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: spacegroup=152, cellpar=[4.9019, 4.9019, 5.3988, 90, 90, 120], ) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def rattled_sio2_sim_state( - sio2_sim_state: ts.SimState, - device: torch.device, - dtype: torch.dtype, -) -> ts.SimState: +def rattled_sio2_sim_state(sio2_sim_state: ts.SimState) -> ts.SimState: """Create a rattled SiO2 system for testing.""" sim_state = sio2_sim_state.clone() @@ -245,9 +202,9 @@ def rattled_sio2_sim_state( # Temporarily set a fixed seed torch.manual_seed(3) weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) - rnd = torch.randn_like(sim_state.positions, device=device, dtype=dtype) - rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True).to(device=device) - shifts = weibull.sample(rnd.shape).to(device=device) * rnd + rnd = torch.randn_like(sim_state.positions, device=DEVICE, dtype=DTYPE) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True).to(device=DEVICE) + shifts = weibull.sample(rnd.shape).to(device=DEVICE) * rnd sim_state.positions = sim_state.positions + shifts finally: # Restore the original RNG state @@ -257,7 +214,29 @@ def rattled_sio2_sim_state( @pytest.fixture -def casio3_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: +def rattled_si_sim_state(si_sim_state: ts.SimState) -> ts.SimState: + """Create a rattled Si system for testing.""" + sim_state = si_sim_state.clone() + + # Store the current RNG state + rng_state = torch.random.get_rng_state() + try: + # Temporarily set a fixed seed + torch.manual_seed(3) + weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) + rnd = torch.randn_like(sim_state.positions, device=DEVICE, dtype=DTYPE) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True).to(device=DEVICE) + shifts = weibull.sample(rnd.shape).to(device=DEVICE) * rnd + sim_state.positions = sim_state.positions + shifts + finally: + # Restore the original RNG state + torch.random.set_rng_state(rng_state) + + return sim_state + + +@pytest.fixture +def casio3_sim_state() -> ts.SimState: a, b, c = 7.9258, 7.3202, 7.0653 alpha, beta, gamma = 90.055, 95.217, 103.426 basis = [ @@ -283,46 +262,40 @@ def casio3_sim_state(device: torch.device, dtype: torch.dtype) -> ts.SimState: spacegroup=2, cellpar=[a, b, c, alpha, beta, gamma], ) - return ts.io.atoms_to_state(atoms, device, dtype) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) @pytest.fixture -def benzene_sim_state( - benzene_atoms: Any, device: torch.device, dtype: torch.dtype -) -> Any: +def benzene_sim_state(benzene_atoms: Any) -> Any: """Create a basic state from benzene_atoms.""" - return ts.io.atoms_to_state(benzene_atoms, device, dtype) + return ts.io.atoms_to_state(benzene_atoms, DEVICE, DTYPE) @pytest.fixture -def fe_supercell_sim_state( - fe_atoms: Atoms, device: torch.device, dtype: torch.dtype -) -> Any: +def fe_supercell_sim_state(fe_atoms: Atoms) -> Any: """Create a face-centered cubic (FCC) iron structure with 4x4x4 supercell.""" - return ts.io.atoms_to_state(fe_atoms.repeat([4, 4, 4]), device, dtype) + return ts.io.atoms_to_state(fe_atoms.repeat([4, 4, 4]), DEVICE, DTYPE) @pytest.fixture -def ar_supercell_sim_state( - ar_atoms: Atoms, device: torch.device, dtype: torch.dtype -) -> ts.SimState: +def ar_supercell_sim_state(ar_atoms: Atoms) -> ts.SimState: """Create a face-centered cubic (FCC) Argon structure with 2x2x2 supercell.""" - return ts.io.atoms_to_state(ar_atoms.repeat([2, 2, 2]), device, dtype) + return ts.io.atoms_to_state(ar_atoms.repeat([2, 2, 2]), DEVICE, DTYPE) @pytest.fixture def ar_double_sim_state(ar_supercell_sim_state: ts.SimState) -> ts.SimState: """Create a batched state from ar_fcc_sim_state.""" - return concatenate_states( + return ts.concatenate_states( [ar_supercell_sim_state, ar_supercell_sim_state], device=ar_supercell_sim_state.device, ) @pytest.fixture -def si_double_sim_state(si_atoms: Atoms, device: torch.device, dtype: torch.dtype) -> Any: +def si_double_sim_state(si_atoms: Atoms) -> Any: """Create a basic state from si_structure.""" - return ts.io.atoms_to_state([si_atoms, si_atoms], device, dtype) + return ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, DTYPE) @pytest.fixture @@ -330,14 +303,14 @@ def mixed_double_sim_state( ar_supercell_sim_state: ts.SimState, si_sim_state: ts.SimState ) -> ts.SimState: """Create a batched state from ar_fcc_sim_state.""" - return concatenate_states( + return ts.concatenate_states( [ar_supercell_sim_state, si_sim_state], device=ar_supercell_sim_state.device, ) @pytest.fixture -def osn2_sim_state(torchsim_mace_mpa: MaceModel) -> ts.state.SimState: +def osn2_sim_state(ts_mace_mpa: MaceModel) -> ts.SimState: """Provides an initial SimState for rhombohedral OsN2.""" # For pymatgen Structure initialization from pymatgen.core import Lattice, Structure @@ -348,14 +321,12 @@ def osn2_sim_state(torchsim_mace_mpa: MaceModel) -> ts.state.SimState: frac_coords = [[0.75, 0.7501, -0.25], [0, 0, 0]] # Slightly perturbed structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False) return ts.initialize_state( - structure, dtype=torchsim_mace_mpa.dtype, device=torchsim_mace_mpa.device + structure, dtype=ts_mace_mpa.dtype, device=ts_mace_mpa.device ) @pytest.fixture -def distorted_fcc_al_conventional_sim_state( - torchsim_mace_mpa: MaceModel, -) -> ts.state.SimState: +def distorted_fcc_al_conventional_sim_state(ts_mace_mpa: MaceModel) -> ts.SimState: """Initial SimState for a slightly distorted FCC Al conventional cell (4 atoms).""" # Create a standard 4-atom conventional FCC Al cell atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True) @@ -373,7 +344,7 @@ def distorted_fcc_al_conventional_sim_state( positions += np_rng.normal(scale=0.01, size=positions.shape) atoms_fcc.set_positions(positions) - dtype = torchsim_mace_mpa.dtype - device = torchsim_mace_mpa.device + dtype = ts_mace_mpa.dtype + device = ts_mace_mpa.device # Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms) - return atoms_to_state(atoms_fcc, device=device, dtype=dtype) + return ts.io.atoms_to_state(atoms_fcc, device=device, dtype=dtype) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 7ef843041..5dc23fab8 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -5,6 +5,7 @@ import torch import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.elastic import full_3x3_to_voigt_6_stress @@ -23,6 +24,7 @@ "niti_sim_state", "ti_sim_state", "si_sim_state", + "rattled_si_sim_state", "sio2_sim_state", "rattled_sio2_sim_state", "ar_supercell_sim_state", @@ -37,6 +39,8 @@ def make_model_calculator_consistency_test( model_fixture_name: str, calculator_fixture_name: str, sim_state_names: tuple[str, ...], + device: torch.device = DEVICE, + dtype: torch.dtype = torch.float64, energy_rtol: float = 1e-5, energy_atol: float = 1e-5, force_rtol: float = 1e-5, @@ -61,10 +65,7 @@ def make_model_calculator_consistency_test( @pytest.mark.parametrize("sim_state_name", sim_state_names) def test_model_calculator_consistency( - sim_state_name: str, - request: pytest.FixtureRequest, - device: torch.device, - dtype: torch.dtype, + sim_state_name: str, request: pytest.FixtureRequest ) -> None: """Test consistency between model and calculator implementations.""" # Get the model and calculator fixtures dynamically @@ -124,6 +125,8 @@ def test_model_calculator_consistency( def make_validate_model_outputs_test( model_fixture_name: str, + device: torch.device = DEVICE, + dtype: torch.dtype = torch.float64, ): """Factory function to create model output validation tests. @@ -132,11 +135,7 @@ def make_validate_model_outputs_test( model_fixture_name: Name of the model fixture to validate """ - def test_model_output_validation( - request: pytest.FixtureRequest, - device: torch.device, - dtype: torch.dtype, - ) -> None: + def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" # Get the model fixture dynamically model: ModelInterface = request.getfixturevalue(model_fixture_name) @@ -170,7 +169,7 @@ def test_model_output_validation( og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() og_batch = sim_state.system_idx.clone() - og_atomic_numbers = sim_state.atomic_numbers.clone() + og_atomic_nums = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -178,7 +177,7 @@ def test_model_output_validation( assert torch.allclose(og_positions, sim_state.positions) assert torch.allclose(og_cell, sim_state.cell) assert torch.allclose(og_batch, sim_state.system_idx) - assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) + assert torch.allclose(og_atomic_nums, sim_state.atomic_numbers) # assert model output has the correct keys assert "energy" in model_output diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b04ad9c85..2a5c46fe0 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,18 +1,15 @@ -import os - import pytest import torch -from tests.models.conftest import ( - consistency_test_simstate_fixtures, - make_model_calculator_consistency_test, - make_validate_model_outputs_test, -) +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import make_validate_model_outputs_test try: - from fairchem.core import OCPCalculator - from fairchem.core.models.model_registry import model_name_to_local_file + from collections.abc import Callable + + from ase.build import bulk, fcc100, molecule from huggingface_hub.utils._auth import get_token from torch_sim.models.fairchem import FairChemModel @@ -21,77 +18,201 @@ pytest.skip("FairChem not installed", allow_module_level=True) -@pytest.fixture(scope="session") -def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - @pytest.fixture -def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=True) +def eqv2_uma_model_pbc() -> FairChemModel: + """UMA model for periodic boundary condition systems.""" + cpu = DEVICE.type == "cpu" + return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) -@pytest.fixture -def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_oc20, cpu=cpu, seed=0, pbc=False) +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) +def test_task_initialization(task_name: str) -> None: + """Test that different UMA task names work correctly.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True) + assert model.task_name + assert str(model.task_name.value) == task_name + assert hasattr(model, "predictor") -if get_token(): +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("task_name", "systems_func"), + [ + ( + "omat", + lambda: [ + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05), + bulk("Fe", "bcc", a=2.87), + bulk("Cu", "fcc", a=3.61), + ], + ), + ( + "omol", + lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], + ), + ], +) +def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None: + """Test batching multiple systems with the same task.""" + systems = systems_func() + + # Add molecular properties for molecules + if task_name == "omol": + for mol in systems: + mol.info |= {"charge": 0, "spin": 1} + + model = FairChemModel( + model=None, model_name="uma-s-1", task_name=task_name, cpu=DEVICE.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + # Check batch dimensions + assert results["energy"].shape == (4,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + + # Check that different systems have different energies + energies = results["energy"] + uniq_energies = torch.unique(energies, dim=0) + assert len(uniq_energies) > 1, "Different systems should have different energies" + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_heterogeneous_tasks() -> None: + """Test different task types work with appropriate systems.""" + # Test molecule, material, and catalysis systems separately + test_cases = [ + ("omol", [molecule("H2O")]), + ("omat", [bulk("Pt", cubic=True)]), + ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), + ] + + for task_name, systems in test_cases: + if task_name == "omol": + systems[0].info |= {"charge": 0, "spin": 1} + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name=task_name, + cpu=DEVICE.type == "cpu", + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + assert results["energy"].shape[0] == 1 + assert results["forces"].dim() == 2 + assert results["forces"].shape[1] == 3 + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("systems_func", "expected_count"), + [ + (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system + ( + lambda: [ + bulk("H", "bcc", a=2.0), + bulk("Li", "bcc", a=3.0), + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), + ], + 4, + ), # Mixed sizes + ( + lambda: [ + bulk(element, "fcc", a=4.0) + for element in ("Al", "Cu", "Ni", "Pd", "Pt") * 3 + ], + 15, + ), # Large batch + ], +) +def test_batch_size_variations(systems_func: Callable, expected_count: int) -> None: + """Test batching with different numbers and sizes of systems.""" + systems = systems_func() - @pytest.fixture(scope="session") - def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + model = FairChemModel( + model=None, model_name="uma-s-1", task_name="omat", cpu=DEVICE.type == "cpu" + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) - @pytest.fixture - def eqv2_omat24_model_pbc( - model_path_omat24: str, device: torch.device - ) -> FairChemModel: - cpu = device.type == "cpu" - return FairChemModel(model=model_path_omat24, cpu=cpu, seed=0, pbc=True) + assert results["energy"].shape == (expected_count,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() -@pytest.fixture -def ocp_calculator(model_path_oc20: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) - - -test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( - test_name="fairchem_ocp", - model_fixture_name="eqv2_oc20_model_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" ) - -test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( - test_name="fairchem_non_pbc_benzene", - model_fixture_name="eqv2_oc20_model_non_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=["benzene_sim_state"], - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, +@pytest.mark.parametrize("compute_stress", [True, False]) +def test_stress_computation(*, compute_stress: bool) -> None: + """Test stress tensor computation.""" + systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] + + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name="omat", + cpu=DEVICE.type == "cpu", + compute_stress=compute_stress, + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + if compute_stress: + assert "stress" in results + assert results["stress"].shape == (2, 3, 3) + assert torch.isfinite(results["stress"]).all() + else: + assert "stress" not in results + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" ) +def test_device_consistency() -> None: + """Test device consistency between model and data.""" + cpu = DEVICE.type == "cpu" + + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu) + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) + results = model(state) + assert results["energy"].device == DEVICE + assert results["forces"].device == DEVICE -# Skip this test due to issues with how the older models -# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428) -test_fairchem_ocp_model_outputs = pytest.mark.skipif( - os.environ.get("HF_TOKEN") is None, - reason="Issues in graph construction of older models", -)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +def test_empty_batch_error() -> None: + """Test that empty batches raise appropriate errors.""" + model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True) + with pytest.raises((ValueError, RuntimeError, IndexError)): + model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32)) + + +test_fairchem_uma_model_outputs = pytest.mark.skipif( + get_token() is None, + reason="Requires HuggingFace authentication for UMA model access", +)( + make_validate_model_outputs_test( + model_fixture_name="eqv2_uma_model_pbc", device=DEVICE, dtype=DTYPE + ) +) diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index 1470d3bec..6a48601ff 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -3,6 +3,7 @@ from ase.build import bulk, molecule import torch_sim as ts +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -18,15 +19,11 @@ except ImportError: pytest.skip("graph-pes not installed", allow_module_level=True) - -@pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 +DTYPE = torch.float32 -def test_graphpes_isolated(device: torch.device): - # test that the raw model and torch_sim wrapper give the same results +def test_graphpes_isolated(): + # test that the raw model and torch-sim wrapper give the same results # for an isolated, unbatched structure water_atoms = molecule("H2O") @@ -38,20 +35,20 @@ def test_graphpes_isolated(device: torch.device): ts_model = GraphPESWrapper( gp_model, - device=device, - dtype=torch.float32, + device=DEVICE, + dtype=DTYPE, compute_forces=True, compute_stress=False, ) - ts_output = ts_model(ts.io.atoms_to_state([water_atoms], device, torch.float32)) + ts_output = ts_model(ts.io.atoms_to_state([water_atoms], DEVICE, DTYPE)) assert set(ts_output) == {"energy", "forces"} assert ts_output["energy"].shape == (1,) assert gp_energy.item() == pytest.approx(ts_output["energy"].item(), abs=1e-5) -def test_graphpes_periodic(device: torch.device): - # test that the raw model and torch_sim wrapper give the same results +def test_graphpes_periodic(): + # test that the raw model and torch-sim wrapper give the same results # for a periodic, unbatched structure bulk_atoms = bulk("Al", "hcp", a=4.05) @@ -63,12 +60,12 @@ def test_graphpes_periodic(device: torch.device): ts_model = GraphPESWrapper( gp_model, - device=device, - dtype=torch.float32, + device=DEVICE, + dtype=DTYPE, compute_forces=True, compute_stress=True, ) - ts_output = ts_model(ts.io.atoms_to_state([bulk_atoms], device, torch.float32)) + ts_output = ts_model(ts.io.atoms_to_state([bulk_atoms], DEVICE, DTYPE)) assert set(ts_output) == {"energy", "forces", "stress"} assert ts_output["energy"].shape == (1,) assert ts_output["forces"].shape == (len(bulk_atoms), 3) @@ -77,9 +74,9 @@ def test_graphpes_periodic(device: torch.device): torch.testing.assert_close(ts_output["forces"].to("cpu"), gp_forces) -def test_batching(device: torch.device): - # test that the raw model and torch_sim wrapper give the same results - # when batching is done via torch_sim's atoms_to_state function +def test_batching(): + # test that the raw model and torch-sim wrapper give the same results + # when batching is done via torch-sim's atoms_to_state function water = molecule("H2O") methane = molecule("CH4") @@ -94,12 +91,12 @@ def test_batching(device: torch.device): ts_model = GraphPESWrapper( gp_model, - device=device, - dtype=torch.float32, + device=DEVICE, + dtype=DTYPE, compute_forces=True, compute_stress=True, ) - ts_output = ts_model(ts.io.atoms_to_state(systems, device, torch.float32)) + ts_output = ts_model(ts.io.atoms_to_state(systems, DEVICE, DTYPE)) assert set(ts_output) == {"energy", "forces", "stress"} assert ts_output["energy"].shape == (2,) @@ -110,13 +107,13 @@ def test_batching(device: torch.device): @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_graphpes_dtype(device: torch.device, dtype: torch.dtype): +def test_graphpes_dtype(dtype: torch.dtype): water = molecule("H2O") model = SchNet() - ts_wrapper = GraphPESWrapper(model, device=device, dtype=dtype, compute_stress=False) - ts_output = ts_wrapper(ts.io.atoms_to_state([water], device, dtype)) + ts_wrapper = GraphPESWrapper(model, device=DEVICE, dtype=dtype, compute_stress=False) + ts_output = ts_wrapper(ts.io.atoms_to_state([water], DEVICE, dtype)) assert ts_output["energy"].dtype == dtype assert ts_output["forces"].dtype == dtype @@ -125,18 +122,15 @@ def test_graphpes_dtype(device: torch.device, dtype: torch.dtype): @pytest.fixture -def ts_nequip_model(device: torch.device, dtype: torch.dtype): +def ts_nequip_model(): return GraphPESWrapper( - _nequip_model, - device=device, - dtype=dtype, - compute_stress=False, + _nequip_model, device=DEVICE, dtype=DTYPE, compute_stress=False ) @pytest.fixture -def ase_nequip_calculator(device: torch.device, dtype: torch.dtype): - return _nequip_model.to(device, dtype).ase_calculator(skin=0.0) +def ase_nequip_calculator(): + return _nequip_model.to(DEVICE, DTYPE).ase_calculator(skin=0.0) test_graphpes_nequip_consistency = make_model_calculator_consistency_test( @@ -144,26 +138,34 @@ def ase_nequip_calculator(device: torch.device, dtype: torch.dtype): model_fixture_name="ts_nequip_model", calculator_fixture_name="ase_nequip_calculator", sim_state_names=consistency_test_simstate_fixtures, + device=DEVICE, + dtype=DTYPE, + energy_rtol=1e-3, + energy_atol=1e-3, + force_rtol=1e-3, + force_atol=1e-3, + stress_rtol=1e-3, + stress_atol=1e-3, ) test_graphpes_nequip_model_outputs = make_validate_model_outputs_test( - model_fixture_name="ts_nequip_model", + model_fixture_name="ts_nequip_model", device=DEVICE, dtype=DTYPE ) @pytest.fixture -def ts_mace_model(device: torch.device, dtype: torch.dtype): +def ts_mace_model(): return GraphPESWrapper( mace_mp("medium-mpa-0"), - device=device, - dtype=dtype, + device=DEVICE, + dtype=DTYPE, compute_stress=False, ) @pytest.fixture -def ase_mace_calculator(device: torch.device, dtype: torch.dtype): - return mace_mp("medium-mpa-0").to(device, dtype).ase_calculator(skin=0.0) +def ase_mace_calculator(): + return mace_mp("medium-mpa-0").to(DEVICE, DTYPE).ase_calculator(skin=0.0) test_graphpes_mace_consistency = make_model_calculator_consistency_test( @@ -171,10 +173,14 @@ def ase_mace_calculator(device: torch.device, dtype: torch.dtype): model_fixture_name="ts_mace_model", calculator_fixture_name="ase_mace_calculator", sim_state_names=consistency_test_simstate_fixtures, + device=DEVICE, + dtype=DTYPE, ) test_graphpes_mace_model_outputs = make_validate_model_outputs_test( model_fixture_name="ts_mace_model", + device=DEVICE, + dtype=DTYPE, ) @@ -182,18 +188,13 @@ def ase_mace_calculator(device: torch.device, dtype: torch.dtype): @pytest.fixture -def ts_lj_model(device: torch.device, dtype: torch.dtype): - return GraphPESWrapper( - _lj_model, - device=device, - dtype=dtype, - compute_stress=False, - ) +def ts_lj_model(): + return GraphPESWrapper(_lj_model, device=DEVICE, dtype=DTYPE, compute_stress=False) @pytest.fixture -def ase_lj_calculator(device: torch.device, dtype: torch.dtype): - return _lj_model.to(device, dtype).ase_calculator(skin=0.0) +def ase_lj_calculator(): + return _lj_model.to(DEVICE, DTYPE).ase_calculator(skin=0.0) test_graphpes_lj_consistency = make_model_calculator_consistency_test( @@ -201,4 +202,6 @@ def ase_lj_calculator(device: torch.device, dtype: torch.dtype): model_fixture_name="ts_lj_model", calculator_fixture_name="ase_lj_calculator", sim_state_names=consistency_test_simstate_fixtures, + device=DEVICE, + dtype=DTYPE, ) diff --git a/tests/models/test_lennard_jones.py b/tests/models/test_lennard_jones.py index 205603cf1..52435a754 100644 --- a/tests/models/test_lennard_jones.py +++ b/tests/models/test_lennard_jones.py @@ -1,10 +1,11 @@ -"""Cheap integration tests ensuring different parts of torchsim work together.""" +"""Cheap integration tests ensuring different parts of TorchSim work together.""" import pytest import torch from ase.build import bulk import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.models.interface import validate_model_outputs from torch_sim.models.lennard_jones import ( LennardJonesModel, @@ -136,11 +137,11 @@ def test_lennard_jones_force_energy_consistency() -> None: # is not used in the neighbor list calculation. So to get correct results, # we need a system that is large enough (2*cutoff). @pytest.fixture -def ar_supercell_sim_state_large(device: torch.device) -> ts.SimState: +def ar_supercell_sim_state_large() -> ts.SimState: """Create a face-centered cubic (FCC) Argon structure.""" # Create FCC Ar using ASE, with 4x4x4 supercell ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([4, 4, 4]) - return ts.io.atoms_to_state(ar_atoms, device, torch.float64) + return ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64) @pytest.fixture @@ -230,9 +231,6 @@ def test_stress_tensor_symmetry( assert torch.allclose(stress_tensor, stress_tensor.T, atol=1e-10) -def test_validate_model_outputs( - lj_model: LennardJonesModel, - device: torch.device, -) -> None: +def test_validate_model_outputs(lj_model: LennardJonesModel) -> None: """Test that the model outputs are valid.""" - validate_model_outputs(lj_model, device, torch.float64) + validate_model_outputs(lj_model, DEVICE, torch.float64) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 427ef0647..50246ecb0 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -3,6 +3,7 @@ from ase.atoms import Atoms import torch_sim as ts +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -20,32 +21,25 @@ pytest.skip("MACE not installed", allow_module_level=True) -mace_model = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) -mace_off_model = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) - - -@pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 +raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) +raw_mace_off = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) +DTYPE = torch.float32 @pytest.fixture def ase_mace_calculator() -> MACECalculator: + dtype = str(DTYPE).lstrip("torch.") return mace_mp( - model=MaceUrls.mace_mp_small, - device="cpu", - default_dtype="float32", - dispersion=False, + model=MaceUrls.mace_mp_small, device="cpu", default_dtype=dtype, dispersion=False ) @pytest.fixture -def torchsim_mace_model(device: torch.device, dtype: torch.dtype) -> MaceModel: +def ts_mace_model() -> MaceModel: return MaceModel( - model=mace_model, - device=device, - dtype=dtype, + model=raw_mace_mp, + device=DEVICE, + dtype=DTYPE, compute_forces=True, compute_stress=True, ) @@ -53,36 +47,32 @@ def torchsim_mace_model(device: torch.device, dtype: torch.dtype) -> MaceModel: test_mace_consistency = make_model_calculator_consistency_test( test_name="mace", - model_fixture_name="torchsim_mace_model", + model_fixture_name="ts_mace_model", calculator_fixture_name="ase_mace_calculator", sim_state_names=consistency_test_simstate_fixtures, + dtype=DTYPE, ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_mace_dtype_working( - si_atoms: Atoms, dtype: torch.dtype, device: torch.device -) -> None: +def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: model = MaceModel( - model=mace_model, - device=device, + model=raw_mace_mp, + device=DEVICE, dtype=dtype, compute_forces=True, ) - state = ts.io.atoms_to_state([si_atoms], device, dtype) - + state = ts.io.atoms_to_state([si_atoms], DEVICE, dtype) model.forward(state) @pytest.fixture -def benzene_system( - benzene_atoms: Atoms, device: torch.device, dtype: torch.dtype -) -> dict: +def benzene_system(benzene_atoms: Atoms) -> dict: atomic_numbers = benzene_atoms.get_atomic_numbers() - positions = torch.tensor(benzene_atoms.positions, device=device, dtype=dtype) - cell = torch.tensor(benzene_atoms.cell.array, device=device, dtype=dtype) + positions = torch.tensor(benzene_atoms.positions, device=DEVICE, dtype=DTYPE) + cell = torch.tensor(benzene_atoms.cell.array, device=DEVICE, dtype=DTYPE) return { "positions": positions, @@ -96,46 +86,35 @@ def benzene_system( def ase_mace_off_calculator() -> MACECalculator: return mace_off( model=MaceUrls.mace_off_small, - device="cpu", - default_dtype="float32", + device=str(DEVICE), + default_dtype=str(DTYPE).lstrip("torch."), dispersion=False, ) @pytest.fixture -def torchsim_mace_off_model(device: torch.device, dtype: torch.dtype) -> MaceModel: - return MaceModel( - model=mace_off_model, - device=device, - dtype=dtype, - compute_forces=True, - ) +def ts_mace_off_model() -> MaceModel: + return MaceModel(model=raw_mace_off, device=DEVICE, dtype=DTYPE, compute_forces=True) test_mace_off_consistency = make_model_calculator_consistency_test( test_name="mace_off", - model_fixture_name="torchsim_mace_off_model", + model_fixture_name="ts_mace_off_model", calculator_fixture_name="ase_mace_off_calculator", - sim_state_names=["benzene_sim_state"], + sim_state_names=("benzene_sim_state",), + dtype=DTYPE, ) test_mace_off_model_outputs = make_validate_model_outputs_test( - model_fixture_name="torchsim_mace_model" + model_fixture_name="ts_mace_model", dtype=DTYPE ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_mace_off_dtype_working( - benzene_atoms: Atoms, dtype: torch.dtype, device: torch.device -) -> None: - model = MaceModel( - model=mace_off_model, - device=device, - dtype=dtype, - compute_forces=True, - ) +def test_mace_off_dtype_working(benzene_atoms: Atoms, dtype: torch.dtype) -> None: + model = MaceModel(model=raw_mace_off, device=DEVICE, dtype=dtype, compute_forces=True) - state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + state = ts.io.atoms_to_state([benzene_atoms], DEVICE, dtype) model.forward(state) diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index a137ed788..57a6a521c 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -3,8 +3,8 @@ import ase.spacegroup import ase.units import pytest -import torch +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -21,59 +21,36 @@ pytest.skip("mattersim not installed", allow_module_level=True) -@pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 - - -@pytest.fixture -def model_name() -> str: - """Fixture to provide the model name for testing. Load smaller 1M model - for testing purposes. - """ - return "mattersim-v1.0.0-1m.pth" +model_name = "mattersim-v1.0.0-1m.pth" @pytest.fixture -def pretrained_mattersim_model(device: torch.device, model_name: str): +def pretrained_mattersim_model(): """Load a pretrained MatterSim model for testing.""" return Potential.from_checkpoint( load_path=model_name, model_name="m3gnet", - device=device, + device=DEVICE, load_training_state=False, ) @pytest.fixture -def mattersim_model( - pretrained_mattersim_model: Potential, device: torch.device -) -> MatterSimModel: +def mattersim_model(pretrained_mattersim_model: Potential) -> MatterSimModel: """Create an MatterSimModel wrapper for the pretrained model.""" - return MatterSimModel( - model=pretrained_mattersim_model, - device=device, - ) + return MatterSimModel(model=pretrained_mattersim_model, device=DEVICE) @pytest.fixture -def mattersim_calculator( - pretrained_mattersim_model: Potential, device: torch.device -) -> MatterSimCalculator: +def mattersim_calculator(pretrained_mattersim_model: Potential) -> MatterSimCalculator: """Create an MatterSimCalculator for the pretrained model.""" - return MatterSimCalculator(pretrained_mattersim_model, device=device) + return MatterSimCalculator(pretrained_mattersim_model, device=DEVICE) -def test_mattersim_initialization( - pretrained_mattersim_model: Potential, device: torch.device -) -> None: +def test_mattersim_initialization(pretrained_mattersim_model: Potential) -> None: """Test that the MatterSim model initializes correctly.""" - model = MatterSimModel( - model=pretrained_mattersim_model, - device=device, - ) - assert model._device == device # noqa: SLF001 + model = MatterSimModel(model=pretrained_mattersim_model, device=DEVICE) + assert model.device == DEVICE assert model.stress_weight == ase.units.GPa diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index f467e4e75..7145bbd94 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -1,6 +1,7 @@ import pytest import torch +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -18,38 +19,32 @@ @pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 - - -@pytest.fixture -def metatomic_calculator(device: torch.device): +def metatomic_calculator(): """Load a pretrained metatomic model for testing.""" return ase_calculator.MetatomicCalculator( model=load_model( "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" ).export(), - device=device, + device=DEVICE, ) @pytest.fixture -def metatomic_model(device: torch.device) -> MetatomicModel: +def metatomic_model() -> MetatomicModel: """Create an MetatomicModel wrapper for the pretrained model.""" return MetatomicModel( model="pet-mad", - device=device, + device=DEVICE, ) -def test_metatomic_initialization(device: torch.device) -> None: +def test_metatomic_initialization() -> None: """Test that the metatomic model initializes correctly.""" model = MetatomicModel( model="pet-mad", - device=device, + device=DEVICE, ) - assert model.device == device + assert model.device == DEVICE assert model.dtype == torch.float32 diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py new file mode 100644 index 000000000..97c5fef2f --- /dev/null +++ b/tests/models/test_nequip_framework.py @@ -0,0 +1,81 @@ +import urllib.request +from enum import StrEnum +from pathlib import Path + +import pytest + +from tests.conftest import DEVICE +from tests.models.conftest import make_model_calculator_consistency_test + + +try: + from nequip.ase import NequIPCalculator + + from torch_sim.models.nequip_framework import ( + NequIPFrameworkModel, + from_compiled_model, + ) +except ImportError: + pytest.skip("nequip not installed", allow_module_level=True) + + +class NequIPUrls(StrEnum): + """Checkpoint download URLs for NequIP models.""" + + Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth" + + +@pytest.fixture(scope="session") +def model_path_nequip(tmp_path_factory: pytest.TempPathFactory) -> Path: + tmp_path = tmp_path_factory.mktemp("nequip_checkpoints") + model_name = "Si.nequip.pth" + model_path = Path(tmp_path) / model_name + + if not model_path.is_file(): + urllib.request.urlretrieve(NequIPUrls.Si, model_path) # noqa: S310 + + return model_path + + +@pytest.fixture +def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel: + """Create an NequIPModel wrapper for the pretrained model.""" + compiled_model, (r_max, type_names) = from_compiled_model( + model_path_nequip, device=DEVICE + ) + return NequIPFrameworkModel( + model=compiled_model, + r_max=r_max, + type_names=type_names, + device=DEVICE, + ) + + +@pytest.fixture +def nequip_calculator(model_path_nequip: Path) -> NequIPCalculator: + """Create an NequIPCalculator for the pretrained model.""" + return NequIPCalculator.from_compiled_model(str(model_path_nequip), device=DEVICE) + + +def test_nequip_initialization(model_path_nequip: Path) -> None: + """Test that the NequIP model initializes correctly.""" + compiled_model, (r_max, type_names) = from_compiled_model( + model_path_nequip, device=DEVICE + ) + model = NequIPFrameworkModel( + model=compiled_model, + r_max=r_max, + type_names=type_names, + device=DEVICE, + ) + assert model._device == DEVICE # noqa: SLF001 + + +test_nequip_consistency = make_model_calculator_consistency_test( + test_name="nequip", + model_fixture_name="nequip_model", + calculator_fixture_name="nequip_calculator", + sim_state_names=("si_sim_state", "rattled_si_sim_state"), +) + +# TODO (AG): Test multi element models diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 5c75d4bdc..c6559fb1c 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -1,6 +1,6 @@ import pytest -import torch +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -18,41 +18,33 @@ @pytest.fixture -def orbv3_conservative_inf_omat_model(device: torch.device) -> OrbModel: +def orbv3_conservative_inf_omat_model() -> OrbModel: orb_ff = pretrained.orb_v3_conservative_inf_omat( - device=device, - precision="float32-high", + device=DEVICE, precision="float32-high" ) - return OrbModel(model=orb_ff, device=device) + return OrbModel(model=orb_ff, device=DEVICE) @pytest.fixture -def orbv3_direct_20_omat_model(device: torch.device) -> OrbModel: - orb_ff = pretrained.orb_v3_direct_20_omat( - device=device, - precision="float32-high", - ) - return OrbModel(model=orb_ff, device=device) +def orbv3_direct_20_omat_model() -> OrbModel: + orb_ff = pretrained.orb_v3_direct_20_omat(device=DEVICE, precision="float32-high") + return OrbModel(model=orb_ff, device=DEVICE) @pytest.fixture -def orbv3_conservative_inf_omat_calculator(device: torch.device) -> ORBCalculator: +def orbv3_conservative_inf_omat_calculator() -> ORBCalculator: """Create an ORBCalculator for the pretrained model.""" orb_ff = pretrained.orb_v3_conservative_inf_omat( - device=device, - precision="float32-high", + device=DEVICE, precision="float32-high" ) - return ORBCalculator(model=orb_ff, device=device) + return ORBCalculator(model=orb_ff, device=DEVICE) @pytest.fixture -def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: +def orbv3_direct_20_omat_calculator() -> ORBCalculator: """Create an ORBCalculator for the pretrained model.""" - orb_ff = pretrained.orb_v3_direct_20_omat( - device=device, - precision="float32-high", - ) - return ORBCalculator(model=orb_ff, device=device) + orb_ff = pretrained.orb_v3_direct_20_omat(device=DEVICE, precision="float32-high") + return ORBCalculator(model=orb_ff, device=DEVICE) test_orb_conservative_consistency = make_model_calculator_consistency_test( @@ -60,8 +52,8 @@ def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: model_fixture_name="orbv3_conservative_inf_omat_model", calculator_fixture_name="orbv3_conservative_inf_omat_calculator", sim_state_names=consistency_test_simstate_fixtures, - energy_rtol=5e-5, - energy_atol=5e-5, + energy_rtol=1e-3, + energy_atol=1e-3, ) test_orb_direct_consistency = make_model_calculator_consistency_test( @@ -69,8 +61,8 @@ def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: model_fixture_name="orbv3_direct_20_omat_model", calculator_fixture_name="orbv3_direct_20_omat_calculator", sim_state_names=consistency_test_simstate_fixtures, - energy_rtol=5e-5, - energy_atol=5e-5, + energy_rtol=1e-3, + energy_atol=1e-3, ) test_validate_conservative_model_outputs = make_validate_model_outputs_test( diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 25bd310a9..482dde114 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -1,6 +1,7 @@ import pytest import torch +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -19,26 +20,13 @@ pytest.skip("sevenn not installed", allow_module_level=True) -@pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 - - -@pytest.fixture -def model_name() -> str: - """Fixture to provide the model name for testing.""" - return "sevennet-mf-ompa" +model_name = "sevennet-mf-ompa" +modal_name = "mpa" +DTYPE = torch.float32 @pytest.fixture -def modal_name() -> str: - """Fixture to provide the modal name for testing.""" - return "mpa" - - -@pytest.fixture -def pretrained_sevenn_model(device: torch.device, model_name: str): +def pretrained_sevenn_model(): """Load a pretrained SevenNet model for testing.""" cp = sevenn.util.load_checkpoint(model_name) @@ -46,41 +34,27 @@ def pretrained_sevenn_model(device: torch.device, model_name: str): model_loaded = cp.build_model(backend) model_loaded.set_is_batch_data(True) - return model_loaded.to(device) + return model_loaded.to(DEVICE) @pytest.fixture -def sevenn_model( - pretrained_sevenn_model: AtomGraphSequential, device: torch.device, modal_name: str -) -> SevenNetModel: +def sevenn_model(pretrained_sevenn_model: AtomGraphSequential) -> SevenNetModel: """Create an SevenNetModel wrapper for the pretrained model.""" - return SevenNetModel( - model=pretrained_sevenn_model, - modal=modal_name, - device=device, - ) + return SevenNetModel(model=pretrained_sevenn_model, modal=modal_name, device=DEVICE) @pytest.fixture -def sevenn_calculator( - device: torch.device, model_name: str, modal_name: str -) -> SevenNetCalculator: +def sevenn_calculator() -> SevenNetCalculator: """Create an SevenNetCalculator for the pretrained model.""" - return SevenNetCalculator(model_name, modal=modal_name, device=device) + return SevenNetCalculator(model_name, modal=modal_name, device=DEVICE) -def test_sevennet_initialization( - pretrained_sevenn_model: AtomGraphSequential, device: torch.device -) -> None: +def test_sevennet_initialization(pretrained_sevenn_model: AtomGraphSequential) -> None: """Test that the SevenNet model initializes correctly.""" - model = SevenNetModel( - model=pretrained_sevenn_model, - modal="omat24", - device=device, - ) + model = SevenNetModel(model=pretrained_sevenn_model, modal="omat24", device=DEVICE) # Check that properties were set correctly assert model.modal == "omat24" - assert model._device == device # noqa: SLF001 + assert model.device == DEVICE # NOTE: we take [:-1] to skipbenzene due to eps volume giving numerically @@ -90,9 +64,10 @@ def test_sevennet_initialization( model_fixture_name="sevenn_model", calculator_fixture_name="sevenn_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], + dtype=DTYPE, ) test_sevennet_model_outputs = make_validate_model_outputs_test( - model_fixture_name="sevenn_model", + model_fixture_name="sevenn_model", dtype=DTYPE ) diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index f14cf833f..1f89b8f48 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -1,10 +1,11 @@ -"""Tests for soft sphere models ensuring different parts of torchsim work together.""" +"""Tests for soft sphere models ensuring different parts of TorchSim work together.""" import pytest import torch import torch_sim as ts -import torch_sim.models.soft_sphere as ss +import torch_sim.models.soft_sphere as fss +from tests.conftest import DEVICE from torch_sim.models.interface import validate_model_outputs @@ -22,8 +23,8 @@ def models( "compute_stress": True, } - model_nl = ss.SoftSphereModel(use_neighbor_list=True, **calc_params) - model_direct = ss.SoftSphereModel(use_neighbor_list=False, **calc_params) + model_nl = fss.SoftSphereModel(use_neighbor_list=True, **calc_params) + model_direct = fss.SoftSphereModel(use_neighbor_list=False, **calc_params) return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) @@ -44,8 +45,8 @@ def models_with_per_atom( "per_atom_stresses": True, } - model_nl = ss.SoftSphereModel(use_neighbor_list=True, **calc_params) - model_direct = ss.SoftSphereModel(use_neighbor_list=False, **calc_params) + model_nl = fss.SoftSphereModel(use_neighbor_list=True, **calc_params) + model_direct = fss.SoftSphereModel(use_neighbor_list=False, **calc_params) return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) @@ -68,7 +69,7 @@ def small_sim_state(small_system: tuple[torch.Tensor, torch.Tensor]) -> ts.SimSt return ts.SimState( positions=positions, cell=cell, - pbc=torch.tensor([True, True, True]), + pbc=True, masses=torch.ones(positions.shape[0], dtype=torch.float64), atomic_numbers=torch.ones(positions.shape[0], dtype=torch.long), ) @@ -110,7 +111,7 @@ def test_stress_tensor_symmetry( assert torch.allclose(results_nl["stress"], results_nl["stress"].T, atol=1e-10) -def test_validate_model_outputs(device: torch.device) -> None: +def test_validate_model_outputs() -> None: """Test that the model outputs are valid.""" model_params = { "sigma": 3.405, # Γ…, typical for Ar @@ -121,10 +122,10 @@ def test_validate_model_outputs(device: torch.device) -> None: "compute_stress": True, } - model_nl = ss.SoftSphereModel(use_neighbor_list=True, **model_params) - model_direct = ss.SoftSphereModel(use_neighbor_list=False, **model_params) - for out in [model_nl, model_direct]: - validate_model_outputs(out, device, torch.float64) + model_nl = fss.SoftSphereModel(use_neighbor_list=True, **model_params) + model_direct = fss.SoftSphereModel(use_neighbor_list=False, **model_params) + for out in (model_nl, model_direct): + validate_model_outputs(out, DEVICE, torch.float64) @pytest.mark.parametrize( @@ -165,7 +166,7 @@ def test_soft_sphere_pair_single( distance: float, sigma: float, epsilon: float, alpha: float, expected: float ) -> None: """Test the soft sphere pair calculation for single values.""" - energy = ss.soft_sphere_pair( + energy = fss.soft_sphere_pair( torch.tensor(distance), torch.tensor(sigma), torch.tensor(epsilon), @@ -176,13 +177,13 @@ def test_soft_sphere_pair_single( def test_model_initialization_defaults() -> None: """Test initialization with default parameters.""" - model = ss.SoftSphereModel() + model = fss.SoftSphereModel() # Check default parameters are used - assert torch.allclose(model.sigma, ss.DEFAULT_SIGMA) - assert torch.allclose(model.epsilon, ss.DEFAULT_EPSILON) - assert torch.allclose(model.alpha, ss.DEFAULT_ALPHA) - assert torch.allclose(model.cutoff, ss.DEFAULT_SIGMA) # Default cutoff is sigma + assert torch.allclose(model.sigma, fss.DEFAULT_SIGMA) + assert torch.allclose(model.epsilon, fss.DEFAULT_EPSILON) + assert torch.allclose(model.alpha, fss.DEFAULT_ALPHA) + assert torch.allclose(model.cutoff, fss.DEFAULT_SIGMA) # Default cutoff is sigma @pytest.mark.parametrize( @@ -199,7 +200,7 @@ def test_model_initialization_custom_params( ) -> None: """Test initialization with custom parameters.""" params = {param_name: param_value, "dtype": expected_dtype} - model = ss.SoftSphereModel(**params) + model = fss.SoftSphereModel(**params) param_tensor = getattr(model, param_name) assert torch.allclose(param_tensor, torch.tensor(param_value, dtype=expected_dtype)) @@ -218,7 +219,7 @@ def test_model_initialization_custom_params( ) def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) -> None: """Test initialization with custom flags.""" - model = ss.SoftSphereModel(**{flag_name: flag_value}) + model = fss.SoftSphereModel(**{flag_name: flag_value}) # For compute_forces and compute_stress, we need to check the private attributes if flag_name == "compute_forces": @@ -232,7 +233,7 @@ def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_model_dtype(dtype: torch.dtype) -> None: """Test model with different dtypes.""" - model = ss.SoftSphereModel(dtype=dtype) + model = fss.SoftSphereModel(dtype=dtype) assert model.sigma.dtype == dtype assert model.epsilon.dtype == dtype @@ -244,7 +245,8 @@ def test_multispecies_initialization_defaults() -> None: """Test initialization of multi-species model with defaults.""" # Create with minimal parameters species = torch.tensor([0, 1], dtype=torch.long) - model = ss.SoftSphereMultiModel(species=species) + dtype = torch.float32 + model = fss.SoftSphereMultiModel(species=species, dtype=dtype) # Check matrices are created with defaults assert model.sigma_matrix.shape == (2, 2) @@ -252,12 +254,13 @@ def test_multispecies_initialization_defaults() -> None: assert model.alpha_matrix.shape == (2, 2) # Check default values - assert torch.allclose(model.sigma_matrix, ss.DEFAULT_SIGMA * torch.ones(2, 2)) - assert torch.allclose(model.epsilon_matrix, ss.DEFAULT_EPSILON * torch.ones(2, 2)) - assert torch.allclose(model.alpha_matrix, ss.DEFAULT_ALPHA * torch.ones(2, 2)) + ones = torch.ones(2, 2, dtype=dtype) + assert torch.allclose(model.sigma_matrix, fss.DEFAULT_SIGMA * ones) + assert torch.allclose(model.epsilon_matrix, fss.DEFAULT_EPSILON * ones) + assert torch.allclose(model.alpha_matrix, fss.DEFAULT_ALPHA * ones) # Check cutoff is max sigma - assert model.cutoff.item() == ss.DEFAULT_SIGMA.item() + assert model.cutoff.item() == fss.DEFAULT_SIGMA.item() def test_multispecies_initialization_custom() -> None: @@ -267,7 +270,7 @@ def test_multispecies_initialization_custom() -> None: epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]], dtype=torch.float64) alpha_matrix = torch.tensor([[2.0, 3.0], [3.0, 4.0]], dtype=torch.float64) - model = ss.SoftSphereMultiModel( + model = fss.SoftSphereMultiModel( species=species, sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, @@ -295,7 +298,7 @@ def test_multispecies_matrix_validation() -> None: # Should raise ValueError due to matrix size mismatch with pytest.raises(ValueError, match="sigma_matrix must have shape"): - ss.SoftSphereMultiModel( + fss.SoftSphereMultiModel( species=species, sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, @@ -329,7 +332,7 @@ def test_matrix_symmetry_validation(matrix_name: str, matrix: torch.Tensor) -> N # Should raise ValueError due to asymmetric matrix with pytest.raises(ValueError, match="is not symmetric"): - ss.SoftSphereMultiModel(**params) + fss.SoftSphereMultiModel(**params) def test_multispecies_cutoff_default() -> None: @@ -338,7 +341,7 @@ def test_multispecies_cutoff_default() -> None: species = torch.tensor([0, 1, 2], dtype=torch.long) sigma_matrix = torch.tensor([[1.0, 1.5, 2.0], [1.5, 2.0, 2.5], [2.0, 2.5, 3.0]]) - model = ss.SoftSphereMultiModel(species=species, sigma_matrix=sigma_matrix) + model = fss.SoftSphereMultiModel(species=species, sigma_matrix=sigma_matrix) # Cutoff should default to max value in sigma_matrix assert model.cutoff.item() == 3.0 @@ -361,7 +364,7 @@ def test_multispecies_model_flags(*, flag_name: str, flag_value: bool) -> None: """Test flags of the SoftSphereMultiModel.""" species = torch.tensor([0, 1], dtype=torch.long) - model = ss.SoftSphereMultiModel(species=species, **{flag_name: flag_value}) + model = fss.SoftSphereMultiModel(species=species, **{flag_name: flag_value}) # For SoftSphereMultiModel, we don't need to convert attribute names # as it uses public attribute names for all flags diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 81d904721..dd0ff94ed 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -12,7 +12,6 @@ to_constant_volume_bins, ) from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.optimizers import unit_cell_fire def test_exact_fit(): @@ -24,8 +23,9 @@ def test_exact_fit(): def test_weight_pos(): values = [[1, "x"], [2, "y"], [1, "z"]] bins = to_constant_volume_bins(values, 2, weight_pos=0) - for bin_ in bins: - for item in bin_: + for vol_bin in bins: + for item in vol_bin: + assert isinstance(item, list) assert isinstance(item[0], int) assert isinstance(item[1], str) @@ -34,8 +34,9 @@ def test_key_func(): values = [{"x": "a", "y": 1}, {"x": "b", "y": 5}, {"x": "b", "y": 3}] bins = to_constant_volume_bins(values, 2, key=lambda x: x["y"]) - for bin_ in bins: - for item in bin_: + for vol_bin in bins: + for item in vol_bin: + assert isinstance(item, dict) assert "x" in item assert "y" in item @@ -147,7 +148,7 @@ def test_binning_auto_batcher( assert batcher.memory_scalers[1] == fe_supercell_sim_state.n_atoms # Get batches until None is returned - batches = list(batcher) + batches = [batch for batch, _ in batcher] # Check we got the expected number of systems assert len(batches) == len(batcher.batched_states) @@ -172,9 +173,9 @@ def test_binning_auto_batcher_auto_metric( monkeypatch: pytest.MonkeyPatch, ) -> None: """Test BinningAutoBatcher with different states.""" - # monkeypath determine max memory scaler + # monkeypatch determine max memory scaler monkeypatch.setattr( - "torch_sim.autobatching.determine_max_batch_size", + "ts.autobatching.determine_max_batch_size", lambda *args, **kwargs: 50, # noqa: ARG005 ) @@ -194,7 +195,7 @@ def test_binning_auto_batcher_auto_metric( assert batcher.memory_scalers[1] == fe_supercell_sim_state.n_atoms # Get batches until None is returned - batches = list(batcher) + batches = [batch for batch, _ in batcher] # Check we got the expected number of batches assert len(batches) == len(batcher.batched_states) @@ -217,18 +218,17 @@ def test_binning_auto_batcher_with_indices( fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - """Test BinningAutoBatcher with return_indices=True.""" + """Test BinningAutoBatcher with indices tracking.""" states = [si_sim_state, fe_supercell_sim_state] batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0, - return_indices=True, ) batcher.load_states(states) - # Get batches with indices + # Get batches and track indices manually batches_with_indices = [] for batch, indices in batcher: batches_with_indices.append((batch, indices)) @@ -237,8 +237,8 @@ def test_binning_auto_batcher_with_indices( assert len(batches_with_indices) == len(batcher.batched_states) # Check that the indices match the expected bin indices - for i, (_, indices) in enumerate(batches_with_indices): - assert indices == batcher.index_bins[i] + for idx, (_, indices) in enumerate(batches_with_indices): + assert indices == batcher.index_bins[idx] def test_binning_auto_batcher_restore_order_with_split_states( @@ -258,14 +258,9 @@ def test_binning_auto_batcher_restore_order_with_split_states( ) batcher.load_states(states) - # Get batches until None is returned + # loop through all batches to test we're restore order correctly batches = [] - while True: - batch = batcher.next_batch() - if batch is None: - break - # Split each batch into individual states to simulate processing - # split_batch = split_state(batch) + for batch, _indices in batcher: batches.append(batch) # Test restore_original_order with split states @@ -318,23 +313,21 @@ def test_in_flight_auto_batcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, # Set a small value to force multiple batches - return_indices=True, ) batcher.load_states(states) # Get the first batch - first_batch, [], _ = batcher.next_batch(states, None) + first_batch, [] = batcher.next_batch(states, None) assert isinstance(first_batch, ts.SimState) # Create a convergence tensor where the first state has converged convergence = torch.tensor([True]) # Get the next batch - next_batch, popped_batch, idx = batcher.next_batch(first_batch, convergence) + next_batch, popped_batch = batcher.next_batch(first_batch, convergence) assert isinstance(next_batch, ts.SimState) assert isinstance(popped_batch, list) assert isinstance(popped_batch[0], ts.SimState) - assert idx == [1] # Check that the converged state was removed assert len(batcher.current_scalers) == 1 @@ -345,7 +338,7 @@ def test_in_flight_auto_batcher( convergence = torch.tensor([True]) # Get the next batch, which should be None since all states have converged - final_batch, popped_batch, _ = batcher.next_batch(next_batch, convergence) + final_batch, popped_batch = batcher.next_batch(next_batch, convergence) assert final_batch is None # Check that all states are marked as completed @@ -363,9 +356,7 @@ def test_determine_max_batch_size_fibonacci( def mock_measure(*_args: Any, **_kwargs: Any) -> float: return 0.1 # Return a small constant memory usage - monkeypatch.setattr( - "torch_sim.autobatching.measure_model_memory_forward", mock_measure - ) + monkeypatch.setattr("ts.autobatching.measure_model_memory_forward", mock_measure) # Test with a small max_atoms value to limit the sequence max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10) @@ -384,9 +375,7 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( scale_factor: float, ) -> None: """Test determine_max_batch_size doesn't infinite loop with small scale factors.""" - monkeypatch.setattr( - "torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1 - ) + monkeypatch.setattr("ts.autobatching.measure_model_memory_forward", lambda *_: 0.1) max_size = determine_max_batch_size( si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor @@ -452,8 +441,8 @@ def test_in_flight_auto_batcher_restore_order( "num_steps_per_batch", [ 5, # At 5 steps, not every state will converge before the next batch. - # This tests the merging of partially converged states with new states - # which has been a bug in the past. See https://github.com/TorchSim/torch-sim/pull/219 + # This tests the merging of partially converged states with new states + # which has been a bug in the past. 10, # At 10 steps, all states will converge before the next batch ], ) @@ -463,10 +452,10 @@ def test_in_flight_with_fire( lj_model: LennardJonesModel, num_steps_per_batch: int, ) -> None: - fire_init, fire_update = unit_cell_fire(lj_model) - - si_fire_state = fire_init(si_sim_state) - fe_fire_state = fire_init(fe_supercell_sim_state) + si_fire_state = ts.fire_init(lj_model, si_sim_state, cell_filter=ts.CellFilter.unit) + fe_fire_state = ts.fire_init( + lj_model, fe_supercell_sim_state, cell_filter=ts.CellFilter.unit + ) fire_states = [si_fire_state, fe_fire_state] * 5 fire_states = [state.clone() for state in fire_states] @@ -481,7 +470,7 @@ def test_in_flight_with_fire( ) batcher.load_states(fire_states) - def convergence_fn(state: ts.SimState) -> bool: + def convergence_fn(state: ts.FireState) -> torch.Tensor: system_wise_max_force = torch.zeros( state.n_systems, device=state.device, dtype=torch.float64 ) @@ -500,7 +489,7 @@ def convergence_fn(state: ts.SimState) -> bool: break for _ in range(num_steps_per_batch): - state = fire_update(state) + state = ts.fire_step(lj_model, state) convergence_tensor = convergence_fn(state) assert len(all_completed_states) == len(fire_states) @@ -511,10 +500,10 @@ def test_binning_auto_batcher_with_fire( fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - fire_init, fire_update = unit_cell_fire(lj_model) - - si_fire_state = fire_init(si_sim_state) - fe_fire_state = fire_init(fe_supercell_sim_state) + si_fire_state = ts.fire_init(lj_model, si_sim_state, cell_filter=ts.CellFilter.unit) + fe_fire_state = ts.fire_init( + lj_model, fe_supercell_sim_state, cell_filter=ts.CellFilter.unit + ) fire_states = [si_fire_state, fe_fire_state] * 5 fire_states = [state.clone() for state in fire_states] @@ -532,10 +521,10 @@ def test_binning_auto_batcher_with_fire( finished_states = [] n_systems = 0 - for batch in batcher: + for batch, _ in batcher: n_systems += 1 for _ in range(5): - batch = fire_update(batch) + batch = ts.fire_step(lj_model, batch) finished_states.extend(batch.split()) @@ -568,6 +557,7 @@ def test_in_flight_max_iterations( # Get the first batch state, [] = batcher.next_batch(None, None) + assert state is not None # Create a convergence tensor that never converges convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool) @@ -595,5 +585,5 @@ def test_in_flight_max_iterations( assert iteration_count == 3 # Verify swap_attempts tracking - for i in range(len(states)): - assert batcher.swap_attempts[i] == max_attempts + for idx in range(len(states)): + assert batcher.swap_attempts[idx] == max_attempts diff --git a/tests/test_correlations.py b/tests/test_correlations.py index 31624819d..a26dc0d68 100644 --- a/tests/test_correlations.py +++ b/tests/test_correlations.py @@ -12,6 +12,8 @@ import pytest import torch +import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.properties.correlations import ( CircularBuffer, CorrelationCalculator, @@ -43,24 +45,24 @@ def split(self) -> list["MockState"]: @pytest.fixture -def buffer(device: torch.device) -> CircularBuffer: +def buffer() -> CircularBuffer: """Fixture for CircularBuffer instance.""" - return CircularBuffer(size=10, device=device) + return CircularBuffer(size=10, device=DEVICE) @pytest.fixture -def mock_state_factory(device: torch.device) -> Callable[[torch.Tensor], MockState]: +def mock_state_factory() -> Callable[[torch.Tensor], MockState]: """Factory fixture for creating mock state objects.""" def create_mock_state(velocities: torch.Tensor) -> MockState: """Create mock state with given data tensor.""" - return MockState(velocities, device) + return MockState(velocities, DEVICE) return create_mock_state @pytest.fixture -def corr_calc(device: torch.device) -> CorrelationCalculator: +def corr_calc() -> CorrelationCalculator: """Fixture for creating a CorrelationCalculator instance.""" window_size = 5 @@ -72,7 +74,7 @@ def velocity_getter(state: MockState) -> torch.Tensor: return CorrelationCalculator( window_size=window_size, properties=properties, - device=device, + device=DEVICE, normalize=True, ) @@ -80,13 +82,13 @@ def velocity_getter(state: MockState) -> torch.Tensor: class TestCircularBuffer: """Test suite for CircularBuffer functionality.""" - def test_circular_buffer_operations(self, device: torch.device) -> None: + def test_circular_buffer_operations(self) -> None: """Test core buffer operations including append, retrieval, and wraparound. Tests initialization, data append, retrieval and circular wrapping. """ - buffer = CircularBuffer(size=3, device=device) + buffer = CircularBuffer(size=3, device=DEVICE) # Test initialization state assert buffer.size == 3 @@ -96,26 +98,26 @@ def test_circular_buffer_operations(self, device: torch.device) -> None: assert not buffer.is_full # Test append and retrieval - buffer.append(torch.tensor([1.0], device=device)) - buffer.append(torch.tensor([2.0], device=device)) + buffer.append(torch.tensor([1.0], device=DEVICE)) + buffer.append(torch.tensor([2.0], device=DEVICE)) assert buffer.count == 2 assert buffer.head == 2 result = buffer.get_array() - expected = torch.tensor([[1.0], [2.0]], device=device) + expected = torch.tensor([[1.0], [2.0]], device=DEVICE) assert torch.allclose(result, expected) # Test wraparound behavior - buffer.append(torch.tensor([3.0], device=device)) + buffer.append(torch.tensor([3.0], device=DEVICE)) assert buffer.is_full - buffer.append(torch.tensor([4.0], device=device)) + buffer.append(torch.tensor([4.0], device=DEVICE)) assert buffer.count == 3 assert buffer.head == 1 result = buffer.get_array() - expected = torch.tensor([[2.0], [3.0], [4.0]], device=device) + expected = torch.tensor([[2.0], [3.0], [4.0]], device=DEVICE) assert torch.allclose(result, expected) @@ -156,9 +158,7 @@ def test_update_frequency( corr_calc.update(state1) assert corr_calc.buffers["velocity"].count == 2 - def test_constant_signal( - self, device: torch.device, mock_state_factory: Callable - ) -> None: + def test_constant_signal(self, mock_state_factory: Callable) -> None: """Test correlation of constant signals. Mean-centered constant signals should have zero autocorrelation. @@ -167,12 +167,12 @@ def test_constant_signal( corr_calc = CorrelationCalculator( window_size=win_size, properties={"velocity": lambda s: s.velocities}, - device=device, + device=DEVICE, normalize=False, ) # Constant signal - const_vel = torch.ones((2, 3), device=device) + const_vel = torch.ones((2, 3), device=DEVICE) # Identical states for _ in range(win_size): @@ -183,9 +183,7 @@ def test_constant_signal( acf = corr_calc.get_auto_correlations()["velocity"] assert torch.allclose(acf, torch.zeros_like(acf), atol=1e-5) - def test_white_noise( - self, device: torch.device, mock_state_factory: Callable - ) -> None: + def test_white_noise(self, mock_state_factory: Callable) -> None: """Test autocorrelation of white noise. White noise should have a delta function as its autocorrelation. @@ -194,7 +192,7 @@ def test_white_noise( corr_calc = CorrelationCalculator( window_size=win_size, properties={"velocity": lambda s: s.velocities}, - device=device, + device=DEVICE, normalize=True, ) @@ -202,7 +200,7 @@ def test_white_noise( # White noise for _ in range(win_size): - noise = torch.randn(4, 3, device=device) + noise = torch.randn(4, 3, device=DEVICE) state = mock_state_factory(noise) corr_calc.update(state) @@ -211,29 +209,29 @@ def test_white_noise( acf_mean = torch.mean(acf, dim=(1, 2)) # Delta function - assert torch.isclose(acf_mean[0], torch.tensor(1.0, device=device)) + assert torch.isclose(acf_mean[0], torch.tensor(1.0, device=DEVICE)) assert torch.all(torch.abs(acf_mean[1:]) < 0.3) - def test_sinusoidal(self, device: torch.device, mock_state_factory: Callable) -> None: + def test_sinusoidal(self, mock_state_factory: Callable) -> None: """Test autocorrelation of sinusoidal signals. Sine waves should have a cosine-like acf. """ - win_size = 32 + window_size = 32 period = 8 corr_calc = CorrelationCalculator( - window_size=win_size, + window_size=window_size, properties={"velocity": lambda s: s.velocities}, - device=device, + device=DEVICE, normalize=True, ) - t = torch.arange(win_size, dtype=torch.float32, device=device) + t = torch.arange(window_size, device=DEVICE) freq = 2 * math.pi / period # Sine - for i in range(win_size): - phase = freq * t[i] + for idx in range(window_size): + phase = freq * t[idx] signal_val = torch.sin(phase) # Expand to shape [2, 3] @@ -244,7 +242,7 @@ def test_sinusoidal(self, device: torch.device, mock_state_factory: Callable) -> acf = corr_calc.get_auto_correlations()["velocity"] acf_mean = torch.mean(acf, dim=(1, 2)) - assert torch.isclose(acf_mean[0], torch.tensor(1.0, device=device)) + assert torch.isclose(acf_mean[0], torch.tensor(1.0, device=DEVICE)) half_period = period // 2 assert acf_mean[half_period] < 0 @@ -268,9 +266,7 @@ def test_reset( assert corr_calc.buffers["velocity"].count == 0 assert corr_calc.correlations == {} - def test_normalization( - self, device: torch.device, mock_state_factory: Callable - ) -> None: + def test_normalization(self, mock_state_factory: Callable) -> None: """Test normalization of correlation functions. Validates that normalized correlations have first lag = 1.0. @@ -278,21 +274,21 @@ def test_normalization( corr_calc_norm = CorrelationCalculator( window_size=5, properties={"velocity": lambda s: s.velocities}, - device=device, + device=DEVICE, normalize=True, ) corr_calc_no_norm = CorrelationCalculator( window_size=5, properties={"velocity": lambda s: s.velocities}, - device=device, + device=DEVICE, normalize=False, ) torch.manual_seed(42) for _ in range(5): - vel = torch.randn((2, 3), device=device) + vel = torch.randn((2, 3), device=DEVICE) # Reuse data state = mock_state_factory(vel) @@ -303,7 +299,7 @@ def test_normalization( corr_no_norm = corr_calc_no_norm.get_auto_correlations()["velocity"] norm_first = torch.mean(corr_norm[0]) - assert torch.isclose(norm_first, torch.tensor(1.0, device=device)) + assert torch.isclose(norm_first, torch.tensor(1.0, device=DEVICE)) no_norm_first = torch.mean(corr_no_norm[0]) assert not torch.allclose(no_norm_first, torch.ones_like(no_norm_first)) @@ -315,18 +311,16 @@ def test_normalization( expected = corr_no_norm[:, a, d] / scale_factor assert torch.allclose(corr_norm[:, a, d], expected, atol=1e-5) - def test_cross_correlation_basics( - self, device: torch.device, mock_state_factory: Callable - ) -> None: + def test_cross_correlation_basics(self, mock_state_factory: Callable) -> None: """Test basic cross-correlation.""" - win_size = 10 + window_size = 10 corr_calc = CorrelationCalculator( - window_size=win_size, + window_size=window_size, properties={ "signal_a": lambda s: s.velocities[:1], "signal_b": lambda s: s.velocities[1:], }, - device=device, + device=DEVICE, normalize=True, ) @@ -334,14 +328,14 @@ def test_cross_correlation_basics( torch.manual_seed(42) # Initialize prev_signal_a - prev_signal_a = torch.randn(1, 3, device=device) + prev_signal_a = torch.randn(1, 3, device=DEVICE) - for i in range(win_size): - signal_a = torch.randn(1, 3, device=device) - if i > 0: - signal_b = prev_signal_a * 0.7 + torch.randn(1, 3, device=device) * 0.3 + for idx in range(window_size): + signal_a = torch.randn(1, 3, device=DEVICE) + if idx > 0: + signal_b = prev_signal_a * 0.7 + torch.randn(1, 3, device=DEVICE) * 0.3 else: - signal_b = torch.randn(1, 3, device=device) + signal_b = torch.randn(1, 3, device=DEVICE) prev_signal_a = signal_a.clone() @@ -353,7 +347,7 @@ def test_cross_correlation_basics( assert ("signal_a", "signal_b") in cross_corrs cross_corr = cross_corrs[("signal_a", "signal_b")] - assert len(cross_corr) == win_size + assert len(cross_corr) == window_size @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_device_migration(self, mock_state_factory: Callable) -> None: @@ -374,7 +368,7 @@ def test_device_migration(self, mock_state_factory: Callable) -> None: for _ in range(3): corr_calc.update(state) - cuda_device = torch.device("cuda") + cuda_device = torch.device("cuda:0") corr_calc = corr_calc.to(cuda_device) assert corr_calc.device == cuda_device @@ -383,9 +377,7 @@ def test_device_migration(self, mock_state_factory: Callable) -> None: assert corr_calc.buffers["velocity"].buffer.device == cuda_device -def test_velocity_autocorrelation( - device: torch.device, mock_state_factory: Callable -) -> None: +def test_velocity_autocorrelation(mock_state_factory: Callable) -> None: """Test VACF calculation with cosine pattern velocities. Test checks: @@ -393,25 +385,24 @@ def test_velocity_autocorrelation( 2. Expected periodicity 3. Exhibits sign changes at specific locations """ - window_size = 32 - period = 8 + window_size, period = 32, 8 vacf_calc = VelocityAutoCorrelation( window_size=window_size, - device=device, + device=DEVICE, use_running_average=False, normalize=True, ) # Cosine velocity pattern - t = torch.arange(window_size, dtype=torch.float32, device=device) + t = torch.arange(window_size, device=DEVICE) freq = 2 * math.pi / period velocities = [] - for i in range(window_size): + for idx in range(window_size): # cos(Ο‰t) pattern - val = torch.cos(freq * t[i]) - vel = torch.tensor([[val, val, val]], device=device) + val = torch.cos(freq * t[idx]) + vel = torch.tensor([[val, val, val]], device=DEVICE) velocities.append(vel) for vel in velocities: @@ -422,7 +413,7 @@ def test_velocity_autocorrelation( assert vacf is not None # 1. First lag is 1.0 - assert torch.isclose(vacf[0], torch.tensor(1.0, device=device)) + assert torch.isclose(vacf[0], torch.tensor(1.0)) # 2. Check periodicity expect # positive peaks at t=0, t=8, t=16, ... @@ -445,7 +436,7 @@ def test_velocity_autocorrelation( def test_velocity_autocorrelation_with_trajectory_reporter( - device: torch.device, mock_state_factory: Callable + mock_state_factory: Callable, ) -> None: """Test VelocityAutoCorrelation integration with TrajectoryReporter. @@ -453,17 +444,13 @@ def test_velocity_autocorrelation_with_trajectory_reporter( 1. ``VelocityAutoCorrelation`` as a property calculator 2. ``TrajectoryReporter`` calls correctly """ - from torch_sim.properties.correlations import VelocityAutoCorrelation - from torch_sim.trajectory import TrajectoryReporter window_size = 20 vacf_calc = VelocityAutoCorrelation( - window_size=window_size, - device=device, - use_running_average=True, + window_size=window_size, device=DEVICE, use_running_average=True ) - reporter = TrajectoryReporter( + reporter = ts.TrajectoryReporter( None, # Don't write file state_frequency=100, prop_calculators={5: {"vacf": vacf_calc}}, @@ -473,7 +460,7 @@ def test_velocity_autocorrelation_with_trajectory_reporter( n_steps = 25 for step in range(n_steps): # Mock state - velocities = torch.randn(4, 3, device=device) + velocities = torch.randn(4, 3, device=DEVICE) state = mock_state_factory(velocities) props = reporter.report(state, step) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 91a063da0..07531adde 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -2,6 +2,7 @@ import torch import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.elastic import ( calculate_elastic_moduli, calculate_elastic_tensor, @@ -10,7 +11,6 @@ get_elementary_deformations, get_strain, ) -from torch_sim.optimizers import frechet_cell_fire from torch_sim.typing import BravaisType from torch_sim.units import UnitConversion @@ -219,7 +219,7 @@ def test_get_elementary_deformations_strain_consistency( n_deform=n_deform, max_strain_normal=max_strain_normal, max_strain_shear=max_strain_shear, - bravais_type=BravaisType.TRICLINIC, # Test all axes + bravais_type=BravaisType.triclinic, # Test all axes ) # Should generate deformations for all 6 axes (triclinic) @@ -232,11 +232,11 @@ def test_get_elementary_deformations_strain_consistency( # Check that each deformed state produces a strain with expected dominant component axis_to_strain_idx = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} # axis -> Voigt index - for i, deformed_state in enumerate(deformed_states): + for def_idx, deformed_state in enumerate(deformed_states): strain = get_strain(deformed_state, cu_sim_state) # Determine which axis this deformation corresponds to - axis = i // strains_per_axis # Integer division to get axis index + axis = def_idx // strains_per_axis # Integer division to get axis index strain_idx = axis_to_strain_idx[axis] # The strain component corresponding to this axis should be the largest @@ -253,13 +253,13 @@ def test_get_elementary_deformations_strain_consistency( @pytest.fixture -def mace_model(device: torch.device) -> MaceModel: +def mace_model() -> MaceModel: """Create a MACE model fixture for testing.""" mace_model = mace_mp(model="medium", default_dtype="float64", return_raw_model=True) return MaceModel( model=mace_model, - device=device, + device=DEVICE, dtype=torch.float64, compute_forces=True, compute_stress=True, @@ -269,12 +269,12 @@ def mace_model(device: torch.device) -> MaceModel: @pytest.mark.parametrize( ("sim_state_name", "expected_bravais_type", "atol"), [ - ("cu_sim_state", BravaisType.CUBIC, 2e-1), - ("mg_sim_state", BravaisType.HEXAGONAL, 5e-1), - ("sb_sim_state", BravaisType.TRIGONAL, 5e-1), - ("tio2_sim_state", BravaisType.TETRAGONAL, 5e-1), - ("ga_sim_state", BravaisType.ORTHORHOMBIC, 5e-1), - ("niti_sim_state", BravaisType.MONOCLINIC, 5e-1), + ("cu_sim_state", BravaisType.cubic, 2e-1), + ("mg_sim_state", BravaisType.hexagonal, 5e-1), + ("sb_sim_state", BravaisType.trigonal, 5e-1), + ("tio2_sim_state", BravaisType.tetragonal, 5e-1), + ("ga_sim_state", BravaisType.orthorhombic, 5e-1), + ("niti_sim_state", BravaisType.monoclinic, 5e-1), ], ) def test_elastic_tensor_symmetries( @@ -305,8 +305,12 @@ def test_elastic_tensor_symmetries( ) # Relax positions and cell - fire_init, fire_update = frechet_cell_fire(model=model, scalar_pressure=0.0) - state = fire_init(state=state) + state = ts.fire_init( + model=model, + state=state, + scalar_pressure=0.0, + cell_filter=ts.CellFilter.frechet, + ) fmax = 1e-5 for _ in range(300): @@ -316,7 +320,7 @@ def test_elastic_tensor_symmetries( current_fmax = torch.max(torch.abs(state.forces.squeeze())) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = fire_update(state=state) + state = ts.fire_step(model=model, state=state) # Verify the Bravais type of the relaxed structure actual_bravais_type = get_bravais_type(state) @@ -331,7 +335,7 @@ def test_elastic_tensor_symmetries( * UnitConversion.eV_per_Ang3_to_GPa ) C_triclinic = ( - calculate_elastic_tensor(model, state=state, bravais_type=BravaisType.TRICLINIC) + calculate_elastic_tensor(model, state=state, bravais_type=BravaisType.triclinic) * UnitConversion.eV_per_Ang3_to_GPa ) @@ -348,8 +352,12 @@ def test_copper_elastic_properties( """Test calculation of elastic properties for copper.""" # Relax positions and cell - fire_init, fire_update = frechet_cell_fire(model=mace_model, scalar_pressure=0.0) - state = fire_init(state=cu_sim_state) + state = ts.fire_init( + model=mace_model, + state=cu_sim_state, + scalar_pressure=0.0, + cell_filter=ts.CellFilter.frechet, + ) fmax = 1e-5 for _ in range(300): pressure = ( @@ -358,7 +366,7 @@ def test_copper_elastic_properties( current_fmax = torch.max(torch.abs(state.forces.squeeze())) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = fire_update(state=state) + state = ts.fire_step(model=mace_model, state=state) # Calculate elastic tensor bravais_type = get_bravais_type(state) @@ -372,8 +380,7 @@ def test_copper_elastic_properties( # Calculate elastic moduli bulk_modulus, shear_modulus, _, _ = calculate_elastic_moduli(elastic_tensor) - device = state.device - dtype = state.dtype + device, dtype = state.device, state.dtype # Expected values expected_elastic_tensor = torch.tensor( diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 67621a5bb..fd39ec58e 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -2,114 +2,114 @@ import torch import torch_sim as ts -from torch_sim.integrators import ( - NPTLangevinState, - calculate_momenta, - npt_langevin, - nve, - nvt_langevin, -) +from tests.conftest import DEVICE, DTYPE +from torch_sim.integrators import calculate_momenta +from torch_sim.integrators.npt import _compute_cell_force from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.quantities import calc_kT -from torch_sim.state import concatenate_states from torch_sim.units import MetalUnits -def test_calculate_momenta_basic(device: torch.device): +def test_calculate_momenta_basic(): """Test basic functionality of calculate_momenta.""" seed = 42 - dtype = torch.float64 # Create test inputs for 3 systems with 2 atoms each n_atoms = 8 - positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) - masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 + positions = torch.randn(n_atoms, 3, dtype=DTYPE, device=DEVICE) + masses = torch.rand(n_atoms, dtype=DTYPE, device=DEVICE) + 0.5 system_idx = torch.tensor( - [0, 0, 1, 1, 2, 2, 3, 3], device=device + [0, 0, 1, 1, 2, 2, 3, 3], device=DEVICE ) # 3 systems with 2 atoms each - kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=DTYPE, device=DEVICE) # Run the function momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) # Basic checks assert momenta.shape == positions.shape - assert momenta.dtype == dtype - assert momenta.device == device + assert momenta.dtype == DTYPE + assert momenta.device == DEVICE # Check that each system has zero center of mass momentum - for b in range(4): - system_mask = system_idx == b + for sys_idx in range(4): + system_mask = system_idx == sys_idx system_momenta = momenta[system_mask] com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( - com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 + com_momentum, torch.zeros(3, dtype=DTYPE, device=DEVICE), atol=1e-10 ) -def test_calculate_momenta_single_atoms(device: torch.device): +def test_calculate_momenta_single_atoms(): """Test that calculate_momenta preserves momentum for systems with single atoms.""" seed = 42 - dtype = torch.float64 # Create test inputs with some systems having single atoms - positions = torch.randn(5, 3, dtype=dtype, device=device) - masses = torch.rand(5, dtype=dtype, device=device) + 0.5 + positions = torch.randn(5, 3, dtype=DTYPE, device=DEVICE) + masses = torch.rand(5, dtype=DTYPE, device=DEVICE) + 0.5 system_idx = torch.tensor( - [0, 1, 1, 2, 3], device=device + [0, 1, 1, 2, 3], device=DEVICE ) # systems 0, 2, and 3 have single atoms - kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=DTYPE, device=DEVICE) # Generate momenta and save the raw values before COM correction - generator = torch.Generator(device=device).manual_seed(seed) + generator = torch.Generator(device=DEVICE).manual_seed(seed) raw_momenta = torch.randn( - positions.shape, device=device, dtype=dtype, generator=generator + positions.shape, device=DEVICE, dtype=DTYPE, generator=generator ) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1) # Run the function momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) # Check that single-atom systems have unchanged momenta - for b in [0, 2, 3]: # Single atom systems - system_mask = system_idx == b + for sys_idx in (0, 2, 3): # Single atom systems + system_mask = system_idx == sys_idx # The momentum should be exactly the same as the raw value for single atoms assert torch.allclose(momenta[system_mask], raw_momenta[system_mask]) # Check that multi-atom systems have zero COM - for b in [1]: # Multi-atom systems - system_mask = system_idx == b + for sys_idx in (1,): # Multi-atom systems + system_mask = system_idx == sys_idx system_momenta = momenta[system_mask] com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( - com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 + com_momentum, torch.zeros(3, dtype=DTYPE, device=DEVICE), atol=1e-10 ) -def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): - dtype = torch.float64 +def test_npt_langevin( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: n_steps = 200 - dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor(100.0, dtype=dtype) * MetalUnits.temperature - external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure - - # Initialize integrator - init_fn, update_fn = npt_langevin( - model=lj_model, - dt=dt, - kT=kT, - external_pressure=external_pressure, - alpha=40 * dt, + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + alpha = 40 * dt + cell_alpha = alpha + b_tau = 1 / (1000 * dt) + + # Initialize integrator using new direct API + state = ts.npt_langevin_init( + model=lj_model, state=ar_double_sim_state, dt=dt, kT=kT, alpha=alpha, seed=42 ) # Run dynamics for several steps - state = init_fn(state=ar_double_sim_state, seed=42) energies = [] temperatures = [] for _step in range(n_steps): - state = update_fn(state=state) + state = ts.npt_langevin_update( + model=lj_model, + state=state, + dt=dt, + kT=kT, + external_pressure=external_pressure, + alpha=alpha, + cell_alpha=cell_alpha, + b_tau=b_tau, + ) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT( + temp = ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) energies.append(state.energy) @@ -151,30 +151,36 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo def test_npt_langevin_multi_kt( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): - dtype = torch.float64 n_steps = 200 - dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature - external_pressure = torch.tensor(0, dtype=dtype) * MetalUnits.pressure - - # Initialize integrator - init_fn, update_fn = npt_langevin( - model=lj_model, - dt=dt, - kT=kT, - external_pressure=external_pressure, - alpha=40 * dt, + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor([300, 10_000], dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0, dtype=DTYPE) * MetalUnits.pressure + alpha = 40 * dt + cell_alpha = alpha + b_tau = 1 / (1000 * dt) + + # Initialize integrator using new direct API + state = ts.npt_langevin_init( + model=lj_model, state=ar_double_sim_state, dt=dt, kT=kT, alpha=alpha, seed=42 ) # Run dynamics for several steps - state = init_fn(state=ar_double_sim_state, seed=42) energies = [] temperatures = [] for _step in range(n_steps): - state = update_fn(state=state) + state = ts.npt_langevin_update( + model=lj_model, + state=state, + dt=dt, + kT=kT, + external_pressure=external_pressure, + alpha=alpha, + cell_alpha=cell_alpha, + b_tau=b_tau, + ) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT( + temp = ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) energies.append(state.energy) @@ -197,27 +203,21 @@ def test_npt_langevin_multi_kt( def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): - dtype = torch.float64 n_steps = 100 - dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - init_fn, update_fn = nvt_langevin( - model=lj_model, - dt=dt, - kT=kT, + state = ts.nvt_langevin_init( + model=lj_model, state=ar_double_sim_state, kT=kT, seed=42 ) - - # Run dynamics for several steps - state = init_fn(state=ar_double_sim_state, seed=42) energies = [] temperatures = [] for _step in range(n_steps): - state = update_fn(state=state) + state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT( + temp = ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) energies.append(state.energy) @@ -259,27 +259,21 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo def test_nvt_langevin_multi_kt( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): - dtype = torch.float64 n_steps = 200 - dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor([300, 10_000], dtype=dtype) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor([300, 10_000], dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - init_fn, update_fn = nvt_langevin( - model=lj_model, - dt=dt, - kT=kT, + state = ts.nvt_langevin_init( + model=lj_model, state=ar_double_sim_state, kT=kT, seed=42 ) - - # Run dynamics for several steps - state = init_fn(state=ar_double_sim_state, seed=42) energies = [] temperatures = [] for _step in range(n_steps): - state = update_fn(state=state) + state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) # Calculate instantaneous temperature from kinetic energy - temp = calc_kT( + temp = ts.calc_kT( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) energies.append(state.energy) @@ -302,19 +296,17 @@ def test_nvt_langevin_multi_kt( def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): - dtype = torch.float64 n_steps = 100 - dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor(100.0, dtype=dtype) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - nve_init, nve_update = nve(model=lj_model, dt=dt, kT=kT) - state = nve_init(state=ar_double_sim_state, seed=42) + state = ts.nve_init(model=lj_model, state=ar_double_sim_state, kT=kT, seed=42) # Run dynamics for several steps energies = [] for _step in range(n_steps): - state = nve_update(state=state, dt=dt) + state = ts.nve_update(model=lj_model, state=state, dt=dt) energies.append(state.energy) @@ -335,7 +327,7 @@ def test_compare_single_vs_batched_integrators( ) -> None: """Test NVE single vs batched for a tilted cell to verify PBC wrapping. - NOTE: added triclinic cell after https://github.com/TorchSim/torch-sim/issues/171. + NOTE: added triclinic cell after #171. Although the addition doesn't fail if we do not add the changes suggested in issue. """ sim_state = request.getfixturevalue(sim_state_fixture_name) @@ -343,7 +335,7 @@ def test_compare_single_vs_batched_integrators( initial_states = { "single": sim_state, - "batched": concatenate_states([sim_state, sim_state]), + "batched": ts.concatenate_states([sim_state, sim_state]), } final_states = {} @@ -352,14 +344,15 @@ def test_compare_single_vs_batched_integrators( kT = torch.tensor(100.0) * MetalUnits.temperature dt = torch.tensor(0.001) # Small timestep for stability - nve_init, nve_update = nve(model=lj_model, dt=dt, kT=kT) # Initialize momenta (even if zero) and get forces - state = nve_init(state=state, seed=42) # kT is ignored if momenta are set below + state = ts.nve_init( + model=lj_model, state=state, kT=kT, seed=42 + ) # kT is ignored if momenta are set below # Ensure momenta start at zero AFTER init which might randomize them based on kT state.momenta = torch.zeros_like(state.momenta) # Start from rest for _step in range(n_steps): - state = nve_update(state=state, dt=dt) + state = ts.nve_update(model=lj_model, state=state, dt=dt) final_states[state_name] = state @@ -369,7 +362,7 @@ def test_compare_single_vs_batched_integrators( batched_state_1 = final_states["batched"][1] # Compare single state results with each part of the batched state - for final_state in [batched_state_0, batched_state_1]: + for final_state in (batched_state_0, batched_state_1): # Check positions first - most likely to fail with incorrect PBC torch.testing.assert_close(single_state.positions, final_state.positions) # Check other state components @@ -381,15 +374,12 @@ def test_compare_single_vs_batched_integrators( def test_compute_cell_force_atoms_per_system(): - """Test that compute_cell_force correctly scales by number of atoms per system. - - Covers fix in https://github.com/TorchSim/torch-sim/pull/153.""" - from torch_sim.integrators.npt import _compute_cell_force + """Test that compute_cell_force correctly scales by number of atoms per system.""" # Setup minimal state with two systems having 8:1 atom ratio s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long) - state = NPTLangevinState( + state = ts.NPTLangevinState( positions=torch.zeros((72, 3)), velocities=torch.zeros((72, 3)), energy=torch.zeros(2), diff --git a/tests/test_io.py b/tests/test_io.py index e90f6ab59..26e4a34c7 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -9,19 +9,21 @@ from pymatgen.core import Structure import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.state import SimState -def test_single_structure_to_state(si_structure: Structure, device: torch.device) -> None: +def test_single_structure_to_state(si_structure: Structure) -> None: """Test conversion from pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state(si_structure, device, torch.float64) + state = ts.io.structures_to_state(si_structure, DEVICE, torch.float64) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert all( - t.device.type == device.type for t in [state.positions, state.masses, state.cell] + t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) ) assert all( - t.dtype == torch.float64 for t in [state.positions, state.masses, state.cell] + t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) ) assert state.atomic_numbers.dtype == torch.int @@ -31,18 +33,16 @@ def test_single_structure_to_state(si_structure: Structure, device: torch.device assert torch.all(state.atomic_numbers == 14) # Si atomic number assert torch.allclose( state.cell, - torch.diag(torch.full((3,), 5.43, device=device, dtype=torch.float64)), + torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), ) -def test_multiple_structures_to_state( - si_structure: Structure, device: torch.device -) -> None: +def test_multiple_structures_to_state(si_structure: Structure) -> None: """Test conversion from list of pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state([si_structure, si_structure], device, torch.float64) + state = ts.io.structures_to_state([si_structure, si_structure], DEVICE, torch.float64) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) @@ -51,16 +51,16 @@ def test_multiple_structures_to_state( assert state.system_idx.shape == (16,) assert torch.all( state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8) + == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8) ) -def test_single_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: +def test_single_atoms_to_state(si_atoms: Atoms) -> None: """Test conversion from ASE Atoms to state tensors.""" - state = ts.io.atoms_to_state(si_atoms, device, torch.float64) + state = ts.io.atoms_to_state(si_atoms, DEVICE, torch.float64) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) @@ -70,12 +70,12 @@ def test_single_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert torch.all(state.system_idx == 0) -def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: +def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: """Test conversion from ASE Atoms to state tensors.""" - state = ts.io.atoms_to_state([si_atoms, si_atoms], device, torch.float64) + state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, torch.float64) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) @@ -84,11 +84,11 @@ def test_multiple_atoms_to_state(si_atoms: Atoms, device: torch.device) -> None: assert state.system_idx.shape == (16,) assert torch.all( state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), ) -def test_state_to_structure(ar_supercell_sim_state: ts.SimState) -> None: +def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) assert len(structures) == 1 @@ -96,7 +96,7 @@ def test_state_to_structure(ar_supercell_sim_state: ts.SimState) -> None: assert len(structures[0]) == 32 -def test_state_to_multiple_structures(ar_double_sim_state: ts.SimState) -> None: +def test_state_to_multiple_structures(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_double_sim_state) assert len(structures) == 2 @@ -106,7 +106,7 @@ def test_state_to_multiple_structures(ar_double_sim_state: ts.SimState) -> None: assert len(structures[1]) == 32 -def test_state_to_atoms(ar_supercell_sim_state: ts.SimState) -> None: +def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_supercell_sim_state) assert len(atoms) == 1 @@ -114,7 +114,7 @@ def test_state_to_atoms(ar_supercell_sim_state: ts.SimState) -> None: assert len(atoms[0]) == 32 -def test_state_to_multiple_atoms(ar_double_sim_state: ts.SimState) -> None: +def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) assert len(atoms) == 2 @@ -124,29 +124,29 @@ def test_state_to_multiple_atoms(ar_double_sim_state: ts.SimState) -> None: assert len(atoms[1]) == 32 -def test_to_atoms(ar_supercell_sim_state: ts.SimState) -> None: +def test_to_atoms(ar_supercell_sim_state: SimState) -> None: """Test conversion from SimState to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_supercell_sim_state) assert isinstance(atoms[0], Atoms) -def test_to_structures(ar_supercell_sim_state: ts.SimState) -> None: +def test_to_structures(ar_supercell_sim_state: SimState) -> None: """Test conversion from SimState to list of Pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) assert isinstance(structures[0], Structure) -def test_single_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) -> None: +def test_single_phonopy_to_state(si_phonopy_atoms: Any) -> None: """Test conversion from PhonopyAtoms to state tensors.""" - state = ts.io.phonopy_to_state(si_phonopy_atoms, device, torch.float64) + state = ts.io.phonopy_to_state(si_phonopy_atoms, DEVICE, torch.float64) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert all( - t.device.type == device.type for t in [state.positions, state.masses, state.cell] + t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) ) assert all( - t.dtype == torch.float64 for t in [state.positions, state.masses, state.cell] + t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) ) assert state.atomic_numbers.dtype == torch.int @@ -156,18 +156,18 @@ def test_single_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) -> assert torch.all(state.atomic_numbers == 14) # Si atomic number assert torch.allclose( state.cell, - torch.diag(torch.full((3,), 5.43, device=device, dtype=torch.float64)), + torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), ) -def test_multiple_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) -> None: +def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None: """Test conversion from multiple PhonopyAtoms to state tensors.""" state = ts.io.phonopy_to_state( - [si_phonopy_atoms, si_phonopy_atoms], device, torch.float64 + [si_phonopy_atoms, si_phonopy_atoms], DEVICE, torch.float64 ) # Check basic properties - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) @@ -176,11 +176,11 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any, device: torch.device) assert state.system_idx.shape == (16,) assert torch.all( state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=device), 8), + == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), ) -def test_state_to_phonopy(ar_supercell_sim_state: ts.SimState) -> None: +def test_state_to_phonopy(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of PhonopyAtoms.""" phonopy_atoms = ts.io.state_to_phonopy(ar_supercell_sim_state) assert len(phonopy_atoms) == 1 @@ -188,7 +188,7 @@ def test_state_to_phonopy(ar_supercell_sim_state: ts.SimState) -> None: assert len(phonopy_atoms[0]) == 32 -def test_state_to_multiple_phonopy(ar_double_sim_state: ts.SimState) -> None: +def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of PhonopyAtoms.""" phonopy_atoms = ts.io.state_to_phonopy(ar_double_sim_state) assert len(phonopy_atoms) == 2 @@ -220,11 +220,7 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: ts.SimState) -> None: ), ) def test_state_round_trip( - sim_state_name: str, - conversion_functions: tuple, - request: pytest.FixtureRequest, - device: torch.device, - dtype: torch.dtype, + sim_state_name: str, conversion_functions: tuple, request: pytest.FixtureRequest ) -> None: """Test round-trip conversion from SimState through various formats and back. @@ -232,20 +228,18 @@ def test_state_round_trip( sim_state_name: Name of the sim_state fixture to test conversion_functions: Tuple of (to_format, from_format) conversion functions request: Pytest fixture request object to get dynamic fixtures - device: Device to run tests on - dtype: Data type to use """ # Get the sim_state fixture dynamically using the name - sim_state: ts.SimState = request.getfixturevalue(sim_state_name) + sim_state: SimState = request.getfixturevalue(sim_state_name) to_format_fn, from_format_fn = conversion_functions - unique_systems = torch.unique(sim_state.system_idx) + uniq_systems = torch.unique(sim_state.system_idx) # Convert to intermediate format intermediate_format = to_format_fn(sim_state) - assert len(intermediate_format) == len(unique_systems) + assert len(intermediate_format) == len(uniq_systems) # Convert back to state - round_trip_state: ts.SimState = from_format_fn(intermediate_format, device, dtype) + round_trip_state: SimState = from_format_fn(intermediate_format, DEVICE, DTYPE) # Check that all properties match assert torch.allclose(sim_state.positions, round_trip_state.positions) @@ -261,21 +255,17 @@ def test_state_round_trip( assert torch.allclose(sim_state.masses, round_trip_state.masses) -def test_state_to_atoms_importerror( - monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState -) -> None: +def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "ase", None) monkeypatch.setitem(sys.modules, "ase.data", None) with pytest.raises( ImportError, match="ASE is required for state_to_atoms conversion" ): - ts.io.state_to_atoms(si_sim_state) + ts.io.state_to_atoms(None) # type: ignore[arg-type] -def test_state_to_phonopy_importerror( - monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState -) -> None: +def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "phonopy", None) monkeypatch.setitem(sys.modules, "phonopy.structure", None) monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) @@ -283,12 +273,10 @@ def test_state_to_phonopy_importerror( with pytest.raises( ImportError, match="Phonopy is required for state_to_phonopy conversion" ): - ts.io.state_to_phonopy(si_sim_state) + ts.io.state_to_phonopy(None) # type: ignore[arg-type] -def test_state_to_structures_importerror( - monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState -) -> None: +def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "pymatgen", None) monkeypatch.setitem(sys.modules, "pymatgen.core", None) monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) @@ -296,24 +284,20 @@ def test_state_to_structures_importerror( with pytest.raises( ImportError, match="Pymatgen is required for state_to_structures conversion" ): - ts.io.state_to_structures(si_sim_state) + ts.io.state_to_structures(None) # type: ignore[arg-type] -def test_atoms_to_state_importerror( - monkeypatch: pytest.MonkeyPatch, si_atoms: Atoms -) -> None: +def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "ase", None) monkeypatch.setitem(sys.modules, "ase.data", None) with pytest.raises( ImportError, match="ASE is required for atoms_to_state conversion" ): - ts.io.atoms_to_state(si_atoms, torch.device("cpu"), torch.float64) + ts.io.atoms_to_state(None, None, None) # type: ignore[arg-type] -def test_phonopy_to_state_importerror( - monkeypatch: pytest.MonkeyPatch, si_phonopy_atoms: PhonopyAtoms -) -> None: +def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "phonopy", None) monkeypatch.setitem(sys.modules, "phonopy.structure", None) monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) @@ -321,12 +305,10 @@ def test_phonopy_to_state_importerror( with pytest.raises( ImportError, match="Phonopy is required for phonopy_to_state conversion" ): - ts.io.phonopy_to_state(si_phonopy_atoms, torch.device("cpu"), torch.float64) + ts.io.phonopy_to_state(None, None, None) # type: ignore[arg-type] -def test_structures_to_state_importerror( - monkeypatch: pytest.MonkeyPatch, si_structure: Structure -) -> None: +def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "pymatgen", None) monkeypatch.setitem(sys.modules, "pymatgen.core", None) monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) @@ -334,4 +316,4 @@ def test_structures_to_state_importerror( with pytest.raises( ImportError, match="Pymatgen is required for structures_to_state conversion" ): - ts.io.structures_to_state(si_structure, torch.device("cpu"), torch.float64) + ts.io.structures_to_state(None, None, None) # type: ignore[arg-type] diff --git a/tests/test_math.py b/tests/test_math.py index 51f0fe9fc..21f42f114 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,4 +1,5 @@ -"""Tests for the math module.""" +"""Tests for the math module. Adapted from https://github.com/abhijeetgangan/torch_matfunc""" + # ruff: noqa: SLF001 @@ -9,13 +10,11 @@ import torch from numpy.testing import assert_allclose -import torch_sim.math as tsm +import torch_sim.math as fm +from tests.conftest import DTYPE device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -dtype = torch.float64 - -"""Below tests are adapted from https://github.com/abhijeetgangan/torch_matfunc""" class TestExpmFrechet: @@ -26,18 +25,16 @@ def test_expm_frechet(self): M = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], dtype=np.float64 ) - A_np = np.array([[1, 2], [5, 6]], dtype=np.float64) - E_np = np.array([[3, 4], [7, 8]], dtype=np.float64) - expected_expm = scipy.linalg.expm(A_np) + A = np.array([[1, 2], [5, 6]], dtype=np.float64) + E = np.array([[3, 4], [7, 8]], dtype=np.float64) + expected_expm = scipy.linalg.expm(A) expected_frechet = scipy.linalg.expm(M)[:2, 2:] - A = torch.from_numpy(A_np).to(device=device) - E = torch.from_numpy(E_np).to(device=device) - for method in ("SPS", "blockEnlarge"): + A = torch.from_numpy(A).to(device=device) + E = torch.from_numpy(E).to(device=device) + for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}): # Convert it to numpy arrays before passing it to the function - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp( - A, E, method=method - ) + observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs) assert_allclose(expected_expm, observed_expm.cpu().numpy()) assert_allclose(expected_frechet, observed_frechet.cpu().numpy()) @@ -62,10 +59,10 @@ def test_small_norm_expm_frechet(self): E = scale * E_original expected_expm = scipy.linalg.expm(A) expected_frechet = scipy.linalg.expm(M)[:2, 2:] - A = torch.from_numpy(A).to(device=device, dtype=dtype) - E = torch.from_numpy(E).to(device=device, dtype=dtype) + A = torch.from_numpy(A).to(device=device, dtype=DTYPE) + E = torch.from_numpy(E).to(device=device, dtype=DTYPE) # Convert it to numpy arrays before passing it to the function - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp(A, E) + observed_expm, observed_frechet = fm.expm_frechet(A, E) assert_allclose(expected_expm, observed_expm.cpu().numpy()) assert_allclose(expected_frechet, observed_frechet.cpu().numpy()) @@ -93,26 +90,26 @@ def test_fuzz(self): M = np.vstack([np.hstack([A, E]), np.hstack([np.zeros_like(A), A])]) expected_expm = scipy.linalg.expm(A) expected_frechet = scipy.linalg.expm(M)[:n, n:] - A = torch.from_numpy(A).to(device=device, dtype=dtype) - E = torch.from_numpy(E).to(device=device, dtype=dtype) + A = torch.from_numpy(A).to(device=device, dtype=DTYPE) + E = torch.from_numpy(E).to(device=device, dtype=DTYPE) # Convert it to numpy arrays before passing it to the function - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp(A, E) + observed_expm, observed_frechet = fm.expm_frechet(A, E) assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=5e-8) assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-7) def test_problematic_matrix(self): """Test a specific matrix that previously uncovered a bug.""" - A_np = np.array( + A = np.array( [[1.50591997, 1.93537998], [0.41203263, 0.23443516]], dtype=np.float64 ) - E_np = np.array( + E = np.array( [[1.87864034, 2.07055038], [1.34102727, 0.67341123]], dtype=np.float64 ) - A = torch.from_numpy(A_np).to(device=device, dtype=dtype) - E = torch.from_numpy(E_np).to(device=device, dtype=dtype) + A = torch.from_numpy(A).to(device=device, dtype=DTYPE) + E = torch.from_numpy(E).to(device=device, dtype=DTYPE) # Convert it to numpy arrays before passing it to the function - sps_expm, sps_frechet = tsm.expm_frechet_with_matrix_exp(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = tsm.expm_frechet_with_matrix_exp( + sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") + blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A, E, method="blockEnlarge" ) assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy()) @@ -122,14 +119,14 @@ def test_medium_matrix(self): """Test with a medium-sized matrix to compare performance between methods.""" n = 1000 rng = np.random.default_rng() - A_np = rng.exponential(size=(n, n)) - E_np = rng.exponential(size=(n, n)) + A = rng.exponential(size=(n, n)) + E = rng.exponential(size=(n, n)) - A = torch.from_numpy(A_np).to(device=device, dtype=dtype) - E = torch.from_numpy(E_np).to(device=device, dtype=dtype) + A = torch.from_numpy(A).to(device=device, dtype=DTYPE) + E = torch.from_numpy(E).to(device=device, dtype=DTYPE) # Convert it to numpy arrays before passing it to the function - sps_expm, sps_frechet = tsm.expm_frechet_with_matrix_exp(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = tsm.expm_frechet_with_matrix_exp( + sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") + blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A, E, method="blockEnlarge" ) assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy()) @@ -143,18 +140,16 @@ def test_expm_frechet(self): """Test basic functionality of expm_frechet against torch.linalg.matrix_exp.""" M = torch.tensor( [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], - dtype=dtype, + dtype=DTYPE, device=device, ) - A = torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device) - E = torch.tensor([[3, 4], [7, 8]], dtype=dtype, device=device) + A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) + E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) expected_expm = torch.linalg.matrix_exp(A) expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] - for method in ("SPS", "blockEnlarge"): - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp( - A, E, method=method - ) + for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}): + observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs) torch.testing.assert_close(expected_expm, observed_expm) torch.testing.assert_close(expected_frechet, observed_frechet) @@ -167,17 +162,17 @@ def test_small_norm_expm_frechet(self): [0, 0, 1, 2], [0, 0, 5, 6], ], - dtype=dtype, + dtype=DTYPE, device=device, ) - A_original = torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device) - E_original = torch.tensor([[3, 4], [7, 8]], dtype=dtype, device=device) + A_original = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) + E_original = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) A_original_norm_1 = torch.linalg.norm(A_original, 1) selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15] m_neighbor_pairs = itertools.pairwise(selected_m_list) for ma, mb in m_neighbor_pairs: - ell_a = tsm.ell_table_61[ma] - ell_b = tsm.ell_table_61[mb] + ell_a = fm.ell_table_61[ma] + ell_b = fm.ell_table_61[mb] target_norm_1 = 0.5 * (ell_a + ell_b) scale = target_norm_1 / A_original_norm_1 M = scale * M_original @@ -185,7 +180,7 @@ def test_small_norm_expm_frechet(self): E = scale * E_original expected_expm = torch.linalg.matrix_exp(A) expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp(A, E) + observed_expm, observed_frechet = fm.expm_frechet(A, E) torch.testing.assert_close(expected_expm, observed_expm) torch.testing.assert_close(expected_frechet, observed_frechet) @@ -222,7 +217,7 @@ def test_fuzz(self): ) expected_expm = torch.linalg.matrix_exp(A) expected_frechet = torch.linalg.matrix_exp(M)[:n, n:] - observed_expm, observed_frechet = tsm.expm_frechet_with_matrix_exp(A, E) + observed_expm, observed_frechet = fm.expm_frechet(A, E) torch.testing.assert_close(expected_expm, observed_expm, atol=5e-8, rtol=1e-5) torch.testing.assert_close( expected_frechet, observed_frechet, atol=1e-7, rtol=1e-5 @@ -232,16 +227,16 @@ def test_problematic_matrix(self): """Test a specific matrix that previously uncovered a bug using torch tensors.""" A = torch.tensor( [[1.50591997, 1.93537998], [0.41203263, 0.23443516]], - dtype=dtype, + dtype=DTYPE, device=device, ) E = torch.tensor( [[1.87864034, 2.07055038], [1.34102727, 0.67341123]], - dtype=dtype, + dtype=DTYPE, device=device, ) - sps_expm, sps_frechet = tsm.expm_frechet_with_matrix_exp(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = tsm.expm_frechet_with_matrix_exp( + sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") + blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A, E, method="blockEnlarge" ) torch.testing.assert_close(sps_expm, blockEnlarge_expm) @@ -256,8 +251,8 @@ def test_medium_matrix(self): A = torch.tensor(rng.exponential(size=(n, n))) E = torch.tensor(rng.exponential(size=(n, n))) - sps_expm, sps_frechet = tsm.expm_frechet_with_matrix_exp(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = tsm.expm_frechet_with_matrix_exp( + sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") + blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( A, E, method="blockEnlarge" ) torch.testing.assert_close(sps_expm, blockEnlarge_expm) @@ -271,19 +266,19 @@ def test_expm_frechet(self): """Test gradient computation for matrix exponential and its Frechet derivative.""" M = torch.tensor( [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], - dtype=dtype, + dtype=DTYPE, device=device, ) - A = torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device) - E = torch.tensor([[3, 4], [7, 8]], dtype=dtype, device=device) + A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) + E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) expected_expm = torch.linalg.matrix_exp(A) expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] # expm will use the SPS method as default - observed_expm = tsm.expm.apply(A) + observed_expm = fm.expm.apply(A) torch.testing.assert_close(expected_expm, observed_expm) # Compute the Frechet derivative in the direction of grad_output A.requires_grad = True - observed_expm = tsm.expm.apply(A) + observed_expm = fm.expm.apply(A) (observed_frechet,) = torch.autograd.grad(observed_expm, A, E, retain_graph=True) torch.testing.assert_close(expected_frechet, observed_frechet) @@ -323,26 +318,26 @@ def test_logm_33_reference(self): e_val = torch.exp(torch.tensor(1.0)) # e = exp(1) T_1b = torch.tensor( [[e_val, 1.0, 0.0], [0.0, e_val, 0.0], [0.0, 0.0, e_val]], - dtype=dtype, + dtype=DTYPE, device=device, ) # Expected solution: log T = [[1, 1/e, 0], [0, 1, 0], [0, 0, 1]] expected_1b = torch.tensor( [[1.0, 1.0 / e_val, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - dtype=dtype, + dtype=DTYPE, device=device, ) # Compute using our implementation and compare - result_1b = tsm._matrix_log_33(T_1b) + result_1b = fm._matrix_log_33(T_1b) ( torch.testing.assert_close(result_1b, expected_1b, rtol=rtol, atol=atol), f"Case 1b failed: \nExpected:\n{expected_1b}\nGot:\n{result_1b}", ) # Compare with scipy - scipy_result_1b = tsm.matrix_log_scipy(T_1b) + scipy_result_1b = fm.matrix_log_scipy(T_1b) msg = ( f"Case 1b differs from scipy: Expected:\n{scipy_result_1b}\nGot:\n{result_1b}" ) @@ -354,7 +349,7 @@ def test_logm_33_reference(self): # Example: T = [[e, 1, 1], [0, e, 1], [0, 0, e]] T_1c = torch.tensor( [[e_val, 1.0, 1.0], [0.0, e_val, 1.0], [0.0, 0.0, e_val]], - dtype=dtype, + dtype=DTYPE, device=device, ) @@ -365,17 +360,17 @@ def test_logm_33_reference(self): [0.0, 1.0, 1.0 / e_val], [0.0, 0.0, 1.0], ], - dtype=dtype, + dtype=DTYPE, device=device, ) # Compute using our implementation and compare - result_1c = tsm._matrix_log_33(T_1c) + result_1c = fm._matrix_log_33(T_1c) msg = f"Case 1c failed: \nExpected:\n{expected_1c}\nGot:\n{result_1c}" torch.testing.assert_close(result_1c, expected_1c, rtol=rtol, atol=atol, msg=msg) # Compare with scipy - scipy_result_1c = tsm.matrix_log_scipy(T_1c) + scipy_result_1c = fm.matrix_log_scipy(T_1c) msg = ( f"Case 1c differs from scipy: Expected:\n{scipy_result_1c}\nGot:\n{result_1c}" ) @@ -389,7 +384,7 @@ def test_logm_33_reference(self): e_cubed = e_squared * e_val T_2b = torch.tensor( [[e_val, 1.0, 1.0], [0.0, e_squared, 1.0], [0.0, 0.0, e_squared]], - dtype=dtype, + dtype=DTYPE, device=device, ) @@ -405,17 +400,17 @@ def test_logm_33_reference(self): [0.0, 2.0, 1.0 / e_squared], [0.0, 0.0, 2.0], ], - dtype=dtype, + dtype=DTYPE, device=device, ) # Compute using our implementation and compare - result_2b = tsm._matrix_log_33(T_2b) + result_2b = fm._matrix_log_33(T_2b) msg = f"Case 2b failed: \nExpected:\n{expected_2b}\nGot:\n{result_2b}" torch.testing.assert_close(result_2b, expected_2b, rtol=rtol, atol=atol, msg=msg) # Compare with scipy - scipy_result_2b = tsm.matrix_log_scipy(T_2b) + scipy_result_2b = fm.matrix_log_scipy(T_2b) msg = ( f"Case 2b differs from scipy: Expected:\n{scipy_result_2b}\nGot:\n{result_2b}" ) @@ -424,19 +419,19 @@ def test_logm_33_reference(self): ) # Additional test: identity matrix (should return zero matrix) - identity = torch.eye(3, dtype=dtype, device=device) - log_identity = tsm._matrix_log_33(identity) - expected_log_identity = torch.zeros((3, 3), dtype=dtype, device=device) + identity = torch.eye(3, dtype=DTYPE, device=device) + log_identity = fm._matrix_log_33(identity) + expected_log_identity = torch.zeros((3, 3), dtype=DTYPE, device=device) msg = f"log(I) failed: \nExpected:\n{expected_log_identity}\nGot:\n{log_identity}" torch.testing.assert_close( log_identity, expected_log_identity, rtol=rtol, atol=atol, msg=msg ) # Additional test: diagonal matrix with distinct eigenvalues (Case 3) - D = torch.diag(torch.tensor([2.0, 3.0, 4.0], dtype=dtype, device=device)) - log_D = tsm._matrix_log_33(D) + D = torch.diag(torch.tensor([2.0, 3.0, 4.0], dtype=DTYPE, device=device)) + log_D = fm._matrix_log_33(D) expected_log_D = torch.diag( - torch.log(torch.tensor([2.0, 3.0, 4.0], dtype=dtype, device=device)) + torch.log(torch.tensor([2.0, 3.0, 4.0], dtype=DTYPE, device=device)) ) msg = f"log(diag) failed: \nExpected:\n{expected_log_D}\nGot:\n{log_D}" torch.testing.assert_close(log_D, expected_log_D, rtol=rtol, atol=atol, msg=msg) @@ -449,11 +444,11 @@ def test_random_float(self): """ torch.manual_seed(1234) n = 3 - M = torch.randn(n, n, dtype=dtype, device=device) - M_logm = tsm.matrix_log_33(M) + M = torch.randn(n, n, dtype=DTYPE, device=device) + M_logm = fm.matrix_log_33(M) scipy_logm = scipy.linalg.logm(M.cpu().numpy()) torch.testing.assert_close( - M_logm, torch.tensor(scipy_logm, dtype=dtype, device=device) + M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device) ) def test_nearly_degenerate(self): @@ -466,11 +461,11 @@ def test_nearly_degenerate(self): eps = 1e-6 M = torch.tensor( [[1.0, 1.0, eps], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], - dtype=dtype, + dtype=DTYPE, device=device, ) - M_logm = tsm._matrix_log_33(M) + M_logm = fm._matrix_log_33(M) scipy_logm = scipy.linalg.logm(M.cpu().numpy()) torch.testing.assert_close( - M_logm, torch.tensor(scipy_logm, dtype=dtype, device=device) + M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device) ) diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 479552c07..2414e8b5b 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -3,18 +3,21 @@ from pymatgen.core import Structure import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.models.interface import ModelInterface from torch_sim.monte_carlo import ( SwapMCState, generate_swaps, - swap_monte_carlo, + metropolis_criterion, + swap_mc_init, + swap_mc_step, swaps_to_permutation, - validate_permutation, ) @pytest.fixture -def diverse_structure() -> Structure: +def batched_diverse_state() -> ts.SimState: + """Create a batched state with diverse atomic species for testing.""" lattice = [[5.43, 0, 0], [0, 5.43, 0], [0, 0, 5.43]] species = ["H", "He", "Li", "Be", "B", "C", "N", "O"] coords = [ @@ -27,140 +30,174 @@ def diverse_structure() -> Structure: [0.5, 0.5, 0.0], [0.75, 0.75, 0.25], ] - return Structure(lattice, species, coords) + structure = Structure(lattice, species, coords) + return ts.io.structures_to_state([structure] * 2, device=DEVICE, dtype=torch.float64) -@pytest.fixture -def generator(device: torch.device) -> torch.Generator: - generator = torch.Generator(device=device) - generator.manual_seed(42) - return generator - - -@pytest.fixture -def batched_diverse_state( - diverse_structure: Structure, device: torch.device -) -> ts.SimState: - return ts.io.structures_to_state( - [diverse_structure] * 2, device=device, dtype=torch.float64 - ) - +@pytest.mark.parametrize("use_generator", [True, False]) +def test_generate_swaps(batched_diverse_state: ts.SimState, *, use_generator: bool): + """Test swap generation with and without generator.""" + generator = torch.Generator(device=DEVICE) if use_generator else None + if generator: + generator.manual_seed(42) -def test_generate_permutation( - batched_diverse_state: ts.SimState, generator: torch.Generator -): swaps = generate_swaps(batched_diverse_state, generator=generator) - permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation(permutation, batched_diverse_state.system_idx) - -def test_generate_swaps(batched_diverse_state: ts.SimState, generator: torch.Generator): - swaps = generate_swaps(batched_diverse_state, generator=generator) - - # Check shape and type + # Basic validation assert isinstance(swaps, torch.Tensor) assert swaps.shape[1] == 2 - - # Check swaps are within valid range assert torch.all(swaps >= 0) assert torch.all(swaps < batched_diverse_state.n_atoms) - # Check swaps are within same system + # System consistency system_idx = batched_diverse_state.system_idx assert torch.all(system_idx[swaps[:, 0]] == system_idx[swaps[:, 1]]) + # Different atomic numbers + atomic_numbers = batched_diverse_state.atomic_numbers + for swap in swaps: + assert atomic_numbers[swap[0]] != atomic_numbers[swap[1]] -def test_swaps_to_permutation( - batched_diverse_state: ts.SimState, generator: torch.Generator -): - swaps = generate_swaps(batched_diverse_state, generator=generator) + # Test reproducibility with generator + if use_generator and generator is not None: + generator.manual_seed(42) + swaps2 = generate_swaps(batched_diverse_state, generator=generator) + assert torch.equal(swaps, swaps2) + + +@pytest.mark.parametrize("n_swaps", [0, 1, 3]) +def test_swaps_to_permutation(batched_diverse_state: ts.SimState, *, n_swaps: int): + """Test permutation generation with different numbers of swaps.""" n_atoms = batched_diverse_state.n_atoms - permutation = swaps_to_permutation(swaps, n_atoms) + generator = torch.Generator(device=DEVICE) + generator.manual_seed(42) + + if n_swaps == 0: + combined_swaps = torch.empty((0, 2), dtype=torch.long, device=DEVICE) + else: + all_swaps = [ + generate_swaps(batched_diverse_state, generator=generator) + for _ in range(n_swaps) + ] + combined_swaps = torch.cat(all_swaps, dim=0) + + permutation = swaps_to_permutation(combined_swaps, n_atoms) - # Check shape and type + # Validation assert isinstance(permutation, torch.Tensor) assert permutation.shape == (n_atoms,) + expected_range = torch.arange(n_atoms, device=permutation.device) + assert torch.sort(permutation)[0].equal(expected_range) + + # Test permutation preserves system assignments + original_system = batched_diverse_state.system_idx + assert torch.all(original_system == original_system[permutation]) + + +@pytest.mark.parametrize( + ("energy_old", "energy_new", "kT", "expected_rate"), + [ + ([10.0, 20.0], [5.0, 15.0], 1.0, 1.0), # Energy decreases + ([5.0, 15.0], [25.0, 35.0], 0.1, 0.0), # Energy increases significantly + ([10.0, 20.0], [10.0, 20.0], 1.0, 1.0), # Energy stays same + ([10.0, 20.0], [15.0, 25.0], 1000.0, 1.0), # Very high temperature + ([10.0, 20.0], [15.0, 25.0], 0.001, 0.0), # Very low temperature + ], +) +def test_metropolis_criterion( + *, + energy_old: list[float], + energy_new: list[float], + kT: float, + expected_rate: float, +): + """Test metropolis criterion with different energy scenarios.""" + energy_old_tensor = torch.tensor(energy_old, device=DEVICE) + energy_new_tensor = torch.tensor(energy_new, device=DEVICE) + + if expected_rate in [0.0, 1.0]: + # Deterministic cases + accepted = metropolis_criterion(energy_new_tensor, energy_old_tensor, kT) + actual_rate = accepted.float().mean().item() + assert abs(actual_rate - expected_rate) < 0.1 + else: + # Statistical test + generator = torch.Generator(device=DEVICE) + generator.manual_seed(42) + total_accepted = sum( + metropolis_criterion( + energy_new_tensor, energy_old_tensor, kT, generator=generator + ) + .sum() + .item() + for _ in range(1000) + ) + actual_rate = total_accepted / (1000 * len(energy_old)) + assert abs(actual_rate - expected_rate) < 0.15 - # Check permutation contains all indices - assert torch.sort(permutation)[0].equal( - torch.arange(n_atoms, device=permutation.device) - ) - - # Check swapped pairs - for i, j in swaps: - assert permutation[i] == j - assert permutation[j] == i +def test_metropolis_criterion_randomness(): + """Test that different generators produce different results.""" + energy_old = torch.tensor([10.0, 20.0], device=DEVICE) + energy_new = torch.tensor([11.0, 21.0], device=DEVICE) # ~37% acceptance -def test_validate_permutation(batched_diverse_state: ts.SimState): - # Valid permutation - swaps = generate_swaps(batched_diverse_state) - permutation = swaps_to_permutation(swaps, batched_diverse_state.n_atoms) - validate_permutation( - permutation, batched_diverse_state.system_idx - ) # Should not raise + gen1 = torch.Generator(device=DEVICE) + gen1.manual_seed(42) + gen2 = torch.Generator(device=DEVICE) + gen2.manual_seed(43) - # Invalid permutation (swap between batches) - invalid_perm = permutation.clone() - if batched_diverse_state.n_atoms > 2: - # Swap first atom with last atom (different batches) - invalid_perm[0] = batched_diverse_state.n_atoms - 1 - invalid_perm[batched_diverse_state.n_atoms - 1] = 0 + accepted1 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=gen1) + accepted2 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=gen2) + accepted3 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=None) - with pytest.raises(ValueError, match="Swaps must be between"): - validate_permutation(invalid_perm, batched_diverse_state.system_idx) + different_results = not torch.equal(accepted1, accepted2) or not torch.equal( + accepted1, accepted3 + ) + assert different_results -def test_monte_carlo( +@pytest.mark.parametrize(("kT", "n_steps"), [(0.1, 3), (1.0, 5), (10.0, 2)]) +def test_monte_carlo_integration( batched_diverse_state: ts.SimState, lj_model: ModelInterface, + *, + kT: float, + n_steps: int, ): - """Test the monte_carlo function that returns a step function and initial state.""" - # Call monte_carlo to get the initial state and step function - init_state_fn, monte_carlo_step_fn = swap_monte_carlo(model=lj_model, kT=1.0, seed=42) - initial_state = init_state_fn(batched_diverse_state) - - # Verify the returned values - assert isinstance(initial_state, SwapMCState) - assert callable(monte_carlo_step_fn) - - # Verify the initial state has the expected attributes - assert hasattr(initial_state, "energy") - assert hasattr(initial_state, "last_permutation") - - # Make a copy of the initial state for comparison - initial_positions = initial_state.positions.clone() - - # Get the current state - current_state = initial_state - - # Run multiple Monte Carlo steps - n_steps = 5 + """Test the complete Monte Carlo workflow.""" + # Initialize + mc_state = swap_mc_init(model=lj_model, state=batched_diverse_state) + assert isinstance(mc_state, SwapMCState) + assert mc_state.energy.shape == (batched_diverse_state.n_systems,) + assert mc_state.last_permutation.shape == (batched_diverse_state.n_atoms,) + expected_identity = torch.arange(batched_diverse_state.n_atoms, device=DEVICE) + assert torch.equal(mc_state.last_permutation, expected_identity) + + # Run steps for step in range(n_steps): - # Create a new generator for each step - step_generator = torch.Generator(device=batched_diverse_state.device) - step_generator.manual_seed(42 + step + 1) # Different seed for each step - - # Run a Monte Carlo step - current_state = monte_carlo_step_fn(current_state, generator=step_generator) - - # Verify the state is an MCState - assert isinstance(current_state, SwapMCState) - - # Verify the state has changed after multiple steps - assert not torch.allclose(current_state.positions, initial_positions) - - # Verify system_idx assignments remain unchanged - assert torch.all(current_state.system_idx == batched_diverse_state.system_idx) + mc_state = swap_mc_step(model=lj_model, state=mc_state, kT=kT, seed=42 + step) + assert isinstance(mc_state, SwapMCState) + + # Verify conservation properties + assert torch.all(mc_state.system_idx == batched_diverse_state.system_idx) + for sys_idx in torch.unique(mc_state.system_idx): + orig_mask = batched_diverse_state.system_idx == sys_idx + result_mask = mc_state.system_idx == sys_idx + orig_counts = torch.bincount(batched_diverse_state.atomic_numbers[orig_mask]) + result_counts = torch.bincount(mc_state.atomic_numbers[result_mask]) + assert torch.all(orig_counts == result_counts) - # Verify atomic numbers distribution remains the same per system - for idx in torch.unique(current_state.system_idx): - system_mask_orig = batched_diverse_state.system_idx == idx - system_mask_result = current_state.system_idx == idx - orig_counts = torch.bincount( - batched_diverse_state.atomic_numbers[system_mask_orig] - ) - result_counts = torch.bincount(current_state.atomic_numbers[system_mask_result]) +def test_swap_mc_state_attributes(): + """Test SwapMCState class structure and inheritance.""" + from torch_sim.state import SimState - assert torch.all(orig_counts == result_counts) + assert issubclass(SwapMCState, SimState) + assert "last_permutation" in SwapMCState._atom_attributes # noqa: SLF001 + assert "energy" in SwapMCState._system_attributes # noqa: SLF001 + atom_attrs = SwapMCState._atom_attributes # noqa: SLF001 + system_attrs = SwapMCState._system_attributes # noqa: SLF001 + parent_atom_attrs = SimState._atom_attributes # noqa: SLF001 + parent_system_attrs = SimState._system_attributes # noqa: SLF001 + assert atom_attrs >= parent_atom_attrs + assert system_attrs >= parent_system_attrs diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index dbf625101..8a1e0f7a5 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -1,5 +1,6 @@ import time from collections.abc import Callable +from typing import Any, cast import numpy as np import psutil @@ -9,18 +10,12 @@ from ase.build import bulk, molecule from ase.neighborlist import neighbor_list +from tests.conftest import DEVICE, DTYPE from torch_sim import neighbors, transforms -@pytest.fixture -def dtype() -> torch.dtype: - return torch.float64 - - def ase_to_torch_batch( - atoms_list: list[Atoms], - device: torch.device, - dtype: torch.dtype = torch.float32, + atoms_list: list[Atoms], device: torch.device, dtype: torch.dtype = torch.float32 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Convert a list of ASE Atoms objects into tensors suitable for PyTorch. @@ -105,7 +100,7 @@ def ase_to_torch_batch( @pytest.fixture -def periodic_atoms_set() -> list[Atoms]: +def periodic_atoms_set(): return [ bulk("Si", "diamond", a=6, cubic=True), bulk("Si", "diamond", a=6), @@ -124,7 +119,7 @@ def periodic_atoms_set() -> list[Atoms]: @pytest.fixture -def molecule_atoms_set() -> list[Atoms]: +def molecule_atoms_set() -> list: return [ *map(molecule, ("CH3CH2NH2", "H2O", "methylenecyclopropane", "OCHCHO", "C3H9C")), ] @@ -132,28 +127,19 @@ def molecule_atoms_set() -> list[Atoms]: @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize( - "atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"] -) +@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"]) def test_primitive_neighbor_list( - *, - cutoff: float, - atoms_list_fixture: str, - device: torch.device, - dtype: torch.dtype, - use_jit: bool, - request: pytest.FixtureRequest, + *, cutoff: float, atoms_list: str, use_jit: bool, request: pytest.FixtureRequest ) -> None: """Check that primitive_neighbor_list gives the same NL as ASE by comparing the resulting sorted list of distances between neighbors. Args: cutoff: Cutoff distance for neighbor search - device: Torch device to use - dtype: Torch dtype to use + atoms_list: List of atoms to test use_jit: Whether to use the jitted version or disable JIT """ - atoms_list = request.getfixturevalue(atoms_list_fixture) + atoms_list = request.getfixturevalue(atoms_list) # Create a non-jitted version of the function if requested if use_jit: @@ -168,10 +154,10 @@ def test_primitive_neighbor_list( # Import the function again to get the non-jitted version from importlib import reload - import torch_sim.neighbors + import torch_sim as ts - reload(torch_sim.neighbors) - neighbor_list_fn = torch_sim.neighbors.primitive_neighbor_list + reload(ts.neighbors) + neighbor_list_fn = ts.neighbors.primitive_neighbor_list # Restore JIT setting after test if old_jit_setting is not None: @@ -181,8 +167,8 @@ def test_primitive_neighbor_list( for atoms in atoms_list: # Convert to torch tensors - pos = torch.tensor(atoms.positions, device=device, dtype=dtype) - row_vector_cell = torch.tensor(atoms.cell.array, device=device, dtype=dtype) + pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) + row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) pbc = atoms.pbc.any() @@ -193,9 +179,9 @@ def test_primitive_neighbor_list( positions=pos, cell=row_vector_cell, pbc=(pbc, pbc, pbc), - cutoff=torch.tensor(cutoff, dtype=dtype, device=device), - device=device, - dtype=dtype, + cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), + device=DEVICE, + dtype=DTYPE, self_interaction=False, use_scaled_positions=False, max_n_bins=int(1e6), @@ -205,14 +191,14 @@ def test_primitive_neighbor_list( mapping = torch.stack((idx_i, idx_j), dim=0) # Convert shifts_tensor to the same dtype as cell before matrix multiplication - shifts_tensor = shifts_tensor.to(dtype=dtype) + shifts_tensor = shifts_tensor.to(dtype=DTYPE) # Calculate distances with cell shifts cell_shifts_prim = torch.mm(shifts_tensor, row_vector_cell) dds_prim = transforms.compute_distances_with_cell_shifts( pos, mapping, cell_shifts_prim ) - dds_prim_sorted = np.sort(dds_prim.numpy()) + dds_prim = np.sort(dds_prim.numpy()) # Get the neighbor list from ase idx_i_ref, idx_j_ref, shifts_ref, dist_ref = neighbor_list( @@ -224,12 +210,12 @@ def test_primitive_neighbor_list( ) # Convert to torch tensors - idx_i_ref = torch.tensor(idx_i_ref, dtype=torch.long, device=device) - idx_j_ref = torch.tensor(idx_j_ref, dtype=torch.long, device=device) + idx_i_ref = torch.tensor(idx_i_ref, dtype=torch.long, device=DEVICE) + idx_j_ref = torch.tensor(idx_j_ref, dtype=torch.long, device=DEVICE) # Create mapping and shifts mapping_ref = torch.stack((idx_i_ref, idx_j_ref), dim=0) - shifts_ref = torch.tensor(shifts_ref, dtype=dtype, device=device) + shifts_ref = torch.tensor(shifts_ref, dtype=DTYPE, device=DEVICE) # Calculate distances with cell shifts cell_shifts_ref = torch.mm(shifts_ref, row_vector_cell) @@ -238,22 +224,20 @@ def test_primitive_neighbor_list( ) # Sort the distances - dds_ref_sorted = np.sort(dds_ref.numpy()) - dist_ref_sorted = np.sort(dist_ref) + dds_ref = np.sort(dds_ref.numpy()) + dist_ref = np.sort(dist_ref) - # Check that the distances are the same with ase and torchsim logic - np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted) + # Check that the distances are the same with ase and TorchSim logic + np.testing.assert_allclose(dds_ref, dist_ref) # Check that the primitive_neighbor_list distances match ASE's np.testing.assert_allclose( - dds_prim_sorted, dist_ref_sorted, err_msg=f"Failed with use_jit={use_jit}" + dds_prim, dist_ref, err_msg=f"Failed with use_jit={use_jit}" ) @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) -@pytest.mark.parametrize( - "atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"] -) +@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"]) @pytest.mark.parametrize( "nl_implementation", [neighbors.standard_nl, neighbors.vesin_nl, neighbors.vesin_nl_ts], @@ -261,21 +245,19 @@ def test_primitive_neighbor_list( def test_neighbor_list_implementations( *, cutoff: float, - atoms_list_fixture: str, - nl_implementation: Callable, - device: torch.device, - dtype: torch.dtype, + atoms_list: str, + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor]], request: pytest.FixtureRequest, ) -> None: """Check that different neighbor list implementations give the same results as ASE by comparing the resulting sorted list of distances between neighbors. """ - atoms_list = request.getfixturevalue(atoms_list_fixture) + atoms_list = request.getfixturevalue(atoms_list) for atoms in atoms_list: # Convert to torch tensors - pos = torch.tensor(atoms.positions, device=device, dtype=dtype) - row_vector_cell = torch.tensor(atoms.cell.array, device=device, dtype=dtype) + pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) + row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) pbc = atoms.pbc.any() # Get the neighbor list from the implementation being tested @@ -283,13 +265,13 @@ def test_neighbor_list_implementations( positions=pos, cell=row_vector_cell, pbc=pbc, - cutoff=torch.tensor(cutoff, dtype=dtype, device=device), + cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), ) # Calculate distances with cell shifts cell_shifts = torch.mm(shifts, row_vector_cell) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) - dds_sorted = np.sort(dds.numpy()) + dds = np.sort(dds.numpy()) # Get the reference neighbor list from ASE idx_i, idx_j, shifts_ref, dist = neighbor_list( @@ -301,23 +283,21 @@ def test_neighbor_list_implementations( ) # Convert to torch tensors and calculate reference distances - idx_i = torch.tensor(idx_i, dtype=torch.long, device=torch.device("cpu")) - idx_j = torch.tensor(idx_j, dtype=torch.long, device=torch.device("cpu")) + idx_i = torch.tensor(idx_i, dtype=torch.long, device=DEVICE) + idx_j = torch.tensor(idx_j, dtype=torch.long, device=DEVICE) mapping_ref = torch.stack((idx_i, idx_j), dim=0) - shifts_ref = torch.tensor( - shifts_ref, dtype=torch.float64, device=torch.device("cpu") - ) + shifts_ref = torch.tensor(shifts_ref, dtype=torch.float64, device=DEVICE) cell_shifts_ref = torch.mm(shifts_ref, row_vector_cell) dds_ref = transforms.compute_distances_with_cell_shifts( pos, mapping_ref, cell_shifts_ref ) - dds_ref_sorted = np.sort(dds_ref.numpy()) - dist_ref_sorted = np.sort(dist) + dds_ref = np.sort(dds_ref.numpy()) + dist_ref = np.sort(dist) # Verify results - np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted) - np.testing.assert_allclose(dds_sorted, dds_ref_sorted) - np.testing.assert_allclose(dds_sorted, dist_ref_sorted) + np.testing.assert_allclose(dds_ref, dist_ref) + np.testing.assert_allclose(dds, dds_ref) + np.testing.assert_allclose(dds, dist_ref) @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) @@ -330,9 +310,7 @@ def test_torch_nl_implementations( *, cutoff: float, self_interaction: bool, - nl_implementation: Callable, - device: torch.device, - dtype: torch.dtype, + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], molecule_atoms_set: list[Atoms], periodic_atoms_set: list[Atoms], ) -> None: @@ -343,12 +321,17 @@ def test_torch_nl_implementations( # NOTE we can't use atoms_to_state here because we want to test mixed # periodic and non-periodic systems pos, row_vector_cell, pbc, batch, _ = ase_to_torch_batch( - atoms_list, device=device, dtype=dtype + atoms_list, device=DEVICE, dtype=DTYPE ) # Get the neighbor list from the implementation being tested mapping, mapping_system, shifts_idx = nl_implementation( - cutoff, pos, row_vector_cell, pbc, batch, self_interaction + cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), + positions=pos, + cell=row_vector_cell, + pbc=pbc, + system_idx=batch, + self_interaction=self_interaction, ) # Calculate distances @@ -356,7 +339,7 @@ def test_torch_nl_implementations( row_vector_cell, shifts_idx, mapping_system ) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) - dds_sorted = np.sort(dds.numpy()) + dds = np.sort(dds.numpy()) # Get reference results from ASE dd_ref = [] @@ -369,62 +352,56 @@ def test_torch_nl_implementations( max_nbins=1e6, ) dd_ref.extend(dist) - dd_ref_sorted = np.sort(dd_ref) + dd_ref = np.sort(dd_ref) # Verify results - np.testing.assert_allclose(dd_ref_sorted, dds_sorted) + np.testing.assert_allclose(dd_ref, dds) -def test_primitive_neighbor_list_edge_cases( - device: torch.device, - dtype: torch.dtype, -) -> None: +def test_primitive_neighbor_list_edge_cases() -> None: """Test edge cases for primitive_neighbor_list.""" # Test different PBC combinations - pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=device, dtype=dtype) - cell = torch.eye(3, device=device, dtype=dtype) * 2.0 - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 + cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) # Test all PBC combinations for pbc in [(True, False, False), (False, True, False), (False, False, True)]: - idx_i, idx_j, shifts = neighbors.primitive_neighbor_list( + idx_i, idx_j, _shifts = neighbors.primitive_neighbor_list( quantities="ijS", positions=pos, cell=cell, pbc=pbc, cutoff=cutoff, - device=device, - dtype=dtype, + device=DEVICE, + dtype=DTYPE, ) assert len(idx_i) > 0 # Should find at least one neighbor # Test self-interaction - idx_i, idx_j, shifts = neighbors.primitive_neighbor_list( + idx_i, idx_j, _shifts = neighbors.primitive_neighbor_list( quantities="ijS", positions=pos, cell=cell, pbc=(True, True, True), cutoff=cutoff, - device=device, - dtype=dtype, + device=DEVICE, + dtype=DTYPE, self_interaction=True, ) # Should find self-interactions assert torch.any(idx_i == idx_j) -def test_standard_nl_edge_cases( - device: torch.device, - dtype: torch.dtype, -) -> None: +def test_standard_nl_edge_cases() -> None: """Test edge cases for standard_nl.""" - pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=device, dtype=dtype) - cell = torch.eye(3, device=device, dtype=dtype) * 2.0 - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 + cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) # Test different PBC combinations for pbc in (True, False): - mapping, shifts = neighbors.standard_nl( + mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, pbc=pbc, @@ -433,7 +410,7 @@ def test_standard_nl_edge_cases( assert len(mapping[0]) > 0 # Should find neighbors # Test sort_id - mapping, shifts = neighbors.standard_nl( + mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, pbc=True, @@ -444,34 +421,22 @@ def test_standard_nl_edge_cases( assert torch.all(mapping[0][1:] >= mapping[0][:-1]) -def test_vesin_nl_edge_cases( - device: torch.device, - dtype: torch.dtype, -) -> None: +def test_vesin_nl_edge_cases() -> None: """Test edge cases for vesin_nl and vesin_nl_ts.""" - pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=device, dtype=dtype) - cell = torch.eye(3, device=device, dtype=dtype) * 2.0 - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 + cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) # Test both implementations for nl_fn in (neighbors.vesin_nl, neighbors.vesin_nl_ts): # Test different PBC combinations for pbc in (True, False): - mapping, shifts = nl_fn( - positions=pos, - cell=cell, - pbc=pbc, - cutoff=cutoff, - ) + mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=pbc, cutoff=cutoff) assert len(mapping[0]) > 0 # Should find neighbors # Test sort_id - mapping, shifts = nl_fn( - positions=pos, - cell=cell, - pbc=True, - cutoff=cutoff, - sort_id=True, + mapping, _shifts = nl_fn( + positions=pos, cell=cell, pbc=True, cutoff=cutoff, sort_id=True ) # Check if indices are sorted assert torch.all(mapping[0][1:] >= mapping[0][:-1]) @@ -480,31 +445,24 @@ def test_vesin_nl_edge_cases( if nl_fn == neighbors.vesin_nl: # vesin_nl_ts doesn't support float32 pos_f32 = pos.to(dtype=torch.float32) cell_f32 = cell.to(dtype=torch.float32) - cutoff_f32 = cutoff.to(dtype=torch.float32) - mapping, shifts = nl_fn( - positions=pos_f32, - cell=cell_f32, - pbc=True, - cutoff=cutoff_f32, + mapping, _shifts = nl_fn( + positions=pos_f32, cell=cell_f32, pbc=True, cutoff=cutoff ) assert len(mapping[0]) > 0 # Should find neighbors -def test_strict_nl_edge_cases( - device: torch.device, - dtype: torch.dtype, -) -> None: +def test_strict_nl_edge_cases() -> None: """Test edge cases for strict_nl.""" - pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=device, dtype=dtype) + pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) # Create a cell tensor for each batch - cell = torch.eye(3, device=device, dtype=torch.long).repeat(2, 1, 1) * 2 + cell = torch.eye(3, device=DEVICE, dtype=torch.long).repeat(2, 1, 1) * 2 # Test with no cell shifts - mapping = torch.tensor([[0], [1]], device=device, dtype=torch.long) - system_mapping = torch.tensor([0], device=device, dtype=torch.long) - shifts_idx = torch.zeros((1, 3), device=device, dtype=torch.long) + mapping = torch.tensor([[0], [1]], device=DEVICE, dtype=torch.long) + system_mapping = torch.tensor([0], device=DEVICE, dtype=torch.long) + shifts_idx = torch.zeros((1, 3), device=DEVICE, dtype=torch.long) - new_mapping, new_batch, new_shifts = neighbors.strict_nl( + new_mapping, _new_batch, _new_shifts = neighbors.strict_nl( cutoff=1.5, positions=pos, cell=cell, @@ -515,11 +473,11 @@ def test_strict_nl_edge_cases( assert len(new_mapping[0]) > 0 # Should find neighbors # Test with different batch mappings - mapping = torch.tensor([[0, 1], [1, 0]], device=device, dtype=torch.long) - system_mapping = torch.tensor([0, 1], device=device, dtype=torch.long) - shifts_idx = torch.zeros((2, 3), device=device, dtype=torch.long) + mapping = torch.tensor([[0, 1], [1, 0]], device=DEVICE, dtype=torch.long) + system_mapping = torch.tensor([0, 1], device=DEVICE, dtype=torch.long) + shifts_idx = torch.zeros((2, 3), device=DEVICE, dtype=torch.long) - new_mapping, new_batch, new_shifts = neighbors.strict_nl( + new_mapping, _new_batch, _new_shifts = neighbors.strict_nl( cutoff=1.5, positions=pos, cell=cell, @@ -530,51 +488,47 @@ def test_strict_nl_edge_cases( assert len(new_mapping[0]) > 0 # Should find neighbors -def test_neighbor_lists_time_and_memory( - device: torch.device, - dtype: torch.dtype, -) -> None: +def test_neighbor_lists_time_and_memory() -> None: """Test performance and memory characteristics of neighbor list implementations.""" # Create a smaller system to reduce memory usage n_atoms = 100 - pos = torch.rand(n_atoms, 3, device=device, dtype=dtype) - cell = torch.eye(3, device=device, dtype=dtype) * 10.0 - cutoff = torch.tensor(2.0, device=device, dtype=dtype) + pos = torch.rand(n_atoms, 3, device=DEVICE, dtype=DTYPE) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 10.0 + cutoff = torch.tensor(2.0, device=DEVICE, dtype=DTYPE) # Test different implementations for nl_fn in ( neighbors.standard_nl, - neighbors.vesin_nl, neighbors.vesin_nl_ts, neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell, + cast("Callable[..., Any]", neighbors.vesin_nl), ): # Get initial memory usage process = psutil.Process() initial_cpu_memory = process.memory_info().rss # in bytes - if device.type == "cuda": + if DEVICE.type == "cuda": torch.cuda.reset_peak_memory_stats() initial_gpu_memory = torch.cuda.memory_allocated() # Time the execution start_time = time.perf_counter() - if nl_fn in [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell]: - system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device) + if nl_fn in (neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell): + system_idx = torch.zeros(n_atoms, dtype=torch.long, device=DEVICE) # Fix pbc tensor shape - pbc = torch.tensor([[True, True, True]], device=device) - mapping, mapping_system, shifts_idx = nl_fn( - cutoff=cutoff, + pbc = torch.tensor([[True, True, True]], device=DEVICE) + _mapping, _mapping_system, _shifts_idx = nl_fn( positions=pos, cell=cell, - # TODO: standardize all pbc so we either use tensors/booleans/tuples. - pbc=pbc, # type: ignore[arg-type] + pbc=pbc, + cutoff=cutoff, system_idx=system_idx, - self_interaction=False, # type: ignore[call-arg, misc] + self_interaction=False, ) else: - mapping, shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) + _mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) end_time = time.perf_counter() execution_time = end_time - start_time @@ -585,7 +539,7 @@ def test_neighbor_lists_time_and_memory( fn_name = str(nl_fn) # Warning: cuda case was never tested, to be tweaked later - if device.type == "cuda": + if DEVICE.type == "cuda": final_gpu_memory = torch.cuda.memory_allocated() gpu_memory_used = final_gpu_memory - initial_gpu_memory assert execution_time < 0.01, f"{fn_name} took too long: {execution_time}s" diff --git a/tests/test_optimizer_states.py b/tests/test_optimizer_states.py new file mode 100644 index 000000000..bdd002e4b --- /dev/null +++ b/tests/test_optimizer_states.py @@ -0,0 +1,59 @@ +"""Unit tests for optimizer state classes.""" + +import pytest +import torch + +from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.state import SimState + + +@pytest.fixture +def sim_state() -> SimState: + """Basic SimState for testing.""" + return SimState( + positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], dtype=torch.float64), + masses=torch.tensor([1.0, 2.0], dtype=torch.float64), + cell=torch.eye(3, dtype=torch.float64).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 6], dtype=torch.int64), + system_idx=torch.zeros(2, dtype=torch.int64), + ) + + +@pytest.fixture +def optim_data() -> dict: + """Optimizer state data.""" + return { + "forces": torch.tensor( + [[0.1, -0.2, 0.3], [-0.1, 0.2, -0.3]], dtype=torch.float64 + ), + "energy": torch.tensor([1.5], dtype=torch.float64), + "stress": torch.zeros(1, 3, 3, dtype=torch.float64), + } + + +def test_optim_state_init(sim_state: SimState, optim_data: dict) -> None: + """Test OptimState initialization.""" + state = OptimState(**vars(sim_state), **optim_data) + assert torch.equal(state.forces, optim_data["forces"]) + assert torch.equal(state.energy, optim_data["energy"]) + assert torch.equal(state.stress, optim_data["stress"]) + + +def test_fire_state_custom_values(sim_state: SimState, optim_data: dict) -> None: + """Test FireState with custom values.""" + fire_data = { + "velocities": torch.tensor( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64 + ), + "dt": torch.tensor([0.01], dtype=torch.float64), + "alpha": torch.tensor([0.1], dtype=torch.float64), + "n_pos": torch.tensor([5], dtype=torch.int32), + } + + state = FireState(**vars(sim_state), **optim_data, **fire_data) + + assert torch.equal(state.velocities, fire_data["velocities"]) + assert torch.equal(state.dt, fire_data["dt"]) + assert torch.equal(state.alpha, fire_data["alpha"]) + assert torch.equal(state.n_pos, fire_data["n_pos"]) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 141bf5ee1..a4bce4fcc 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,30 +1,20 @@ import copy +from collections.abc import Callable from dataclasses import fields -from typing import get_args +from functools import partial +from typing import Any, get_args import pytest import torch import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import ( - FireState, - FrechetCellFIREState, - GDState, - MdFlavor, - UnitCellFireState, - UnitCellGDState, - fire, - frechet_cell_fire, - gradient_descent, - unit_cell_fire, - unit_cell_gradient_descent, -) -from torch_sim.state import concatenate_states +from torch_sim.optimizers import FireState, MdFlavor, OptimState +from torch_sim.state import SimState def test_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -37,14 +27,14 @@ def test_gradient_descent_optimization( initial_state = ar_supercell_sim_state # Initialize Gradient Descent optimizer - init_fn, update_fn = gradient_descent(model=lj_model, lr=0.01) - - state = init_fn(ar_supercell_sim_state) + state = ts.gradient_descent_init( + model=lj_model, state=ar_supercell_sim_state, lr=0.01 + ) # Run optimization for a few steps energies = [1000, state.energy.item()] while abs(energies[-2] - energies[-1]) > 1e-6: - state = update_fn(state) + state = ts.gradient_descent_step(model=lj_model, state=state, pos_lr=0.01) energies.append(state.energy.item()) energies = energies[1:] @@ -63,7 +53,7 @@ def test_gradient_descent_optimization( def test_unit_cell_gradient_descent_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Test that the Gradient Descent optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -75,19 +65,17 @@ def test_unit_cell_gradient_descent_optimization( ar_supercell_sim_state.positions = perturbed_positions initial_state = ar_supercell_sim_state - # Initialize Gradient Descent optimizer - init_fn, update_fn = unit_cell_gradient_descent( - model=lj_model, - positions_lr=0.01, - cell_lr=0.1, + # Initialize Gradient Descent optimizer with unit cell filter + state = ts.gradient_descent_init( + model=lj_model, state=ar_supercell_sim_state, cell_filter=ts.CellFilter.unit ) - state = init_fn(ar_supercell_sim_state) - # Run optimization for a few steps energies = [1000, state.energy.item()] while abs(energies[-2] - energies[-1]) > 1e-6: - state = update_fn(state) + state = ts.gradient_descent_step( + model=lj_model, state=state, pos_lr=0.01, cell_lr=0.1 + ) energies.append(state.energy.item()) energies = energies[1:] @@ -112,7 +100,7 @@ def test_unit_cell_gradient_descent_optimization( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor + ar_supercell_sim_state: SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the FIRE optimizer actually minimizes energy.""" # Add some random displacement to positions @@ -123,7 +111,7 @@ def test_fire_optimization( + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 ) - current_sim_state = ts.SimState( + current_sim_state = SimState( positions=current_positions, masses=ar_supercell_sim_state.masses.clone(), cell=ar_supercell_sim_state.cell.clone(), @@ -135,21 +123,14 @@ def test_fire_optimization( initial_state_positions = current_sim_state.positions.clone() # Initialize FIRE optimizer - init_fn, update_fn = fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - md_flavor=md_flavor, - ) - - state = init_fn(current_sim_state) + state = ts.fire_init(lj_model, current_sim_state, md_flavor=md_flavor, dt_start=0.1) # Run optimization for a few steps energies = [1000, state.energy.item()] max_steps = 1000 # Add max step to prevent infinite loop steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = update_fn(state) + state = ts.fire_step(lj_model, state, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -180,59 +161,65 @@ def test_fire_optimization( @pytest.mark.parametrize( ("optimizer_fn", "expected_state_type"), - [(fire, FireState), (gradient_descent, GDState)], + [ + (ts.OptimFlavor.fire, FireState), + (ts.OptimFlavor.gradient_descent, OptimState), + ], ) def test_simple_optimizer_init_with_dict( - optimizer_fn: callable, - expected_state_type: type, - ar_supercell_sim_state: ts.SimState, + optimizer_fn: ts.OptimFlavor, + expected_state_type: FireState | OptimState, + ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: - """Test simple optimizer init_fn with a ts.SimState dictionary.""" + """Test simple optimizer init_fn with a SimState dictionary.""" state_dict = { - f.name: getattr(ar_supercell_sim_state, f.name) - for f in fields(ar_supercell_sim_state) + field.name: getattr(ar_supercell_sim_state, field.name) + for field in fields(ar_supercell_sim_state) } - init_fn, _ = optimizer_fn(model=lj_model) - opt_state = init_fn(state_dict) + init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn] + opt_state = init_fn(model=lj_model, state=state_dict) assert isinstance(opt_state, expected_state_type) assert opt_state.energy is not None assert opt_state.forces is not None -@pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) +@pytest.mark.parametrize( + "optim_func", + [ts.fire_init, partial(ts.fire_init, cell_filter=ts.CellFilter.unit)], +) def test_optimizer_invalid_md_flavor( - optimizer_func: callable, lj_model: ModelInterface + optim_func: Callable[..., Any], + lj_model: ModelInterface, + ar_supercell_sim_state: SimState, ) -> None: """Test optimizer with an invalid md_flavor raises ValueError.""" with pytest.raises(ValueError, match="Unknown md_flavor"): - optimizer_func(model=lj_model, md_flavor="invalid_flavor") + optim_func( + model=lj_model, state=ar_supercell_sim_state, md_flavor="invalid_flavor" + ) def test_fire_ase_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Test that the ASE FIRE P<0 branch behaves as expected.""" f_dec = 0.5 # Default from fire optimizer alpha_start = 0.1 # Default from fire optimizer dt_start_val = 0.1 - init_fn, update_fn = fire( + state = ts.fire_init( model=lj_model, + state=ar_supercell_sim_state, md_flavor="ase_fire", - f_dec=f_dec, alpha_start=alpha_start, dt_start=dt_start_val, - dt_max=1.0, - max_step=10.0, # Large max_step to not interfere with velocity check ) - # Initialize state (forces are computed here) - state = init_fn(ar_supercell_sim_state) # Save parameters from initial state initial_dt_batch = state.dt.clone() # per-system dt - # Manipulate state to ensure P < 0 for the update_fn step + # Manipulate state to ensure P < 0 for the step_fn # Ensure forces are non-trivial state.forces += torch.sign(state.forces + 1e-6) * 1e-2 state.forces[torch.abs(state.forces) < 1e-3] = 1e-3 @@ -242,9 +229,15 @@ def test_fire_ase_negative_power_branch( # Store forces that will be used in the power calculation and v += dt*F step forces_at_power_calc = state.forces.clone() - # Deepcopy state as update_fn modifies it in-place + # Deepcopy state as step_fn modifies it in-place state_to_update = copy.deepcopy(state) - updated_state = update_fn(state_to_update) + updated_state = ts.fire_step( + lj_model, + state_to_update, + f_dec=f_dec, + dt_max=1.0, + max_step=10.0, # Large max_step to not interfere with velocity check + ) # Assertions for P < 0 branch being taken # Check for a single-batch state (ar_supercell_sim_state is single batch) @@ -273,7 +266,7 @@ def test_fire_ase_negative_power_branch( def test_fire_vv_negative_power_branch( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Attempt to trigger and test the VV FIRE P<0 branch.""" f_dec = 0.5 @@ -282,22 +275,25 @@ def test_fire_vv_negative_power_branch( dt_start_val = 2.0 dt_max_val = 2.0 - init_fn, update_fn = fire( + state = ts.fire_init( model=lj_model, + state=ar_supercell_sim_state, md_flavor="vv_fire", - f_dec=f_dec, alpha_start=alpha_start, dt_start=dt_start_val, - dt_max=dt_max_val, - n_min=0, # Allow dt to change immediately ) - state = init_fn(ar_supercell_sim_state) initial_dt_batch = state.dt.clone() initial_alpha_batch = state.alpha.clone() # Already alpha_start state_to_update = copy.deepcopy(state) - updated_state = update_fn(state_to_update) + updated_state = ts.fire_step( + lj_model, + state_to_update, + f_dec=f_dec, + dt_max=dt_max_val, + n_min=0, # Allow dt to change immediately + ) # Check if the P<0 branch was likely hit (params changed accordingly for batch 0) expected_dt_val = initial_dt_batch[0] * f_dec @@ -326,7 +322,7 @@ def test_fire_vv_negative_power_branch( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_unit_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor + ar_supercell_sim_state: SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Unit Cell FIRE optimizer actually minimizes energy.""" @@ -340,7 +336,7 @@ def test_unit_cell_fire_optimization( + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 ) - current_sim_state = ts.SimState( + current_sim_state = SimState( positions=current_positions, masses=ar_supercell_sim_state.masses.clone(), cell=current_cell, @@ -352,23 +348,22 @@ def test_unit_cell_fire_optimization( initial_state_positions = current_sim_state.positions.clone() initial_state_cell = current_sim_state.cell.clone() - # Initialize FIRE optimizer - init_fn, update_fn = unit_cell_fire( + # Initialize FIRE optimizer with unit cell filter + state = ts.fire_init( model=lj_model, - dt_max=0.3, + state=current_sim_state, dt_start=0.1, md_flavor=md_flavor, + cell_filter=ts.CellFilter.unit, ) - state = init_fn(current_sim_state) - # Run optimization for a few steps energies = [1000.0, state.energy.item()] max_steps = 1000 steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = update_fn(state) + state = ts.fire_step(lj_model, state, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -403,18 +398,24 @@ def test_unit_cell_fire_optimization( @pytest.mark.parametrize( - ("optimizer_fn", "expected_state_type", "cell_factor_val"), + ("optimizer_fn", "cell_filter", "expected_state_type", "cell_factor_val"), [ - (unit_cell_fire, UnitCellFireState, 100), - (unit_cell_gradient_descent, UnitCellGDState, 50.0), - (frechet_cell_fire, FrechetCellFIREState, 75.0), + (ts.OptimFlavor.fire, ts.CellFilter.unit, ts.CellFireState, 100), + ( + ts.OptimFlavor.gradient_descent, + ts.CellFilter.unit, + ts.CellOptimState, + 50.0, + ), + (ts.OptimFlavor.fire, ts.CellFilter.frechet, ts.CellFireState, 75.0), ], ) def test_cell_optimizer_init_with_dict_and_cell_factor( - optimizer_fn: callable, - expected_state_type: type, + optimizer_fn: ts.OptimFlavor, + expected_state_type: OptimState, + cell_filter: ts.CellFilter, cell_factor_val: float, - ar_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with dict state and explicit cell_factor.""" @@ -422,13 +423,19 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( f.name: getattr(ar_supercell_sim_state, f.name) for f in fields(ar_supercell_sim_state) } - init_fn, _ = optimizer_fn(model=lj_model, cell_factor=cell_factor_val) - opt_state = init_fn(state_dict) + init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn] + opt_state = init_fn( + model=lj_model, + state=state_dict, + cell_factor=cell_factor_val, + cell_filter=cell_filter, + ) assert isinstance(opt_state, expected_state_type) assert opt_state.energy is not None assert opt_state.forces is not None assert opt_state.stress is not None + # Check cell_factor is stored in cell_state expected_cf_tensor = torch.full( (opt_state.n_systems, 1, 1), float(cell_factor_val), # Ensure float for comparison if int is passed @@ -439,27 +446,46 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( @pytest.mark.parametrize( - ("optimizer_fn", "expected_state_type"), + ("optimizer_fn", "cell_filter", "expected_state_type"), [ - (unit_cell_fire, UnitCellFireState), - (frechet_cell_fire, FrechetCellFIREState), + (ts.OptimFlavor.fire, ts.CellFilter.unit, ts.CellFireState), + (ts.OptimFlavor.fire, ts.CellFilter.frechet, ts.CellFireState), + ( + ts.OptimFlavor.gradient_descent, + ts.CellFilter.unit, + ts.CellOptimState, + ), + ( + ts.OptimFlavor.gradient_descent, + ts.CellFilter.frechet, + ts.CellOptimState, + ), ], ) def test_cell_optimizer_init_cell_factor_none( - optimizer_fn: callable, - expected_state_type: type, - ar_supercell_sim_state: ts.SimState, + optimizer_fn: ts.OptimFlavor, + cell_filter: ts.CellFilter, + expected_state_type: OptimState, + ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: """Test cell optimizer init_fn with cell_factor=None.""" - init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) + init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn] + opt_state = init_fn( + model=lj_model, + state=ar_supercell_sim_state, + cell_factor=None, + cell_filter=cell_filter, + ) # Ensure n_systems > 0 for cell_factor calculation from counts assert ar_supercell_sim_state.n_systems > 0 - opt_state = init_fn(ar_supercell_sim_state) # Uses ts.SimState directly assert isinstance(opt_state, expected_state_type) _, counts = torch.unique(ar_supercell_sim_state.system_idx, return_counts=True) expected_cf_tensor = counts.to(dtype=lj_model.dtype).view(-1, 1, 1) + + # Check cell_factor is stored in cell_state for new API assert torch.allclose(opt_state.cell_factor, expected_cf_tensor) + assert opt_state.energy is not None assert opt_state.forces is not None assert opt_state.stress is not None @@ -467,11 +493,11 @@ def test_cell_optimizer_init_cell_factor_none( @pytest.mark.filterwarnings("ignore:WARNING: Non-positive volume detected") def test_unit_cell_fire_ase_non_positive_volume_warning( - ar_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, lj_model: ModelInterface, capsys: pytest.CaptureFixture, ) -> None: - """Attempt to trigger non-positive volume warning in unit_cell_fire ASE.""" + """Attempt to trigger non-positive volume warning in ASE unit cell fire.""" # Use a state that might lead to cell inversion with aggressive steps # Make a copy and slightly perturb the cell to make it prone to issues perturbed_state = ar_supercell_sim_state.clone() @@ -482,20 +508,24 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( if torch.linalg.det(perturbed_state.cell[0]) < 1.0: perturbed_state.cell[0] *= 2.0 - init_fn, update_fn = unit_cell_fire( + state = ts.fire_init( model=lj_model, + state=perturbed_state, md_flavor="ase_fire", - dt_max=5.0, # Large dt - max_step=2.0, # Large max_step dt_start=1.0, - f_dec=0.99, # Slow down dt decrease alpha_start=0.99, # Aggressive alpha + cell_filter=ts.CellFilter.unit, ) - state = init_fn(perturbed_state) # Run a few steps hoping to trigger the warning for _ in range(5): - state = update_fn(state) + state = ts.fire_step( + lj_model, + state, + dt_max=5.0, # Large dt + max_step=2.0, # Large max_step + f_dec=0.99, # Slow down dt decrease + ) if "WARNING: Non-positive volume detected" in capsys.readouterr().err: break # Warning captured @@ -504,7 +534,7 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( @pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_frechet_cell_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, md_flavor: MdFlavor + ar_supercell_sim_state: SimState, lj_model: ModelInterface, md_flavor: MdFlavor ) -> None: """Test that the Frechet Cell FIRE optimizer actually minimizes energy for different md_flavors.""" @@ -520,7 +550,7 @@ def test_frechet_cell_fire_optimization( + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 ) - current_sim_state = ts.SimState( + current_sim_state = SimState( positions=current_positions, masses=ar_supercell_sim_state.masses.clone(), cell=current_cell, @@ -532,23 +562,21 @@ def test_frechet_cell_fire_optimization( initial_state_positions = current_sim_state.positions.clone() initial_state_cell = current_sim_state.cell.clone() - # Initialize FIRE optimizer - init_fn, update_fn = frechet_cell_fire( + state = ts.fire_init( model=lj_model, - dt_max=0.3, + state=current_sim_state, dt_start=0.1, md_flavor=md_flavor, + cell_filter=ts.CellFilter.frechet, ) - state = init_fn(current_sim_state) - # Run optimization for a few steps energies = [1000.0, state.energy.item()] # Ensure float for comparison max_steps = 1000 steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = update_fn(state) + state = ts.fire_step(model=lj_model, state=state, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -589,10 +617,13 @@ def test_frechet_cell_fire_optimization( ) -@pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) +@pytest.mark.parametrize( + "filter_func", + [None, ts.CellFilter.unit, ts.CellFilter.frechet], +) def test_optimizer_batch_consistency( - optimizer_func: callable, - ar_supercell_sim_state: ts.SimState, + filter_func: ts.CellFilter | None, + ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: """Test batched optimizer is consistent with individual optimizations.""" @@ -610,7 +641,7 @@ def test_optimizer_batch_consistency( ) * 0.1 ) - if optimizer_func in (unit_cell_fire, frechet_cell_fire): + if filter_func: generator.manual_seed(44) # Reset seed for cell state1_orig.cell += ( torch.randn( @@ -624,68 +655,71 @@ def test_optimizer_batch_consistency( final_individual_states = [] - def energy_converged(current_e: torch.Tensor, prev_e: torch.Tensor) -> bool: + def energy_converged(e_current: torch.Tensor, e_prev: torch.Tensor) -> bool: """Check for energy convergence (scalar energies).""" - return not torch.allclose(current_e, prev_e, atol=1e-6) + return not torch.allclose(e_current, e_prev, atol=1e-6) for state_for_indiv_opt in [state1_orig.clone(), state2_orig.clone()]: - init_fn_indiv, update_fn_indiv = optimizer_func( - model=lj_model, dt_max=0.3, dt_start=0.1 + init_fn_indiv, step_fn_indiv = ts.OPTIM_REGISTRY[ts.OptimFlavor.fire] + opt_state_indiv = init_fn_indiv( + model=lj_model, + state=state_for_indiv_opt, + dt_start=0.1, + cell_filter=filter_func, ) - opt_state_indiv = init_fn_indiv(state_for_indiv_opt) current_e_indiv = opt_state_indiv.energy # Ensure prev_e_indiv is different to start the loop - prev_e_indiv = current_e_indiv + torch.tensor( + e_prev_indiv = current_e_indiv + torch.tensor( 1.0, device=current_e_indiv.device, dtype=current_e_indiv.dtype ) steps_indiv = 0 - while energy_converged(current_e_indiv, prev_e_indiv): - prev_e_indiv = current_e_indiv - opt_state_indiv = update_fn_indiv(opt_state_indiv) + while energy_converged(current_e_indiv, e_prev_indiv): + e_prev_indiv = current_e_indiv + opt_state_indiv = step_fn_indiv( + model=lj_model, state=opt_state_indiv, dt_max=0.3 + ) current_e_indiv = opt_state_indiv.energy steps_indiv += 1 if steps_indiv > 1000: raise ValueError( - f"Individual opt for {optimizer_func.__name__} did not converge" + f"Individual opt for {filter_func.name} did not converge" ) final_individual_states.append(opt_state_indiv) # Batched optimization - multi_state_initial = concatenate_states( + multi_state_initial = ts.concatenate_states( [state1_orig.clone(), state2_orig.clone()], device=ar_supercell_sim_state.device, ) - init_fn_batch, update_fn_batch = optimizer_func( - model=lj_model, dt_max=0.3, dt_start=0.1 + init_fn_batch, step_fn_batch = ts.OPTIM_REGISTRY[ts.OptimFlavor.fire] + batch_opt_state = init_fn_batch( + model=lj_model, state=multi_state_initial, cell_filter=filter_func ) - batch_opt_state = init_fn_batch(multi_state_initial) - current_energies_batch = batch_opt_state.energy.clone() - # Ensure prev_energies_batch requires update and has same shape - prev_energies_batch = current_energies_batch + torch.tensor( - 1.0, device=current_energies_batch.device, dtype=current_energies_batch.dtype + e_current_batch = batch_opt_state.energy.clone() + # Ensure e_prev_batch requires update and has same shape + e_prev_batch = e_current_batch + torch.tensor( + 1.0, device=e_current_batch.device, dtype=e_current_batch.dtype ) steps_batch = 0 # Converge when all batch energies have converged - while not torch.allclose(current_energies_batch, prev_energies_batch, atol=1e-6): - prev_energies_batch = current_energies_batch.clone() - batch_opt_state = update_fn_batch(batch_opt_state) - current_energies_batch = batch_opt_state.energy.clone() + while not torch.allclose(e_current_batch, e_prev_batch, atol=1e-6): + e_prev_batch = e_current_batch.clone() + batch_opt_state = step_fn_batch(model=lj_model, state=batch_opt_state) + e_current_batch = batch_opt_state.energy.clone() steps_batch += 1 if steps_batch > 1000: - raise ValueError( - f"Batched opt for {optimizer_func.__name__} did not converge" - ) + raise ValueError(f"Batched opt for {filter_func.name} did not converge") individual_final_energies = [s.energy.item() for s in final_individual_states] for idx, indiv_energy in enumerate(individual_final_energies): - assert abs(batch_opt_state.energy[idx].item() - indiv_energy) < 1e-4, ( - f"Energy batch {idx} ({optimizer_func.__name__}): " - f"{batch_opt_state.energy[idx].item()} vs indiv {indiv_energy}" + assert abs(e_current_batch[idx].item() - indiv_energy) < 1e-4, ( + f"Energy batch {idx} ({filter_func=}): " + f"{e_current_batch[idx].item()} vs indiv {indiv_energy}" ) # Check positions changed for both parts of the batch @@ -694,21 +728,21 @@ def energy_converged(current_e: torch.Tensor, prev_e: torch.Tensor) -> bool: batch_opt_state.positions[:n_atoms_first_state], multi_state_initial.positions[:n_atoms_first_state], atol=1e-5, # Added tolerance as in original frechet test - ), f"{optimizer_func.__name__} positions batch 0 did not change." + ), f"{filter_func=} positions batch 0 did not change." assert not torch.allclose( batch_opt_state.positions[n_atoms_first_state:], multi_state_initial.positions[n_atoms_first_state:], atol=1e-5, - ), f"{optimizer_func.__name__} positions batch 1 did not change." + ), f"{filter_func=} positions batch 1 did not change." - if optimizer_func in (unit_cell_fire, frechet_cell_fire): + if filter_func: assert not torch.allclose( batch_opt_state.cell, multi_state_initial.cell, atol=1e-5 - ), f"{optimizer_func.__name__} cell did not change." + ), f"{filter_func.name} cell did not change." def test_unit_cell_fire_multi_batch( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Test FIRE optimization with multiple batches.""" # Create a multi-batch system by duplicating ar_fcc_state @@ -718,7 +752,7 @@ def test_unit_cell_fire_multi_batch( ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): generator.manual_seed(43) state.positions += ( torch.randn( @@ -729,19 +763,15 @@ def test_unit_cell_fire_multi_batch( * 0.1 ) - multi_state = concatenate_states( + multi_state = ts.concatenate_states( [ar_supercell_sim_state_1, ar_supercell_sim_state_2], device=ar_supercell_sim_state.device, ) - # Initialize FIRE optimizer - init_fn, update_fn = unit_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, + # Initialize FIRE optimizer with unit cell filter + state = ts.fire_init( + model=lj_model, state=multi_state, dt_start=0.1, cell_filter=ts.CellFilter.unit ) - - state = init_fn(multi_state) initial_state = copy.deepcopy(state) # Run optimization for a few steps @@ -750,7 +780,7 @@ def test_unit_cell_fire_multi_batch( step = 0 while not torch.allclose(current_energy, prev_energy, atol=1e-9): prev_energy = current_energy - state = update_fn(state) + state = ts.fire_step(model=lj_model, state=state, dt_max=0.3) current_energy = state.energy step += 1 @@ -784,7 +814,7 @@ def test_unit_cell_fire_multi_batch( def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface + ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: """Test batched Frechet Fixed cell FIRE optimization is consistent with FIRE (position only) optimizations.""" @@ -794,7 +824,7 @@ def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) # Add same random perturbation to both states - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): generator.manual_seed(43) state.positions += ( torch.randn(state.positions.shape, device=state.device, generator=generator) @@ -805,21 +835,20 @@ def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 final_individual_states_unit_cell = [] total_steps_unit_cell = [] - def energy_converged(current_energy: float, prev_energy: float) -> bool: + def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> bool: """Check if optimization should continue based on energy convergence.""" return not torch.allclose(current_energy, prev_energy, atol=1e-6) - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = unit_cell_fire( - model=lj_model, - dt_max=0.3, + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + state_opt = ts.fire_init( + lj_model, + state, dt_start=0.1, + cell_filter=ts.CellFilter.unit, hydrostatic_strain=True, constant_volume=True, ) - state_opt = init_fn(state) - # Run optimization until convergence current_energy = state_opt.energy prev_energy = current_energy + 1 @@ -827,7 +856,7 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: step = 0 while energy_converged(current_energy, prev_energy): prev_energy = current_energy - state_opt = update_fn(state_opt) + state_opt = ts.fire_step(lj_model, state_opt, dt_max=0.3) current_energy = state_opt.energy step += 1 if step > 1000: @@ -840,14 +869,12 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: final_individual_states_fire = [] total_steps_fire = [] - def energy_converged(current_energy: float, prev_energy: float) -> bool: + def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> bool: """Check if optimization should continue based on energy convergence.""" return not torch.allclose(current_energy, prev_energy, atol=1e-6) - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = fire(model=lj_model, dt_max=0.3, dt_start=0.1) - - state_opt = init_fn(state) + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + state_opt = ts.fire_init(model=lj_model, state=state, dt_start=0.1) # Run optimization until convergence current_energy = state_opt.energy @@ -856,7 +883,7 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: step = 0 while energy_converged(current_energy, prev_energy): prev_energy = current_energy - state_opt = update_fn(state_opt) + state_opt = ts.fire_step(model=lj_model, state=state_opt, dt_max=0.3) current_energy = state_opt.energy step += 1 if step > 1000: diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index acf5e0b94..e8192fcea 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -7,18 +7,45 @@ from pymatgen.analysis.structure_matcher import StructureMatcher import torch_sim as ts -from torch_sim.io import atoms_to_state, state_to_atoms, state_to_structures -from torch_sim.models.mace import MaceModel -from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire +from tests.conftest import DTYPE +from torch_sim.models.mace import MaceModel, MaceUrls if TYPE_CHECKING: from mace.calculators import MACECalculator +@pytest.fixture +def ts_mace_mpa() -> MaceModel: + """Provides a MACE MP model instance for the optimizer tests.""" + from mace.calculators.foundations_models import mace_mp + + # Use float64 for potentially higher precision needed in optimization + dtype = getattr(torch, dtype_str := "float64") + raw_mace = mace_mp( + model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str + ) + return MaceModel( + model=raw_mace, + device=torch.device("cpu"), + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def ase_mace_mpa() -> "MACECalculator": + """Provides an ASE MACECalculator instance using mace_mp.""" + from mace.calculators.foundations_models import mace_mp + + # Ensure dtype matches the one used in the torch-sim fixture (float64) + return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") + + def _compare_ase_and_ts_states( - ts_current_system_state: ts.state.SimState, - filtered_ase_atoms_for_run: Any, + state: ts.FireState, + filtered_ase_atoms: FrechetCellFilter | UnitCellFilter, tolerances: dict[str, float], current_test_id: str, ) -> None: @@ -28,29 +55,23 @@ def _compare_ase_and_ts_states( angle_tol=tolerances["angle_tol"], scale=False, ) + tensor_kwargs = {"device": state.device, "dtype": state.dtype} - tensor_kwargs = { - "device": ts_current_system_state.device, - "dtype": ts_current_system_state.dtype, - } - - final_custom_energy = ts_current_system_state.energy.item() - final_custom_forces_max = ( - torch.norm(ts_current_system_state.forces, dim=-1).max().item() - ) + final_custom_energy = state.energy.item() + final_custom_forces_max = torch.norm(state.forces, dim=-1).max().item() # Convert torch-sim state to pymatgen Structure - ts_structure = state_to_structures(ts_current_system_state)[0] + ts_structure = ts.io.state_to_structures(state)[0] # Convert ASE atoms to pymatgen Structure - final_ase_atoms = filtered_ase_atoms_for_run.atoms + final_ase_atoms = filtered_ase_atoms.atoms final_ase_energy = final_ase_atoms.get_potential_energy() ase_forces_raw = final_ase_atoms.get_forces() final_ase_forces_max = torch.norm( torch.tensor(ase_forces_raw, **tensor_kwargs), dim=-1 ).max() - ts_state = atoms_to_state(final_ase_atoms, **tensor_kwargs) - ase_structure = state_to_structures(ts_state)[0] + ts_state = ts.io.atoms_to_state(final_ase_atoms, **tensor_kwargs) + ase_structure = ts.io.state_to_structures(ts_state)[0] # Compare energies energy_diff = abs(final_custom_energy - final_ase_energy) @@ -76,56 +97,44 @@ def _compare_ase_and_ts_states( def _run_and_compare_optimizers( - initial_sim_state_fixture: ts.state.SimState, - torchsim_mace_mpa: MaceModel, + initial_sim_state_fixture: ts.SimState, + ts_mace_mpa: MaceModel, ase_mace_mpa: "MACECalculator", - torch_sim_optimizer_type: str, - ase_filter_class: Any, + fire_type: ts.OptimFlavor, + ase_filter_cls: FrechetCellFilter | UnitCellFilter, checkpoints: list[int], force_tol: float, tolerances: dict[str, float], test_id_prefix: str, + **optim_kwargs: Any, ) -> None: """Run and compare optimizations between torch-sim and ASE.""" pytest.importorskip("mace") - dtype = torch.float64 - device = torchsim_mace_mpa.device + device = ts_mace_mpa.device - ts_current_system_state = initial_sim_state_fixture.clone() - - optimizer_builders = { - "frechet": frechet_cell_fire, - "unit_cell": unit_cell_fire, - } - if torch_sim_optimizer_type not in optimizer_builders: - raise ValueError(f"Unknown torch_sim_optimizer_type: {torch_sim_optimizer_type}") - ts_optimizer_builder = optimizer_builders[torch_sim_optimizer_type] - - optimizer_callable_for_ts_optimize = lambda model, **_kwargs: ts_optimizer_builder( # noqa: E731 - model, md_flavor="ase_fire" - ) + state = initial_sim_state_fixture.clone() - ase_atoms_for_run = state_to_atoms( - initial_sim_state_fixture.clone().to(dtype=dtype, device=device) + ase_atoms = ts.io.state_to_atoms( + initial_sim_state_fixture.clone().to(dtype=DTYPE, device=device) )[0] - ase_atoms_for_run.calc = ase_mace_mpa - filtered_ase_atoms_for_run = ase_filter_class(ase_atoms_for_run) - ase_optimizer = FIRE(filtered_ase_atoms_for_run, logfile=None) + ase_atoms.calc = ase_mace_mpa + filtered_ase_atoms = ase_filter_cls(ase_atoms) # type: ignore[call-non-callable] + ase_optimizer = FIRE(filtered_ase_atoms, logfile=None) last_checkpoint_step_count = 0 convergence_fn = ts.generate_force_convergence_fn( force_tol=force_tol, include_cell_forces=True ) - results = torchsim_mace_mpa(ts_current_system_state) - ts_initial_system_state = ts_current_system_state.clone() + results = ts_mace_mpa(state) + ts_initial_system_state = state.clone() ts_initial_system_state.forces = results["forces"] ts_initial_system_state.energy = results["energy"] - ase_atoms_for_run.calc.calculate(ase_atoms_for_run) + ase_mace_mpa.calculate(ase_atoms) _compare_ase_and_ts_states( ts_initial_system_state, - filtered_ase_atoms_for_run, + filtered_ase_atoms, tolerances, f"{test_id_prefix} (Initial)", ) @@ -135,25 +144,22 @@ def _run_and_compare_optimizers( if steps_for_current_segment > 0: updated_ts_state = ts.optimize( - system=ts_current_system_state, - model=torchsim_mace_mpa, - optimizer=optimizer_callable_for_ts_optimize, + system=state, + model=ts_mace_mpa, + optimizer=fire_type, max_steps=steps_for_current_segment, convergence_fn=convergence_fn, steps_between_swaps=1, + md_flavor="ase_fire", # optimizer kwargs + **optim_kwargs, ) - ts_current_system_state = updated_ts_state.clone() + state = updated_ts_state.clone() ase_optimizer.run(fmax=force_tol, steps=steps_for_current_segment) current_test_id = f"{test_id_prefix} (Step {checkpoint_step})" - _compare_ase_and_ts_states( - ts_current_system_state, - filtered_ase_atoms_for_run, - tolerances, - current_test_id, - ) + _compare_ase_and_ts_states(state, filtered_ase_atoms, tolerances, current_test_id) last_checkpoint_step_count = checkpoint_step @@ -161,8 +167,9 @@ def _run_and_compare_optimizers( @pytest.mark.parametrize( ( "sim_state_fixture_name", - "torch_sim_optimizer_type", - "ase_filter_class", + "fire_type", + "cell_filter", + "ase_filter_cls", "checkpoints", "force_tol", "tolerances", @@ -171,7 +178,8 @@ def _run_and_compare_optimizers( [ ( "rattled_sio2_sim_state", - "frechet", + ts.OptimFlavor.fire, + ts.CellFilter.frechet, FrechetCellFilter, [1, 33, 66, 100], 0.02, @@ -186,7 +194,8 @@ def _run_and_compare_optimizers( ), ( "osn2_sim_state", - "frechet", + ts.OptimFlavor.fire, + ts.CellFilter.frechet, FrechetCellFilter, [1, 16, 33, 50], 0.02, @@ -201,7 +210,8 @@ def _run_and_compare_optimizers( ), ( "distorted_fcc_al_conventional_sim_state", - "frechet", + ts.OptimFlavor.fire, + ts.CellFilter.frechet, FrechetCellFilter, [1, 33, 66, 100], 0.01, @@ -216,7 +226,8 @@ def _run_and_compare_optimizers( ), ( "distorted_fcc_al_conventional_sim_state", - "unit_cell", + ts.OptimFlavor.fire, + ts.CellFilter.unit, UnitCellFilter, [1, 33, 66, 100], 0.01, @@ -231,7 +242,8 @@ def _run_and_compare_optimizers( ), ( "rattled_sio2_sim_state", - "unit_cell", + ts.OptimFlavor.fire, + ts.CellFilter.unit, UnitCellFilter, [1, 33, 66, 100], 0.02, @@ -246,7 +258,8 @@ def _run_and_compare_optimizers( ), ( "osn2_sim_state", - "unit_cell", + ts.OptimFlavor.fire, + ts.CellFilter.unit, UnitCellFilter, [1, 16, 33, 50], 0.02, @@ -263,13 +276,14 @@ def _run_and_compare_optimizers( ) def test_optimizer_vs_ase_parametrized( sim_state_fixture_name: str, - torch_sim_optimizer_type: str, - ase_filter_class: Any, + fire_type: ts.OptimFlavor, + cell_filter: ts.CellFilter, + ase_filter_cls: FrechetCellFilter | UnitCellFilter, checkpoints: list[int], force_tol: float, tolerances: dict[str, float], test_id_prefix: str, - torchsim_mace_mpa: MaceModel, + ts_mace_mpa: MaceModel, ase_mace_mpa: "MACECalculator", request: pytest.FixtureRequest, ) -> None: @@ -279,10 +293,11 @@ def test_optimizer_vs_ase_parametrized( _run_and_compare_optimizers( initial_sim_state_fixture=initial_sim_state_fixture, - torchsim_mace_mpa=torchsim_mace_mpa, + ts_mace_mpa=ts_mace_mpa, ase_mace_mpa=ase_mace_mpa, - torch_sim_optimizer_type=torch_sim_optimizer_type, - ase_filter_class=ase_filter_class, + fire_type=fire_type, + cell_filter=cell_filter, + ase_filter_cls=ase_filter_cls, checkpoints=checkpoints, force_tol=force_tol, tolerances=tolerances, diff --git a/tests/test_quantities.py b/tests/test_quantities.py index 7513b6bd9..936484b2e 100644 --- a/tests/test_quantities.py +++ b/tests/test_quantities.py @@ -1,136 +1,145 @@ +"""Tests for quantities module functions.""" + import pytest import torch -from torch._tensor import Tensor - -from torch_sim import quantities -from torch_sim.units import MetalUnits - - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -DTYPE = torch.double - - -@pytest.fixture -def single_system_data() -> dict[str, Tensor]: - masses = torch.tensor([1.0, 2.0], device=DEVICE, dtype=DTYPE) - velocities = torch.tensor( - [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], device=DEVICE, dtype=DTYPE - ) - momenta = velocities * masses.unsqueeze(-1) - return { - "masses": masses, - "velocities": velocities, - "momenta": momenta, - "ke": torch.tensor(13.5, device=DEVICE, dtype=DTYPE), - "kt": torch.tensor(4.5, device=DEVICE, dtype=DTYPE), - } - - -@pytest.fixture -def batched_system_data() -> dict[str, Tensor]: - masses = torch.tensor([1.0, 1.0, 2.0, 2.0], device=DEVICE, dtype=DTYPE) - velocities = torch.tensor( - [[1, 1, 1], [1, 1, 1], [2, 2, 2], [2, 2, 2]], device=DEVICE, dtype=DTYPE - ) - momenta = velocities * masses.unsqueeze(-1) - system_idx = torch.tensor([0, 0, 1, 1], device=DEVICE) - return { - "masses": masses, - "velocities": velocities, - "momenta": momenta, - "system_idx": system_idx, - "ke": torch.tensor([3.0, 24.0], device=DEVICE, dtype=DTYPE), - "kt": torch.tensor([1.0, 8.0], device=DEVICE, dtype=DTYPE), - } - - -def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: - # With velocities - ke_vel = quantities.calc_kinetic_energy( - masses=single_system_data["masses"], - velocities=single_system_data["velocities"], - ) - assert torch.allclose(ke_vel, single_system_data["ke"]) - - # With momenta - ke_mom = quantities.calc_kinetic_energy( - masses=single_system_data["masses"], momenta=single_system_data["momenta"] - ) - assert torch.allclose(ke_mom, single_system_data["ke"]) - - -def test_calc_kinetic_energy_batched_system( - batched_system_data: dict[str, Tensor], -) -> None: - # With velocities - ke_vel = quantities.calc_kinetic_energy( - masses=batched_system_data["masses"], - velocities=batched_system_data["velocities"], - system_idx=batched_system_data["system_idx"], - ) - assert torch.allclose(ke_vel, batched_system_data["ke"]) - - # With momenta - ke_mom = quantities.calc_kinetic_energy( - masses=batched_system_data["masses"], - momenta=batched_system_data["momenta"], - system_idx=batched_system_data["system_idx"], - ) - assert torch.allclose(ke_mom, batched_system_data["ke"]) - - -def test_calc_kinetic_energy_errors(single_system_data: dict[str, Tensor]) -> None: - with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy( - masses=single_system_data["masses"], - momenta=single_system_data["momenta"], - velocities=single_system_data["velocities"], +from numpy.testing import assert_allclose + +from tests.conftest import DEVICE +from torch_sim.quantities import calc_heat_flux + + +class TestHeatFlux: + """Test suite for heat flux calculations.""" + + @pytest.fixture + def mock_simple_system(self) -> dict[str, torch.Tensor]: + """Simple system with known values.""" + return { + "velocities": torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=DEVICE, + ), + "energies": torch.tensor([1.0, 2.0, 3.0], device=DEVICE), + "stress": torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=DEVICE, + ), + "masses": torch.ones(3, device=DEVICE), + } + + def test_unbatched_total_flux( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test total heat flux calculation for unbatched case.""" + flux = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stresses=mock_simple_system["stress"], + is_virial_only=False, + ) + + # Heat flux parts should cancel out + expected = torch.zeros(3, device=flux.device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_unbatched_virial_only( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test virial-only heat flux calculation for unbatched case.""" + virial = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stresses=mock_simple_system["stress"], + is_virial_only=True, + ) + + expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device) + assert_allclose(virial.cpu().numpy(), expected.cpu().numpy()) + + def test_batched_calculation(self) -> None: + """Test heat flux calculation with batched data.""" + velocities = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=DEVICE, + ) + energies = torch.tensor([1.0, 2.0, 3.0], device=DEVICE) + stress = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=DEVICE, + ) + batch = torch.tensor([0, 0, 1], device=DEVICE) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(3, device=DEVICE), + velocities=velocities, + energies=energies, + stresses=stress, + batch=batch, + ) + + # Each batch should cancel heat flux parts + expected = torch.zeros((2, 3), device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_centroid_stress(self) -> None: + """Test heat flux with centroid stress formulation.""" + velocities = torch.tensor([[1.0, 1.0, 1.0]], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) + + # Symmetric cross-terms + stress = torch.tensor( + [[1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], device=DEVICE + ) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(1, device=DEVICE), + velocities=velocities, + energies=energies, + stresses=stress, + is_centroid_stress=True, + ) + + # Heatflux should be [-1,-1,-1] + expected = torch.full((3,), -1.0, device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_momenta_input(self) -> None: + """Test heat flux calculation using momenta instead.""" + momenta = torch.tensor([[1.0, 0.0, 0.0]], device=DEVICE) + masses = torch.tensor([2.0], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) + stress = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], device=DEVICE) + + flux = calc_heat_flux( + momenta=momenta, + masses=masses, + velocities=None, + energies=energies, + stresses=stress, ) - with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy(masses=single_system_data["masses"]) - - -def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: - # With velocities - kt_vel = quantities.calc_kT( - masses=single_system_data["masses"], - velocities=single_system_data["velocities"], - ) - assert torch.allclose(kt_vel, single_system_data["kt"]) - - # With momenta - kt_mom = quantities.calc_kT( - masses=single_system_data["masses"], momenta=single_system_data["momenta"] - ) - assert torch.allclose(kt_mom, single_system_data["kt"]) - - -def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: - # With velocities - kt_vel = quantities.calc_kT( - masses=batched_system_data["masses"], - velocities=batched_system_data["velocities"], - system_idx=batched_system_data["system_idx"], - ) - assert torch.allclose(kt_vel, batched_system_data["kt"]) - - # With momenta - kt_mom = quantities.calc_kT( - masses=batched_system_data["masses"], - momenta=batched_system_data["momenta"], - system_idx=batched_system_data["system_idx"], - ) - assert torch.allclose(kt_mom, batched_system_data["kt"]) - - -def test_calc_temperature(single_system_data: dict[str, Tensor]) -> None: - temp = quantities.calc_temperature( - masses=single_system_data["masses"], - velocities=single_system_data["velocities"], - ) - kt = quantities.calc_kT( - masses=single_system_data["masses"], - velocities=single_system_data["velocities"], - ) - assert torch.allclose(temp, kt / MetalUnits.temperature) + # Heat flux terms should cancel out + expected = torch.zeros(3, device=DEVICE) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) diff --git a/tests/test_runners.py b/tests/test_runners.py index 5c9862d0e..8db32ff18 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -6,16 +6,16 @@ import torch import torch_sim as ts +from tests.conftest import DEVICE, DTYPE from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher -from torch_sim.integrators import nve, nvt_langevin +from torch_sim.integrators.md import MDState from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.optimizers import unit_cell_fire -from torch_sim.quantities import calc_kinetic_energy +from torch_sim.state import SimState from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter def test_integrate_nve( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test NVE integration with LJ potential.""" traj_file = tmp_path / "nve.h5md" @@ -24,7 +24,7 @@ def test_integrate_nve( state_frequency=1, prop_calculators={ 1: { - "ke": lambda state: calc_kinetic_energy( + "ke": lambda state: ts.calc_kinetic_energy( momenta=state.momenta, masses=state.masses ) } @@ -34,14 +34,14 @@ def test_integrate_nve( final_state = ts.integrate( system=ar_supercell_sim_state, model=lj_model, - integrator=nve, + integrator=ts.MdFlavor.nve, n_steps=10, temperature=100.0, # K timestep=0.001, # ps trajectory_reporter=reporter, ) - assert isinstance(final_state, ts.SimState) + assert isinstance(final_state, SimState) assert traj_file.is_file() # Check energy conservation @@ -52,7 +52,7 @@ def test_integrate_nve( def test_integrate_single_nvt( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test NVT integration with LJ potential.""" traj_file = tmp_path / "nvt.h5md" @@ -61,7 +61,7 @@ def test_integrate_single_nvt( state_frequency=1, prop_calculators={ 1: { - "ke": lambda state: calc_kinetic_energy( + "ke": lambda state: ts.calc_kinetic_energy( momenta=state.momenta, masses=state.masses ) } @@ -71,15 +71,14 @@ def test_integrate_single_nvt( final_state = ts.integrate( system=ar_supercell_sim_state, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=10, temperature=100.0, # K timestep=0.001, # ps trajectory_reporter=reporter, - gamma=0.1, # ps^-1 ) - assert isinstance(final_state, ts.SimState) + assert isinstance(final_state, SimState) assert traj_file.is_file() # Check energy fluctuations @@ -90,25 +89,25 @@ def test_integrate_single_nvt( def test_integrate_double_nvt( - ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel + ar_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """Test NVT integration with LJ potential.""" final_state = ts.integrate( system=ar_double_sim_state, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=10, temperature=100.0, # K timestep=0.001, # ps ) - assert isinstance(final_state, ts.SimState) + assert isinstance(final_state, SimState) assert final_state.n_atoms == 64 assert not torch.isnan(final_state.energy).any() def test_integrate_double_nvt_with_reporter( - ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test NVT integration with LJ potential.""" trajectory_files = [tmp_path / "nvt_0.h5md", tmp_path / "nvt_1.h5md"] @@ -117,7 +116,7 @@ def test_integrate_double_nvt_with_reporter( state_frequency=1, prop_calculators={ 1: { - "ke": lambda state: calc_kinetic_energy( + "ke": lambda state: ts.calc_kinetic_energy( momenta=state.momenta, masses=state.masses ) } @@ -127,15 +126,14 @@ def test_integrate_double_nvt_with_reporter( final_state = ts.integrate( system=ar_double_sim_state, model=lj_model, - integrator=nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, n_steps=10, temperature=100.0, # K timestep=0.001, # ps trajectory_reporter=reporter, - gamma=0.1, # ps^-1 ) - assert isinstance(final_state, ts.SimState) + assert isinstance(final_state, SimState) assert final_state.n_atoms == 64 assert all(traj_file.is_file() for traj_file in trajectory_files) @@ -149,8 +147,8 @@ def test_integrate_double_nvt_with_reporter( def test_integrate_many_nvt( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path, ) -> None: @@ -161,14 +159,14 @@ def test_integrate_many_nvt( lj_model.dtype, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_systems) + tmp_path / f"nvt_{sys_idx}.h5md" for sys_idx in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, state_frequency=1, prop_calculators={ 1: { - "ke": lambda state: calc_kinetic_energy( + "ke": lambda state: ts.calc_kinetic_energy( momenta=state.momenta, masses=state.masses ) } @@ -178,14 +176,14 @@ def test_integrate_many_nvt( final_state = ts.integrate( system=triple_state, model=lj_model, - integrator=nve, + integrator=ts.MdFlavor.nve, n_steps=10, temperature=300.0, # K timestep=0.001, # ps trajectory_reporter=reporter, ) - assert isinstance(final_state, ts.SimState) + assert isinstance(final_state, SimState) assert all(traj_file.is_file() for traj_file in trajectory_files) assert not torch.isnan(final_state.energy).any() assert not torch.isnan(final_state.positions).any() @@ -196,8 +194,8 @@ def test_integrate_many_nvt( def test_integrate_with_autobatcher( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, ) -> None: """Test integration with autobatcher.""" @@ -215,22 +213,22 @@ def test_integrate_with_autobatcher( final_states = ts.integrate( system=triple_state, model=lj_model, - integrator=nve, + integrator=ts.MdFlavor.nve, n_steps=10, temperature=300.0, timestep=0.001, autobatcher=autobatcher, ) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) def test_integrate_with_autobatcher_and_reporting( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path, ) -> None: @@ -247,7 +245,7 @@ def test_integrate_with_autobatcher_and_reporting( max_memory_scaler=260, ) trajectory_files = [ - tmp_path / f"nvt_{batch}.h5md" for batch in range(triple_state.n_systems) + tmp_path / f"nvt_{sys_idx}.h5md" for sys_idx in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -257,7 +255,7 @@ def test_integrate_with_autobatcher_and_reporting( final_states = ts.integrate( system=triple_state, model=lj_model, - integrator=nve, + integrator=ts.MdFlavor.nve, n_steps=10, temperature=300.0, timestep=0.001, @@ -267,7 +265,7 @@ def test_integrate_with_autobatcher_and_reporting( assert all(traj_file.is_file() for traj_file in trajectory_files) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) @@ -287,7 +285,7 @@ def test_integrate_with_autobatcher_and_reporting( def test_optimize_fire( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test FIRE optimization with LJ potential.""" trajectory_files = [tmp_path / "opt.h5md"] @@ -304,7 +302,8 @@ def test_optimize_fire( final_state = ts.optimize( system=ar_supercell_sim_state, model=lj_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), trajectory_reporter=reporter, ) @@ -320,7 +319,7 @@ def test_optimize_fire( def test_default_converged_fn( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test default converged function.""" ar_supercell_sim_state.positions += ( @@ -337,7 +336,8 @@ def test_default_converged_fn( final_state = ts.optimize( system=ar_supercell_sim_state, model=lj_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, trajectory_reporter=reporter, ) @@ -350,7 +350,7 @@ def test_default_converged_fn( def test_batched_optimize_fire( - ar_double_sim_state: ts.SimState, + ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path, ) -> None: @@ -363,7 +363,7 @@ def test_batched_optimize_fire( state_frequency=1, prop_calculators={ 1: { - "ke": lambda state: calc_kinetic_energy( + "ke": lambda state: ts.calc_kinetic_energy( velocities=state.velocities, masses=state.masses ) } @@ -373,17 +373,19 @@ def test_batched_optimize_fire( final_state = ts.optimize( system=ar_double_sim_state, model=lj_model, - optimizer=unit_cell_fire, - convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, + convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-5), trajectory_reporter=reporter, + max_steps=500, ) assert torch.all(final_state.forces < 1e-4) def test_optimize_with_autobatcher( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, ) -> None: """Test optimize with autobatcher.""" @@ -394,27 +396,26 @@ def test_optimize_with_autobatcher( lj_model.dtype, ) autobatcher = InFlightAutoBatcher( - model=lj_model, - memory_scales_with="n_atoms", - max_memory_scaler=260, + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260 ) final_states = ts.optimize( system=triple_state, model=lj_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), autobatcher=autobatcher, ) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) def test_optimize_with_autobatcher_and_reporting( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path, ) -> None: @@ -434,7 +435,7 @@ def test_optimize_with_autobatcher_and_reporting( ) trajectory_files = [ - tmp_path / f"opt_{batch}.h5md" for batch in range(triple_state.n_systems) + tmp_path / f"opt_{sys_idx}.h5md" for sys_idx in range(triple_state.n_systems) ] reporter = TrajectoryReporter( filenames=trajectory_files, @@ -445,7 +446,8 @@ def test_optimize_with_autobatcher_and_reporting( final_states = ts.optimize( system=triple_state, model=lj_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), trajectory_reporter=reporter, autobatcher=autobatcher, @@ -453,7 +455,7 @@ def test_optimize_with_autobatcher_and_reporting( assert all(traj_file.is_file() for traj_file in trajectory_files) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) @@ -476,8 +478,8 @@ def test_optimize_with_autobatcher_and_reporting( def test_integrate_with_default_autobatcher( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -486,36 +488,30 @@ def test_integrate_with_default_autobatcher( def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 return 10_000.0 - monkeypatch.setattr( - "torch_sim.autobatching.estimate_max_memory_scaler", mock_estimate - ) + monkeypatch.setattr("ts.autobatching.estimate_max_memory_scaler", mock_estimate) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] - triple_state = ts.initialize_state( - states, - lj_model.device, - lj_model.dtype, - ) + triple_state = ts.initialize_state(states, lj_model.device, lj_model.dtype) final_states = ts.integrate( system=triple_state, model=lj_model, - integrator=nve, + integrator=ts.MdFlavor.nve, n_steps=10, temperature=300.0, timestep=0.001, autobatcher=True, ) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) def test_optimize_with_default_autobatcher( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -524,7 +520,7 @@ def test_optimize_with_default_autobatcher( def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 return 200 - monkeypatch.setattr("torch_sim.autobatching.determine_max_batch_size", mock_estimate) + monkeypatch.setattr("ts.autobatching.determine_max_batch_size", mock_estimate) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] triple_state = ts.initialize_state( @@ -536,19 +532,20 @@ def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 final_states = ts.optimize( system=triple_state, model=lj_model, - optimizer=unit_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), autobatcher=True, ) - assert isinstance(final_states, ts.SimState) + assert isinstance(final_states, SimState) for init_state, final_state in zip(states, final_states.split(), strict=True): assert torch.all(final_state.atomic_numbers == init_state.atomic_numbers) assert torch.any(final_state.positions != init_state.positions) def test_static_single( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test static calculation with LJ potential.""" traj_file = tmp_path / "static.h5md" @@ -587,7 +584,7 @@ def test_static_single( def test_static_double( - ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel, tmp_path: Path + ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: """Test static calculation with multiple systems.""" trajectory_files = [tmp_path / "static_0.h5md", tmp_path / "static_1.h5md"] @@ -619,8 +616,8 @@ def test_static_double( def test_static_with_autobatcher( - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, ) -> None: """Test static calculation with autobatcher.""" @@ -666,7 +663,7 @@ def test_static_with_autobatcher_and_reporting( s2_atoms = bulk("Cu", "fcc", a=3.6, cubic=True).repeat((2, 1, 1)) s3_atoms = bulk("Ar", "fcc", a=5.3, cubic=True) # Different params from s0_atoms - initial_sim_states: list[ts.SimState] = [] + initial_sim_states: list[SimState] = [] for idx, atoms_obj in enumerate((s0_atoms, s1_atoms, s2_atoms, s3_atoms)): sim_state_batched = ts.initialize_state( atoms_obj, device=lj_model.device, dtype=lj_model.dtype @@ -699,7 +696,6 @@ def test_static_with_autobatcher_and_reporting( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=10, - return_indices=True, ) # 4. Call ts.static with trajectory reporting @@ -760,7 +756,7 @@ def test_static_with_autobatcher_and_reporting( def test_static_no_filenames( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """Test static calculation with no trajectory filenames.""" reporter = TrajectoryReporter( @@ -784,8 +780,6 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: from ase.build import bulk - import torch_sim as ts - cu_atoms = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2)) many_cu_atoms = [cu_atoms] * 5 trajectory_files = [tmp_path / f"Cu_traj_{i}.h5md" for i in range(len(many_cu_atoms))] @@ -797,10 +791,9 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: n_steps=50, timestep=0.002, temperature=1000, - integrator=ts.nvt_langevin, + integrator=ts.MdFlavor.nvt_langevin, trajectory_reporter=dict(filenames=trajectory_files, state_frequency=10), ) - final_atoms_list = final_state.to_atoms() # noqa: F841 # extract the final energy from the trajectory file final_energies = [] @@ -814,7 +807,8 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: relaxed_state = ts.optimize( system=final_state, model=lj_model, - optimizer=ts.frechet_cell_fire, + optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, # autobatcher=True, # disabled for CPU-based LJ model in test ) @@ -824,23 +818,21 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: @pytest.fixture def mock_state() -> Callable: """Create a mock state for testing convergence functions.""" - device = torch.device("cpu") - dtype = torch.float64 n_systems, n_atoms = 2, 8 torch.manual_seed(0) # deterministic forces class MockState: def __init__(self, *, include_cell_forces: bool = True) -> None: - self.forces = torch.randn(n_atoms, 3, device=device, dtype=dtype) + self.forces = torch.randn(n_atoms, 3, device=DEVICE, dtype=DTYPE) self.system_idx = torch.repeat_interleave( torch.arange(n_systems), n_atoms // n_systems ) - self.device = device - self.dtype = dtype + self.device = DEVICE + self.dtype = DTYPE self.n_systems = n_systems if include_cell_forces: self.cell_forces = torch.randn( - n_systems, 3, 3, device=device, dtype=dtype + n_systems, 3, 3, device=DEVICE, dtype=DTYPE ) return MockState @@ -858,7 +850,7 @@ def __init__(self, *, include_cell_forces: bool = True) -> None: ) def test_generate_force_convergence_fn( *, - ar_supercell_sim_state: ts.SimState, + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel, mock_state: Callable, force_tol: float, @@ -871,20 +863,22 @@ def test_generate_force_convergence_fn( if should_error: state = mock_state(include_cell_forces=False) else: - # Prepare real state + # Create a proper state with forces from the model output model_output = lj_model(ar_supercell_sim_state) - ar_supercell_sim_state.forces = model_output["forces"] - ar_supercell_sim_state.energy = model_output["energy"] + + state = MDState.from_state( + ar_supercell_sim_state, + energy=model_output["energy"], + forces=model_output["forces"], + momenta=torch.zeros_like(ar_supercell_sim_state.positions), + ) if has_cell_forces: - ar_supercell_sim_state.cell_forces = torch.randn( - ar_supercell_sim_state.n_systems, - 3, - 3, + state.cell_forces = torch.randn( + *(ar_supercell_sim_state.n_systems, 3, 3), device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype, ) - state = ar_supercell_sim_state convergence_fn = ts.generate_force_convergence_fn( force_tol=force_tol, include_cell_forces=include_cell_forces @@ -901,13 +895,18 @@ def test_generate_force_convergence_fn( def test_generate_force_convergence_fn_tolerance_ordering( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """Test that higher tolerances are less restrictive than lower ones.""" model_output = lj_model(ar_supercell_sim_state) - ar_supercell_sim_state.forces = model_output["forces"] - ar_supercell_sim_state.energy = model_output["energy"] - ar_supercell_sim_state.cell_forces = torch.randn( + + test_state = MDState.from_state( + ar_supercell_sim_state, + energy=model_output["energy"], + forces=model_output["forces"], + momenta=torch.zeros_like(ar_supercell_sim_state.positions), + ) + test_state.cell_forces = torch.randn( ar_supercell_sim_state.n_systems, 3, 3, @@ -917,8 +916,7 @@ def test_generate_force_convergence_fn_tolerance_ordering( tolerances = [1e-4, 1e-2, 1e0, 1e2] results = [ - ts.generate_force_convergence_fn(force_tol=tol)(ar_supercell_sim_state) - for tol in tolerances + ts.generate_force_convergence_fn(force_tol=tol)(test_state) for tol in tolerances ] # If converged at lower tolerance, must be converged at higher tolerance @@ -960,12 +958,12 @@ def __init__(self) -> None: self.forces = torch.zeros(n_atoms, 3, device=device, dtype=dtype) self.cell_forces = torch.zeros(n_systems, 3, 3, device=device, dtype=dtype) - for system_idx, (atomic_force, cell_force) in enumerate( + for sys_idx, (atomic_force, cell_force) in enumerate( zip(atomic_forces, cell_forces, strict=False) ): - system_mask = self.system_idx == system_idx + system_mask = self.system_idx == sys_idx self.forces[system_mask, 0] = atomic_force - self.cell_forces[system_idx, 0, 0] = cell_force + self.cell_forces[sys_idx, 0, 0] = cell_force state = ControlledMockState() convergence_fn = ts.generate_force_convergence_fn( @@ -977,21 +975,26 @@ def __init__(self) -> None: def test_generate_force_convergence_fn_ignores_last_energy( - ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """Test that convergence function ignores last_energy parameter.""" model_output = lj_model(ar_supercell_sim_state) - ar_supercell_sim_state.forces = model_output["forces"] - ar_supercell_sim_state.energy = model_output["energy"] + + test_state = MDState.from_state( + ar_supercell_sim_state, + energy=model_output["energy"], + forces=model_output["forces"], + momenta=torch.zeros_like(ar_supercell_sim_state.positions), + ) convergence_fn = ts.generate_force_convergence_fn( force_tol=1e-2, include_cell_forces=False ) results = [ - convergence_fn(ar_supercell_sim_state), - convergence_fn(ar_supercell_sim_state, last_energy=torch.tensor([1.0])), - convergence_fn(ar_supercell_sim_state, last_energy=None), + convergence_fn(test_state), + convergence_fn(test_state, last_energy=torch.tensor([1.0])), + convergence_fn(test_state, last_energy=None), ] # All results should be identical diff --git a/tests/test_state.py b/tests/test_state.py index 67a757fed..1657b3821 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -5,6 +5,7 @@ import torch import torch_sim as ts +from tests.conftest import DEVICE from torch_sim.integrators import MDState from torch_sim.state import ( DeformGradMixin, @@ -12,9 +13,7 @@ _normalize_system_indices, _pop_states, _slice_state, - concatenate_states, get_attrs_for_scope, - initialize_state, ) @@ -24,25 +23,20 @@ from pymatgen.core import Structure -def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: +def test_get_attrs_for_scope(si_sim_state: SimState) -> None: """Test getting attributes for a scope.""" per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) - assert set(per_atom_attrs.keys()) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - } + assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs.keys()) == {"cell"} + assert set(per_system_attrs) == {"cell"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs.keys()) == {"pbc"} + assert set(global_attrs) == {"pbc"} def test_all_attributes_must_be_specified_in_scopes() -> None: """Test that an error is raised when we forget to specify the scope for an attribute in a child SimState class.""" - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError) as exc_info: class ChildState(SimState): attribute_specified_in_scopes: bool @@ -52,15 +46,15 @@ class ChildState(SimState): SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 ) - assert "attribute_not_specified_in_scopes" in str(excinfo.value) - assert "attribute_specified_in_scopes" not in str(excinfo.value) + assert "attribute_not_specified_in_scopes" in str(exc_info.value) + assert "attribute_specified_in_scopes" not in str(exc_info.value) def test_no_duplicate_attributes_in_scopes() -> None: """Test that no attributes are specified in multiple scopes.""" - # Capture the exception information using "as excinfo" - with pytest.raises(TypeError) as excinfo: + # Capture the exception information using "as exc_info" + with pytest.raises(TypeError) as exc_info: class ChildState(SimState): duplicated_attribute: bool @@ -68,13 +62,11 @@ class ChildState(SimState): _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 - assert "are declared multiple times" in str(excinfo.value) - assert "duplicated_attribute" in str(excinfo.value) + assert "are declared multiple times" in str(exc_info.value) + assert "duplicated_attribute" in str(exc_info.value) -def test_slice_substate( - si_double_sim_state: ts.SimState, si_sim_state: ts.SimState -) -> None: +def test_slice_substate(si_double_sim_state: SimState, si_sim_state: SimState) -> None: """Test slicing a substate from the SimState.""" for system_index in range(2): substate = _slice_state(si_double_sim_state, [system_index]) @@ -89,7 +81,7 @@ def test_slice_substate( assert torch.allclose(substate.system_idx, torch.zeros_like(substate.system_idx)) -def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: +def test_slice_md_substate(si_double_sim_state: SimState) -> None: state = MDState( **asdict(si_double_sim_state), momenta=torch.randn_like(si_double_sim_state.positions), @@ -108,11 +100,11 @@ def test_slice_md_substate(si_double_sim_state: ts.SimState) -> None: def test_concatenate_two_si_states( - si_sim_state: ts.SimState, si_double_sim_state: ts.SimState + si_sim_state: SimState, si_double_sim_state: SimState ) -> None: """Test concatenating two identical silicon states.""" # Concatenate two copies of the sim state - concatenated = concatenate_states([si_sim_state, si_sim_state]) + concatenated = ts.concatenate_states([si_sim_state, si_sim_state]) # Check that the result is the same as the double state assert isinstance(concatenated, SimState) @@ -123,22 +115,19 @@ def test_concatenate_two_si_states( assert concatenated.system_idx.shape == si_double_sim_state.system_idx.shape # Check system indices + tensor_args = dict(dtype=torch.int64, device=si_sim_state.device) expected_system_indices = torch.cat( [ - torch.zeros( - si_sim_state.n_atoms, dtype=torch.int64, device=si_sim_state.device - ), - torch.ones( - si_sim_state.n_atoms, dtype=torch.int64, device=si_sim_state.device - ), + torch.zeros(si_sim_state.n_atoms, **tensor_args), + torch.ones(si_sim_state.n_atoms, **tensor_args), ] ) assert torch.all(concatenated.system_idx == expected_system_indices) # Check that positions match (accounting for system indices) - for system_idx in range(2): - mask_concat = concatenated.system_idx == system_idx - mask_double = si_double_sim_state.system_idx == system_idx + for sys_idx in range(2): + mask_concat = concatenated.system_idx == sys_idx + mask_double = si_double_sim_state.system_idx == sys_idx assert torch.allclose( concatenated.positions[mask_concat], si_double_sim_state.positions[mask_double], @@ -146,11 +135,11 @@ def test_concatenate_two_si_states( def test_concatenate_si_and_fe_states( - si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState + si_sim_state: SimState, fe_supercell_sim_state: SimState ) -> None: """Test concatenating silicon and argon states.""" # Concatenate silicon and argon states - concatenated = concatenate_states([si_sim_state, fe_supercell_sim_state]) + concatenated = ts.concatenate_states([si_sim_state, fe_supercell_sim_state]) # Check basic properties assert isinstance(concatenated, SimState) @@ -196,11 +185,11 @@ def test_concatenate_si_and_fe_states( def test_concatenate_double_si_and_fe_states( - si_double_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState + si_double_sim_state: SimState, fe_supercell_sim_state: SimState ) -> None: """Test concatenating a double silicon state and an argon state.""" # Concatenate double silicon and argon states - concatenated = concatenate_states([si_double_sim_state, fe_supercell_sim_state]) + concatenated = ts.concatenate_states([si_double_sim_state, fe_supercell_sim_state]) # Check basic properties assert isinstance(concatenated, SimState) @@ -239,12 +228,12 @@ def test_concatenate_double_si_and_fe_states( assert torch.allclose(fe_slice.positions, fe_supercell_sim_state.positions) -def test_split_state(si_double_sim_state: ts.SimState) -> None: +def test_split_state(si_double_sim_state: SimState) -> None: """Test splitting a state into a list of states.""" states = si_double_sim_state.split() assert len(states) == si_double_sim_state.n_systems for state in states: - assert isinstance(state, ts.SimState) + assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) @@ -253,13 +242,13 @@ def test_split_state(si_double_sim_state: ts.SimState) -> None: def test_split_many_states( - si_sim_state: ts.SimState, - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, ) -> None: """Test splitting a state into a list of states.""" states = [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] - concatenated = concatenate_states(states) + concatenated = ts.concatenate_states(states) split_states = concatenated.split() for state, sub_state in zip(states, split_states, strict=True): assert isinstance(sub_state, SimState) @@ -273,13 +262,13 @@ def test_split_many_states( def test_pop_states( - si_sim_state: ts.SimState, - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, ) -> None: """Test popping states from a state.""" states = [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] - concatenated_states = concatenate_states(states) + concatenated_states = ts.concatenate_states(states) kept_state, popped_states = _pop_states( concatenated_states, torch.tensor([0], device=concatenated_states.device) ) @@ -298,56 +287,50 @@ def test_pop_states( assert kept_state.system_idx.shape == (len_kept,) -def test_initialize_state_from_structure( - si_structure: "Structure", device: torch.device -) -> None: +def test_initialize_state_from_structure(si_structure: "Structure") -> None: """Test conversion from pymatgen Structure to state tensors.""" - state = initialize_state([si_structure], device, torch.float64) - assert isinstance(state, ts.SimState) + state = ts.initialize_state([si_structure], DEVICE, torch.float64) + assert isinstance(state, SimState) assert state.positions.shape == si_structure.cart_coords.shape assert state.cell.shape[1:] == si_structure.lattice.matrix.shape -def test_initialize_state_from_state( - ar_supercell_sim_state: ts.SimState, device: torch.device -) -> None: +def test_initialize_state_from_state(ar_supercell_sim_state: SimState) -> None: """Test conversion from SimState to SimState.""" - state = initialize_state(ar_supercell_sim_state, device, torch.float64) - assert isinstance(state, ts.SimState) + state = ts.initialize_state(ar_supercell_sim_state, DEVICE, torch.float64) + assert isinstance(state, SimState) assert state.positions.shape == ar_supercell_sim_state.positions.shape assert state.masses.shape == ar_supercell_sim_state.masses.shape assert state.cell.shape == ar_supercell_sim_state.cell.shape -def test_initialize_state_from_atoms(si_atoms: "Atoms", device: torch.device) -> None: +def test_initialize_state_from_atoms(si_atoms: "Atoms") -> None: """Test conversion from ASE Atoms to SimState.""" - state = initialize_state([si_atoms], device, torch.float64) - assert isinstance(state, ts.SimState) + state = ts.initialize_state([si_atoms], DEVICE, torch.float64) + assert isinstance(state, SimState) assert state.positions.shape == si_atoms.positions.shape assert state.masses.shape == si_atoms.get_masses().shape assert state.cell.shape[1:] == si_atoms.cell.array.T.shape -def test_initialize_state_from_phonopy_atoms( - si_phonopy_atoms: "PhonopyAtoms", device: torch.device -) -> None: +def test_initialize_state_from_phonopy_atoms(si_phonopy_atoms: "PhonopyAtoms") -> None: """Test conversion from PhonopyAtoms to SimState.""" - state = initialize_state([si_phonopy_atoms], device, torch.float64) - assert isinstance(state, ts.SimState) + state = ts.initialize_state([si_phonopy_atoms], DEVICE, torch.float64) + assert isinstance(state, SimState) assert state.positions.shape == si_phonopy_atoms.positions.shape assert state.masses.shape == si_phonopy_atoms.masses.shape assert state.cell.shape[1:] == si_phonopy_atoms.cell.shape def test_state_pop_method( - si_sim_state: ts.SimState, - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, ) -> None: """Test the pop method of SimState.""" # Create a concatenated state states = [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] - concatenated = concatenate_states(states) + concatenated = ts.concatenate_states(states) # Test popping a single batch popped_states = concatenated.pop(1) @@ -360,7 +343,7 @@ def test_state_pop_method( assert torch.unique(concatenated.system_idx).tolist() == [0, 1] # Test popping multiple batches - multi_state = concatenate_states(states) + multi_state = ts.concatenate_states(states) popped_multi = multi_state.pop([0, 2]) assert len(popped_multi) == 2 assert torch.allclose(popped_multi[0].positions, si_sim_state.positions) @@ -373,14 +356,14 @@ def test_state_pop_method( def test_state_getitem( - si_sim_state: ts.SimState, - ar_supercell_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, ) -> None: """Test the __getitem__ method of SimState.""" # Create a concatenated state states = [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] - concatenated = concatenate_states(states) + concatenated = ts.concatenate_states(states) # Test integer indexing single_state = concatenated[1] @@ -420,11 +403,10 @@ def test_state_getitem( assert concatenated.n_systems == 3 -def test_normalize_system_indices(si_double_sim_state: ts.SimState) -> None: +def test_normalize_system_indices(si_double_sim_state: SimState) -> None: """Test the _normalize_system_indices utility method.""" state = si_double_sim_state # State with 2 batches - n_systems = state.n_systems - device = state.device + n_systems, device = state.n_systems, state.device # Test integer indexing assert _normalize_system_indices(0, n_systems, device).tolist() == [0] @@ -476,7 +458,7 @@ def test_normalize_system_indices(si_double_sim_state: ts.SimState) -> None: pass -def test_row_vector_cell(si_sim_state: ts.SimState) -> None: +def test_row_vector_cell(si_sim_state: SimState) -> None: """Test the row_vector_cell property getter and setter.""" # Test getter - should return transposed cell original_cell = si_sim_state.cell.clone() @@ -492,7 +474,7 @@ def test_row_vector_cell(si_sim_state: ts.SimState) -> None: assert torch.allclose(si_sim_state.row_vector_cell, new_cell.mT) -def test_column_vector_cell(si_sim_state: ts.SimState) -> None: +def test_column_vector_cell(si_sim_state: SimState) -> None: """Test the column_vector_cell property getter and setter.""" # Test getter - should return cell directly since it's already in column vector format original_cell = si_sim_state.cell.clone() @@ -517,11 +499,7 @@ class DeformState(SimState, DeformGradMixin): ) def __init__( - self, - *args, - velocities: torch.Tensor, - reference_cell: torch.Tensor, - **kwargs, + self, *args, velocities: torch.Tensor, reference_cell: torch.Tensor, **kwargs ) -> None: super().__init__(*args, **kwargs) self.velocities = velocities @@ -529,13 +507,13 @@ def __init__( @pytest.fixture -def deform_grad_state(device: torch.device) -> DeformState: +def deform_grad_state() -> DeformState: """Create a test state with deformation gradient support.""" - positions = torch.randn(10, 3, device=device) - masses = torch.ones(10, device=device) - velocities = torch.randn(10, 3, device=device) - reference_cell = torch.eye(3, device=device).unsqueeze(0) + positions = torch.randn(10, 3, device=DEVICE) + masses = torch.ones(10, device=DEVICE) + velocities = torch.randn(10, 3, device=DEVICE) + reference_cell = torch.eye(3, device=DEVICE).unsqueeze(0) current_cell = 2 * reference_cell return DeformState( @@ -543,7 +521,7 @@ def deform_grad_state(device: torch.device) -> DeformState: masses=masses, cell=current_cell, pbc=True, - atomic_numbers=torch.ones(10, device=device, dtype=torch.long), + atomic_numbers=torch.ones(10, device=DEVICE, dtype=torch.long), velocities=velocities, reference_cell=reference_cell, ) @@ -572,20 +550,20 @@ def test_deform_grad_uniform(deform_grad_state: DeformState) -> None: assert torch.allclose(deform_grad, expected) -def test_deform_grad_non_uniform(device: torch.device) -> None: +def test_deform_grad_non_uniform() -> None: """Test deformation gradient calculation for non-uniform deformation.""" - reference_cell = torch.eye(3, device=device).unsqueeze(0) + reference_cell = torch.eye(3, device=DEVICE).unsqueeze(0) current_cell = torch.tensor( - [[[2.0, 0.1, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 1.8]]], device=device + [[[2.0, 0.1, 0.0], [0.1, 1.5, 0.0], [0.0, 0.0, 1.8]]], device=DEVICE ) state = DeformState( - positions=torch.randn(10, 3, device=device), - masses=torch.ones(10, device=device), + positions=torch.randn(10, 3, device=DEVICE), + masses=torch.ones(10, device=DEVICE), cell=current_cell, pbc=True, - atomic_numbers=torch.ones(10, device=device, dtype=torch.long), - velocities=torch.randn(10, 3, device=device), + atomic_numbers=torch.ones(10, device=DEVICE, dtype=torch.long), + velocities=torch.randn(10, 3, device=DEVICE), reference_cell=reference_cell, ) @@ -595,75 +573,49 @@ def test_deform_grad_non_uniform(device: torch.device) -> None: assert torch.allclose(reconstructed_cell, current_cell) -def test_deform_grad_batched(device: torch.device) -> None: +def test_deform_grad_batched() -> None: """Test deformation gradient calculation with batched states.""" - batch_size = 3 - n_atoms = 10 + batch_size, n_atoms = 3, 10 - reference_cell = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1) + reference_cell = torch.eye(3, device=DEVICE).unsqueeze(0).repeat(batch_size, 1, 1) current_cell = torch.stack( [ - 2.0 * torch.eye(3, device=device), # Uniform expansion - torch.eye(3, device=device), # No deformation - 0.5 * torch.eye(3, device=device), # Uniform compression + 2.0 * torch.eye(3, device=DEVICE), # Uniform expansion + torch.eye(3, device=DEVICE), # No deformation + 0.5 * torch.eye(3, device=DEVICE), # Uniform compression ] ) state = DeformState( - positions=torch.randn(n_atoms * batch_size, 3, device=device), - masses=torch.ones(n_atoms * batch_size, device=device), + positions=torch.randn(n_atoms * batch_size, 3, device=DEVICE), + masses=torch.ones(n_atoms * batch_size, device=DEVICE), cell=current_cell, pbc=True, - atomic_numbers=torch.ones(n_atoms * batch_size, device=device, dtype=torch.long), - velocities=torch.randn(n_atoms * batch_size, 3, device=device), + atomic_numbers=torch.ones(n_atoms * batch_size, device=DEVICE, dtype=torch.long), + velocities=torch.randn(n_atoms * batch_size, 3, device=DEVICE), reference_cell=reference_cell, system_idx=torch.repeat_interleave( - torch.arange(batch_size, device=device), n_atoms + torch.arange(batch_size, device=DEVICE), n_atoms ), ) deform_grad = state.deform_grad() assert deform_grad.shape == (batch_size, 3, 3) - expected_factors = torch.tensor([2.0, 1.0, 0.5], device=device) - for i in range(batch_size): - expected = expected_factors[i] * torch.eye(3, device=device) - assert torch.allclose(deform_grad[i], expected) - - -def test_deprecated_batch_properties_equal_to_new_system_properties( - device: torch.device, -) -> None: - """Test that deprecated batch properties are equal to new system properties. - - This tests that the rename from batch to system is not breaking anything.""" - state = SimState( - positions=torch.randn(10, 3, device=device), - masses=torch.ones(10, device=device), - cell=torch.eye(3, device=device).unsqueeze(0).repeat(2, 1, 1), - pbc=True, - atomic_numbers=torch.ones(10, device=device, dtype=torch.long), - system_idx=torch.repeat_interleave(torch.arange(2, device=device), 5), - ) - assert state.batch is state.system_idx - assert state.n_batches == state.n_systems - assert torch.allclose(state.n_atoms_per_batch, state.n_atoms_per_system) - - # now test that assigning the old .batch property behaves the same - new_system_idx = torch.arange(4, device=device) - state.batch = new_system_idx - assert torch.allclose(state.system_idx, new_system_idx) - assert torch.allclose(state.batch, new_system_idx) + expected_factors = torch.tensor([2.0, 1.0, 0.5], device=DEVICE) + for batch_idx in range(batch_size): + expected = expected_factors[batch_idx] * torch.eye(3, device=DEVICE) + assert torch.allclose(deform_grad[batch_idx], expected) def test_derived_classes_trigger_init_subclass() -> None: """Test that derived classes cannot have attributes that are "tensors | None".""" - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError) as exc_info: class DerivedState(SimState): invalid_attr: torch.Tensor | None = None assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( - excinfo.value + exc_info.value ) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 67d9de1f6..e7bd4531b 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -11,6 +11,7 @@ from torch_sim.integrators import MDState from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.state import SimState from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter @@ -28,9 +29,7 @@ def random_state() -> MDState: momenta=torch.randn(10, 3), energy=torch.tensor(1.0), forces=torch.randn(10, 3), - masses=torch.ones( - 10, - ), + masses=torch.ones(10), cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), atomic_numbers=torch.ones(10, dtype=torch.int32), system_idx=torch.zeros(10, dtype=torch.int32), @@ -39,7 +38,7 @@ def random_state() -> MDState: @pytest.fixture -def trajectory(test_file: Path) -> Generator[TorchSimTrajectory, None, None]: +def trajectory(test_file: Path) -> Generator[TorchSimTrajectory]: """Create a trajectory file for testing.""" traj = TorchSimTrajectory(test_file, compress_data=True, mode="w") yield traj @@ -49,7 +48,7 @@ def trajectory(test_file: Path) -> Generator[TorchSimTrajectory, None, None]: def test_initialization(test_file: Path) -> None: """Test trajectory file initialization.""" traj = TorchSimTrajectory(test_file, mode="w") - assert os.path.exists(test_file) + assert os.path.isfile(test_file) assert traj._file.isopen # noqa: SLF001 traj.close() @@ -356,12 +355,12 @@ def test_invalid_dtype_handling(test_file: Path) -> None: complex_data = { "complex": np.random.default_rng(seed=0).random((10, 3)).astype(np.float16) } - with pytest.raises(ValueError, match="Unsupported array.dtype="): + with pytest.raises(ValueError, match=r"Unsupported array.dtype="): traj.write_arrays(complex_data, steps=0) # Test string data string_data = {"strings": np.array([["a", "b", "c"]] * 10)} - with pytest.raises(ValueError, match="Unsupported array.dtype="): + with pytest.raises(ValueError, match=r"Unsupported array.dtype="): traj.write_arrays(string_data, steps=0) traj.close() @@ -370,10 +369,7 @@ def test_invalid_dtype_handling(test_file: Path) -> None: def test_scalar_dtype_handling(test_file: Path) -> None: """Test handling of scalar values with different dtypes.""" traj = TorchSimTrajectory( - test_file, - coerce_to_float32=True, - coerce_to_int32=True, - mode="w", + test_file, coerce_to_float32=True, coerce_to_int32=True, mode="w" ) scalar_data = { @@ -532,12 +528,9 @@ def prop_calculators() -> dict[int, dict[str, Callable]]: } -def test_report_no_properties(si_sim_state: ts.SimState, tmp_path: Path) -> None: +def test_report_no_properties(si_sim_state: SimState, tmp_path: Path) -> None: """Test TrajectoryReporter with no properties.""" - reporter = TrajectoryReporter( - tmp_path / "no_properties.hdf5", - state_frequency=1, - ) + reporter = TrajectoryReporter(tmp_path / "no_properties.hdf5", state_frequency=1) # Run several steps for step in range(5): reporter.report(si_sim_state, step) @@ -545,7 +538,7 @@ def test_report_no_properties(si_sim_state: ts.SimState, tmp_path: Path) -> None reporter.close() # Verify file was created - assert os.path.exists(tmp_path / "no_properties.hdf5") + assert os.path.isfile(tmp_path / "no_properties.hdf5") # Open trajectory and check contents trajectory = TorchSimTrajectory(tmp_path / "no_properties.hdf5", mode="r") @@ -557,11 +550,9 @@ def test_report_no_properties(si_sim_state: ts.SimState, tmp_path: Path) -> None assert "atomic_numbers" in trajectory.array_registry -def test_report_no_filenames(si_sim_state: ts.SimState, prop_calculators: dict) -> None: +def test_report_no_filenames(si_sim_state: SimState, prop_calculators: dict) -> None: """Test TrajectoryReporter with no filenames.""" - from torch_sim.state import initialize_state - - triple_state = initialize_state( + triple_state = ts.initialize_state( [si_sim_state.clone() for _ in range(3)], device=si_sim_state.device, dtype=si_sim_state.dtype, @@ -587,7 +578,7 @@ def test_report_no_filenames(si_sim_state: ts.SimState, prop_calculators: dict) def test_single_batch_reporter( - si_sim_state: ts.SimState, tmp_path: Path, prop_calculators: dict + si_sim_state: SimState, tmp_path: Path, prop_calculators: dict ) -> None: """Test TrajectoryReporter with a single batch.""" # Create a reporter with a single file @@ -604,7 +595,7 @@ def test_single_batch_reporter( reporter.close() # Verify file was created - assert os.path.exists(tmp_path / "single_batch.hdf5") + assert os.path.isfile(tmp_path / "single_batch.hdf5") # Open trajectory and check contents trajectory = TorchSimTrajectory(tmp_path / "single_batch.hdf5", mode="r") @@ -624,7 +615,7 @@ def test_single_batch_reporter( def test_multi_batch_reporter_filenames_none( - si_double_sim_state: ts.SimState, prop_calculators: dict + si_double_sim_state: SimState, prop_calculators: dict ) -> None: """Test TrajectoryReporter with multiple batches and no filenames.""" reporter = TrajectoryReporter( @@ -649,7 +640,7 @@ def test_multi_batch_reporter_filenames_none( def test_multi_batch_reporter( - si_double_sim_state: ts.SimState, tmp_path: Path, prop_calculators: dict + si_double_sim_state: SimState, tmp_path: Path, prop_calculators: dict ) -> None: """Test TrajectoryReporter with multiple batches.""" # Create a reporter with multiple files @@ -666,8 +657,8 @@ def test_multi_batch_reporter( reporter.close() # Verify files were created - assert os.path.exists(tmp_path / "batch0.hdf5") - assert os.path.exists(tmp_path / "batch1.hdf5") + assert os.path.isfile(tmp_path / "batch0.hdf5") + assert os.path.isfile(tmp_path / "batch1.hdf5") # Open trajectories and check contents traj0 = TorchSimTrajectory(tmp_path / "batch0.hdf5", mode="r") @@ -694,7 +685,7 @@ def test_multi_batch_reporter( def test_property_model_consistency( - si_double_sim_state: ts.SimState, tmp_path: Path, prop_calculators: dict + si_double_sim_state: SimState, tmp_path: Path, prop_calculators: dict ) -> None: """Test property models are consistent for single and multi-batch cases.""" # Create reporters for single and multi-batch cases @@ -744,12 +735,12 @@ def test_property_model_consistency( def test_reporter_with_model( - si_double_sim_state: ts.SimState, tmp_path: Path, lj_model: LennardJonesModel + si_double_sim_state: SimState, tmp_path: Path, lj_model: LennardJonesModel ) -> None: """Test TrajectoryReporter with a model argument in property calculators.""" # Create a property calculator that uses the model - def energy_calculator(state: ts.SimState, model: ModelInterface) -> torch.Tensor: + def energy_calculator(state: SimState, model: ModelInterface) -> torch.Tensor: output = model(state) # Calculate a property that depends on the model return output["energy"] @@ -781,13 +772,13 @@ def energy_calculator(state: ts.SimState, model: ModelInterface) -> torch.Tensor TorchSimTrajectory(tmp_path / "model_1.hdf5", mode="r"), ] - for system_idx, trajectory in enumerate(trajectories): + for sys_idx, trajectory in enumerate(trajectories): # Get the property value from file file_energy = trajectory.get_array("energy")[0] - system_props = props[system_idx] + system_props = props[sys_idx] # Calculate expected value - substate = si_double_sim_state[system_idx] + substate = si_double_sim_state[sys_idx] expected = lj_model(substate)["energy"] # Compare file contents with expected @@ -806,7 +797,7 @@ def test_get_atoms_importerror(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) traj = TorchSimTrajectory(tmp_path / "dummy.h5", mode="w") # Write minimal data so get_atoms can be called - state = ts.SimState( + state = SimState( positions=torch.zeros(1, 3), masses=torch.ones(1), cell=torch.eye(3).unsqueeze(0), @@ -830,7 +821,7 @@ def test_write_ase_trajectory_importerror( traj = TorchSimTrajectory(tmp_path / "dummy.h5", mode="w") # Write minimal data so write_ase_trajectory can be called - state = ts.SimState( + state = SimState( positions=torch.zeros(1, 3), masses=torch.ones(1), cell=torch.eye(3).unsqueeze(0), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c9317cdf9..3b76145e3 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,7 +5,8 @@ from ase.geometry import wrap_positions as ase_wrap_positions import torch_sim as ts -import torch_sim.transforms as tst +import torch_sim.transforms as ft +from tests.conftest import DEVICE, DTYPE def test_inverse_box_scalar() -> None: @@ -15,7 +16,7 @@ def test_inverse_box_scalar() -> None: """ # Test scalar inverse x = torch.tensor(2.0) - assert torch.allclose(tst.inverse_box(x), torch.tensor(0.5)) + assert torch.allclose(ft.inverse_box(x), torch.tensor(0.5)) def test_inverse_box_vector() -> None: @@ -26,7 +27,7 @@ def test_inverse_box_vector() -> None: # Test vector inverse x = torch.tensor([2.0, 4.0]) expected = torch.tensor([0.5, 0.25]) - assert torch.allclose(tst.inverse_box(x), expected) + assert torch.allclose(ft.inverse_box(x), expected) def test_inverse_box_matrix() -> None: @@ -37,7 +38,7 @@ def test_inverse_box_matrix() -> None: # Test matrix inverse x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) expected = torch.tensor([[-2.0, 1.0], [1.5, -0.5]]) - assert torch.allclose(tst.inverse_box(x), expected) + assert torch.allclose(ft.inverse_box(x), expected) def test_inverse_box_invalid() -> None: @@ -48,7 +49,7 @@ def test_inverse_box_invalid() -> None: # Test invalid input (3D tensor) x = torch.ones(2, 2, 2) with pytest.raises(ValueError): - tst.inverse_box(x) + ft.inverse_box(x) def test_inverse_box_single_element() -> None: @@ -58,7 +59,7 @@ def test_inverse_box_single_element() -> None: """ # Test single element tensor x = torch.tensor([2.0]) - assert torch.allclose(tst.inverse_box(x), torch.tensor(0.5)) + assert torch.allclose(ft.inverse_box(x), torch.tensor(0.5)) def test_pbc_wrap_general_orthorhombic() -> None: @@ -90,7 +91,7 @@ def test_pbc_wrap_general_orthorhombic() -> None: ] ) - wrapped = tst.pbc_wrap_general(positions, lattice) + wrapped = ft.pbc_wrap_general(positions, lattice) assert torch.allclose(wrapped, expected) @@ -120,7 +121,7 @@ def test_pbc_wrap_general_param(cell: torch.Tensor, shift: torch.Tensor) -> None base_frac = torch.tensor([[0.25, 0.5, 0.75]], dtype=torch.float64) base_cart = base_frac @ cell.T shifted_cart = base_cart + (shift @ cell.T) - wrapped = tst.pbc_wrap_general(shifted_cart, cell) + wrapped = ft.pbc_wrap_general(shifted_cart, cell) torch.testing.assert_close(wrapped, base_cart, rtol=1e-6, atol=1e-6) @@ -141,7 +142,7 @@ def test_pbc_wrap_general_edge_case() -> None: expected = torch.tensor([[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]]) - wrapped = tst.pbc_wrap_general(positions, lattice) + wrapped = ft.pbc_wrap_general(positions, lattice) assert torch.allclose(wrapped, expected) @@ -155,15 +156,15 @@ def test_pbc_wrap_general_invalid_inputs() -> None: """ # Test integer tensors with pytest.raises(TypeError): - tst.pbc_wrap_general(torch.ones(3, dtype=torch.int64), torch.eye(3)) + ft.pbc_wrap_general(torch.ones(3, dtype=torch.int64), torch.eye(3)) # Test non-square lattice with pytest.raises(ValueError): - tst.pbc_wrap_general(torch.ones(3), torch.ones(3, 2)) + ft.pbc_wrap_general(torch.ones(3), torch.ones(3, 2)) # Test dimension mismatch with pytest.raises(ValueError): - tst.pbc_wrap_general(torch.ones(4), torch.eye(3)) + ft.pbc_wrap_general(torch.ones(4), torch.eye(3)) def test_pbc_wrap_general_batch() -> None: @@ -189,7 +190,7 @@ def test_pbc_wrap_general_batch() -> None: ] ) - wrapped = tst.pbc_wrap_general(positions, lattice) + wrapped = ft.pbc_wrap_general(positions, lattice) assert torch.allclose(wrapped, expected) @@ -206,7 +207,7 @@ def test_wrap_positions_matches_ase( cell = torch.eye(3) + 0.1 * torch.randn(3, 3) # Run both implementations - torch_result = tst.wrap_positions( + torch_result = ft.wrap_positions( positions, cell, pbc=pbc, pretty_translation=pretty_translation ) @@ -218,11 +219,13 @@ def test_wrap_positions_matches_ase( def test_wrap_positions_basic(): - pos = torch.tensor([[-0.1, 1.01, -0.5]], dtype=torch.float32) - cell = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]) + pos = torch.tensor([[-0.1, 1.01, -0.5]], dtype=torch.float64) + cell = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]], dtype=torch.float64 + ) - wrapped = tst.wrap_positions(pos, cell, pbc=[True, True, False]) - expected = torch.tensor([[0.9, 0.01, -0.5]]) + wrapped = ft.wrap_positions(pos, cell, pbc=[True, True, False]) + expected = torch.tensor([[0.9, 0.01, -0.5]], dtype=torch.float64) torch.testing.assert_close(wrapped, expected, rtol=1e-6, atol=1e-6) @@ -231,7 +234,7 @@ def test_translate_pretty(): coords = torch.tensor([[0.1, 1.2, -0.3], [0.7, 0.8, 0.9]]) pbc = [True, True, True] - translated = tst.translate_pretty(coords, pbc) + translated = ft.translate_pretty(coords, pbc) # Check that differences between coordinates are preserved orig_diff = (coords[1] - coords[0]) % 1.0 @@ -268,7 +271,7 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None test_positions[idx1, 0] = -0.5 # Apply wrapping - wrapped = tst.pbc_wrap_batched( + wrapped = ft.pbc_wrap_batched( test_positions, cell=state.cell, system_idx=state.system_idx ) @@ -281,7 +284,7 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: ts.SimState) -> None assert wrapped[idx1, 0] >= 0 -def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: +def test_pbc_wrap_batched_triclinic() -> None: """Test batched periodic boundary wrapping with triclinic cell.""" # Define cell matrices (M_row convention) cell1 = torch.tensor( @@ -291,7 +294,7 @@ def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: [0.0, 0.3, 2.0], # c vector with b-tilt ], dtype=torch.float64, - device=device, + device=DEVICE, ) cell2 = torch.tensor( [ @@ -300,7 +303,7 @@ def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: [0.0, 0.0, 2.0], # c vector ], dtype=torch.float64, - device=device, + device=DEVICE, ) cell = torch.stack([cell1, cell2]) @@ -311,29 +314,29 @@ def test_pbc_wrap_batched_triclinic(device: torch.device) -> None: [2.7, 2.7, 2.7], # Atom 1 (batch 1) ], dtype=torch.float64, - device=device, + device=DEVICE, ) - batch = torch.tensor([0, 1], device=device) + batch = torch.tensor([0, 1], device=DEVICE) # Stack the cells for batched processing cell = torch.stack([cell1, cell2]) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, system_idx=batch) + wrapped = ft.pbc_wrap_batched(positions, cell=cell, system_idx=batch) # Calculate expected result for first atom (using original algorithm for verification) - expected1 = tst.pbc_wrap_general(positions[0:1], cell1) - expected2 = tst.pbc_wrap_general(positions[1:2], cell2) + expected1 = ft.pbc_wrap_general(positions[0:1], cell1) + expected2 = ft.pbc_wrap_general(positions[1:2], cell2) # Verify results match the expected values assert torch.allclose(wrapped[0:1], expected1, atol=1e-6) assert torch.allclose(wrapped[1:2], expected2, atol=1e-6) -def test_pbc_wrap_batched_edge_case(device: torch.device) -> None: +def test_pbc_wrap_batched_edge_case() -> None: """Test batched boundary wrapping at cell edges.""" # Create two identical cells - cell = torch.eye(3, device=device) * 2.0 + cell = torch.eye(3, device=DEVICE) * 2.0 cell = torch.stack([cell, cell]) # Create positions at cell boundaries @@ -342,14 +345,14 @@ def test_pbc_wrap_batched_edge_case(device: torch.device) -> None: [2.0, 1.0, 0.5], # First atom (batch 0), on +x boundary [1.0, 2.0, 0.5], # Second atom (batch 1), on +y boundary ], - device=device, + device=DEVICE, ) # Create system indices - system_idx = torch.tensor([0, 1], device=device) + system_idx = torch.tensor([0, 1], device=DEVICE) # Apply wrapping - wrapped = tst.pbc_wrap_batched(positions, cell=cell, system_idx=system_idx) + wrapped = ft.pbc_wrap_batched(positions, cell=cell, system_idx=system_idx) # Expected results (wrapping to 0.0 rather than 2.0) expected = torch.tensor( @@ -357,39 +360,39 @@ def test_pbc_wrap_batched_edge_case(device: torch.device) -> None: [0.0, 1.0, 0.5], # x-coordinate wrapped from 2.0 to 0.0 [1.0, 0.0, 0.5], # y-coordinate wrapped from 2.0 to 0.0 ], - device=device, + device=DEVICE, ) # Verify results assert torch.allclose(wrapped, expected) -def test_pbc_wrap_batched_invalid_inputs(device: torch.device) -> None: +def test_pbc_wrap_batched_invalid_inputs() -> None: """Test error handling for invalid inputs in batched wrapping.""" # Valid inputs for reference - positions = torch.ones(4, 3, device=device) - cell = torch.stack([torch.eye(3, device=device)] * 2) - system_idx = torch.tensor([0, 0, 1, 1], device=device) + positions = torch.ones(4, 3, device=DEVICE) + cell = torch.stack([torch.eye(3, device=DEVICE)] * 2) + system_idx = torch.tensor([0, 0, 1, 1], device=DEVICE) # Test integer tensors with pytest.raises(TypeError): - tst.pbc_wrap_batched( - torch.ones(4, 3, dtype=torch.int64, device=device), cell, system_idx + ft.pbc_wrap_batched( + torch.ones(4, 3, dtype=torch.int64, device=DEVICE), cell, system_idx ) # Test dimension mismatch - positions with pytest.raises(ValueError): - tst.pbc_wrap_batched( - torch.ones(4, 2, device=device), # Wrong dimension (2 instead of 3) + ft.pbc_wrap_batched( + torch.ones(4, 2, device=DEVICE), # Wrong dimension (2 instead of 3) cell, system_idx, ) # Test mismatch between system indices and cell with pytest.raises(ValueError): - tst.pbc_wrap_batched( + ft.pbc_wrap_batched( positions, - torch.stack([torch.eye(3, device=device)] * 3), # 3 cell but only 2 batches + torch.stack([torch.eye(3, device=DEVICE)] * 3), # 3 cell but only 2 batches system_idx, ) @@ -412,25 +415,25 @@ def test_pbc_wrap_batched_multi_atom(si_double_sim_state: ts.SimState) -> None: test_positions[system_1_mask, 1] -= cell_size_y # Apply wrapping - wrapped = tst.pbc_wrap_batched( + wrapped = ft.pbc_wrap_batched( test_positions, cell=state.cell, system_idx=state.system_idx ) # Check all positions are within the cell boundaries - for b in range(2): # For each system - system_mask = state.system_idx == b + for sys_idx in range(2): + system_mask = state.system_idx == sys_idx # Check x coordinates assert torch.all(wrapped[system_mask, 0] >= 0) - assert torch.all(wrapped[system_mask, 0] < state.cell[b, 0, 0]) + assert torch.all(wrapped[system_mask, 0] < state.cell[sys_idx, 0, 0]) # Check y coordinates assert torch.all(wrapped[system_mask, 1] >= 0) - assert torch.all(wrapped[system_mask, 1] < state.cell[b, 1, 1]) + assert torch.all(wrapped[system_mask, 1] < state.cell[sys_idx, 1, 1]) # Check z coordinates assert torch.all(wrapped[system_mask, 2] >= 0) - assert torch.all(wrapped[system_mask, 2] < state.cell[b, 2, 2]) + assert torch.all(wrapped[system_mask, 2] < state.cell[sys_idx, 2, 2]) def test_pbc_wrap_batched_preserves_relative_positions( @@ -444,19 +447,19 @@ def test_pbc_wrap_batched_preserves_relative_positions( # Move all atoms outside the cell, but maintain their relative positions test_positions = original_positions.clone() - test_positions += torch.tensor([10.0, 15.0, 20.0], device=state.device) + test_positions += torch.tensor([10.0, 15.0, 20.0], device=DEVICE) # Apply wrapping - wrapped = tst.pbc_wrap_batched( + wrapped = ft.pbc_wrap_batched( test_positions, cell=state.cell, system_idx=state.system_idx ) # Check that relative positions within each system are preserved - for b in range(2): # For each batch - system_idx_mask = state.system_idx == b + for sys_idx in range(2): + system_idx_mask = state.system_idx == sys_idx # Calculate pairwise distances before wrapping - atoms_in_batch = int(torch.sum(system_idx_mask).item()) + atoms_in_batch = torch.sum(system_idx_mask).item() for n_atoms in range(atoms_in_batch - 1): for j in range(n_atoms + 1, atoms_in_batch): # Get the indices of atoms i and j in this batch @@ -467,10 +470,12 @@ def test_pbc_wrap_batched_preserves_relative_positions( # Original vector from i to j orig_vec = ( original_positions[idx_j] - original_positions[idx_i] - ) % state.cell[b].diag() + ) % state.cell[sys_idx].diag() # Vector after wrapping - wrapped_vec = (wrapped[idx_j] - wrapped[idx_i]) % state.cell[b].diag() + wrapped_vec = (wrapped[idx_j] - wrapped[idx_i]) % state.cell[ + sys_idx + ].diag() # Check that relative positions are preserved assert torch.allclose(orig_vec, wrapped_vec, atol=1e-6) @@ -484,7 +489,7 @@ def test_safe_mask_basic() -> None: """ x = torch.tensor([1.0, 2.0, -1.0]) mask = torch.tensor([True, True, False]) - result = tst.safe_mask(mask, torch.log, x) + result = ft.safe_mask(mask, torch.log, x) expected = torch.tensor([0, 0.6931, 0]) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) @@ -498,7 +503,7 @@ def test_safe_mask_custom_placeholder() -> None: """ x = torch.tensor([1.0, 2.0, -1.0]) mask = torch.tensor([True, False, False]) - result = tst.safe_mask(mask, torch.log, x, placeholder=-999.0) + result = ft.safe_mask(mask, torch.log, x, placeholder=-999.0) expected = torch.tensor([0.0, -999, -999]) torch.testing.assert_close(result, expected) @@ -512,7 +517,7 @@ def test_safe_mask_all_masked() -> None: """ x = torch.tensor([1.0, 2.0, 3.0]) mask = torch.tensor([False, False, False]) - result = tst.safe_mask(mask, torch.log, x) + result = ft.safe_mask(mask, torch.log, x) expected = torch.zeros_like(x) torch.testing.assert_close(result, expected) @@ -526,7 +531,7 @@ def test_safe_mask_none_masked() -> None: """ x = torch.tensor([1.0, 2.0, 3.0]) mask = torch.tensor([True, True, True]) - result = tst.safe_mask(mask, torch.log, x) + result = ft.safe_mask(mask, torch.log, x) expected = torch.log(x) torch.testing.assert_close(result, expected) @@ -542,7 +547,7 @@ def test_safe_mask_shape_mismatch() -> None: mask = torch.tensor([True, False]) with pytest.raises(RuntimeError): - tst.safe_mask(mask, torch.log, x) + ft.safe_mask(mask, torch.log, x) def test_high_precision_sum_float() -> None: @@ -554,7 +559,7 @@ def test_high_precision_sum_float() -> None: 3. The precision is adequate for basic float32 operations """ x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) assert result.dtype == torch.float32 expected = torch.tensor(6.0, dtype=torch.float32) torch.testing.assert_close(result, expected) @@ -569,7 +574,7 @@ def test_high_precision_sum_double() -> None: 3. No precision is lost when input is already float64 """ x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) assert result.dtype == torch.float64 expected = torch.tensor(6.0, dtype=torch.float64) torch.testing.assert_close(result, expected) @@ -584,7 +589,7 @@ def test_high_precision_sum_int() -> None: 3. Integer arithmetic is precise and lossless """ x = torch.tensor([1, 2, 3], dtype=torch.int32) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) assert result.dtype == torch.int32 assert result == torch.tensor(6, dtype=torch.int32) @@ -599,7 +604,7 @@ def test_high_precision_sum_complex() -> None: 4. Complex arithmetic is performed at high precision """ x = torch.tensor([1 + 1j, 2 + 2j], dtype=torch.complex64) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) assert result.dtype == torch.complex64 expected = torch.tensor(3 + 3j, dtype=torch.complex64) torch.testing.assert_close(result, expected) @@ -618,7 +623,7 @@ def test_high_precision_sum_dim() -> None: Output shape: (2,) when dim=0 """ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - result = tst.high_precision_sum(x, dim=0) + result = ft.high_precision_sum(x, dim=0) expected = torch.tensor([4.0, 6.0], dtype=torch.float32) torch.testing.assert_close(result, expected) @@ -636,7 +641,7 @@ def test_high_precision_sum_keepdim() -> None: Output shape: (1, 2) when dim=0 and keepdim=True """ x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - result = tst.high_precision_sum(x, dim=0, keepdim=True) + result = ft.high_precision_sum(x, dim=0, keepdim=True) assert result.shape == (1, 2) expected = torch.tensor([[4.0, 6.0]], dtype=torch.float32) torch.testing.assert_close(result, expected) @@ -656,7 +661,7 @@ def test_high_precision_sum_multiple_dims() -> None: Each output element is the sum of 8 numbers (2 * 4 = 8) """ x = torch.ones((2, 3, 4), dtype=torch.float32) - result = tst.high_precision_sum(x, dim=(0, 2)) + result = ft.high_precision_sum(x, dim=(0, 2)) assert result.shape == (3,) expected = torch.tensor([8.0, 8.0, 8.0], dtype=torch.float32) torch.testing.assert_close(result, expected) @@ -672,7 +677,7 @@ def test_high_precision_sum_numerical_stability() -> None: """ # Create a tensor with numbers of very different magnitudes x = torch.tensor([1e-8, 1e8, 1e-8], dtype=torch.float32) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) expected = torch.tensor(1e8 + 2e-8, dtype=torch.float32) torch.testing.assert_close(result, expected, atol=1e-8, rtol=1e-8) @@ -686,7 +691,7 @@ def test_high_precision_sum_empty() -> None: 3. The sum of an empty tensor is 0 of the appropriate type """ x = torch.tensor([], dtype=torch.float32) - result = tst.high_precision_sum(x) + result = ft.high_precision_sum(x) assert result.dtype == torch.float32 assert result == torch.tensor(0.0, dtype=torch.float32) @@ -697,9 +702,7 @@ def test_multiplicative_isotropic_cutoff_basic() -> None: def constant_fn(dr: torch.Tensor) -> torch.Tensor: return torch.ones_like(dr) - cutoff_fn = tst.multiplicative_isotropic_cutoff( - constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) - ) + cutoff_fn = ft.multiplicative_isotropic_cutoff(constant_fn, r_onset=1.0, r_cutoff=2.0) # Test points in different regions dr = torch.tensor([0.5, 1.5, 2.5]) @@ -716,9 +719,9 @@ def test_multiplicative_isotropic_cutoff_continuity() -> None: def linear_fn(dr: torch.Tensor) -> torch.Tensor: return dr - r_onset = torch.tensor(1.0) - r_cutoff = torch.tensor(2.0) - cutoff_fn = tst.multiplicative_isotropic_cutoff(linear_fn, r_onset, r_cutoff) + r_onset = 1.0 + r_cutoff = 2.0 + cutoff_fn = ft.multiplicative_isotropic_cutoff(linear_fn, r_onset, r_cutoff) # Test near onset dr_before = torch.tensor([r_onset - 1e-5]) @@ -739,11 +742,11 @@ def test_multiplicative_isotropic_cutoff_derivative_continuity() -> None: """Test that the derivative of the cutoff function is continuous.""" def quadratic_fn(dr: torch.Tensor) -> torch.Tensor: - return dr**2 + return torch.pow(dr, 2) - r_onset = torch.tensor(1.0) - r_cutoff = torch.tensor(2.0) - cutoff_fn = tst.multiplicative_isotropic_cutoff(quadratic_fn, r_onset, r_cutoff) + r_onset = 1.0 + r_cutoff = 2.0 + cutoff_fn = ft.multiplicative_isotropic_cutoff(quadratic_fn, r_onset, r_cutoff) # Test derivative near onset and cutoff using finite differences points = torch.tensor([r_onset, r_cutoff], requires_grad=True) @@ -763,8 +766,8 @@ def test_multiplicative_isotropic_cutoff_with_parameters() -> None: def parameterized_fn(dr: torch.Tensor, scale: float) -> torch.Tensor: return scale * dr - cutoff_fn = tst.multiplicative_isotropic_cutoff( - parameterized_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) + cutoff_fn = ft.multiplicative_isotropic_cutoff( + parameterized_fn, r_onset=1.0, r_cutoff=2.0 ) dr = torch.tensor([0.5, 1.5, 2.5]) @@ -781,9 +784,7 @@ def test_multiplicative_isotropic_cutoff_batch() -> None: def constant_fn(dr: torch.Tensor) -> torch.Tensor: return torch.ones_like(dr) - cutoff_fn = tst.multiplicative_isotropic_cutoff( - constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) - ) + cutoff_fn = ft.multiplicative_isotropic_cutoff(constant_fn, r_onset=1.0, r_cutoff=2.0) # Test with 2D input dr = torch.rand(5, 5) * 3.0 @@ -800,9 +801,7 @@ def test_multiplicative_isotropic_cutoff_gradient() -> None: def linear_fn(dr: torch.Tensor) -> torch.Tensor: return dr - cutoff_fn = tst.multiplicative_isotropic_cutoff( - linear_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) - ) + cutoff_fn = ft.multiplicative_isotropic_cutoff(linear_fn, r_onset=1.0, r_cutoff=2.0) dr = torch.tensor([1.5], requires_grad=True) result = cutoff_fn(dr) @@ -834,30 +833,28 @@ def test_get_fractional_coordinates( Tests the function with both cubic and non-orthogonal cells. """ - frac = tst.get_fractional_coordinates(torch.tensor(pos), torch.tensor(cell)) + frac = ft.get_fractional_coordinates(torch.tensor(pos), torch.tensor(cell)) torch.testing.assert_close(frac, torch.tensor(expected)) def test_get_fractional_coordinates_batched() -> None: """Test get_fractional_coordinates with batched cell tensors.""" - device = torch.device("cpu") - dtype = torch.float64 positions = torch.tensor( - [[1.0, 1.0, 1.0], [2.0, 0.0, 0.0]], device=device, dtype=dtype + [[1.0, 1.0, 1.0], [2.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE ) # Test single system case (should work) cell_single_system = torch.tensor( - [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]], device=device, dtype=dtype + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]], device=DEVICE, dtype=DTYPE ) - frac_batched = tst.get_fractional_coordinates(positions, cell_single_system) + frac_batched = ft.get_fractional_coordinates(positions, cell_single_system) # Compare with 2D case cell_2d = torch.tensor( - [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], device=device, dtype=dtype + [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], device=DEVICE, dtype=DTYPE ) - frac_2d = tst.get_fractional_coordinates(positions, cell_2d) + frac_2d = ft.get_fractional_coordinates(positions, cell_2d) assert torch.allclose(frac_batched, frac_2d), ( "Single system case should produce same result as 2D case" @@ -869,12 +866,12 @@ def test_get_fractional_coordinates_batched() -> None: [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], [[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]], ], - device=device, - dtype=dtype, + device=DEVICE, + dtype=DTYPE, ) with pytest.raises(NotImplementedError, match="Multiple system cell tensors"): - tst.get_fractional_coordinates(positions, cell_multi_system) + ft.get_fractional_coordinates(positions, cell_multi_system) @pytest.mark.parametrize( @@ -882,19 +879,19 @@ def test_get_fractional_coordinates_batched() -> None: [ ( [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], - torch.eye(3) * 3.0, + torch.eye(3, dtype=DTYPE) * 3.0, False, [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], ), ( [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], - torch.eye(3) * 3.0, + torch.eye(3, dtype=DTYPE) * 3.0, True, [[1.5, 1.5, 1.5], [-1.5, -1.5, -1.5]], ), ( [[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]], - torch.eye(3) * 2.0, + torch.eye(3, dtype=DTYPE) * 2.0, True, [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], ), @@ -907,8 +904,10 @@ def test_minimum_image_displacement( Tests function with and without PBC and with different displacement vectors. """ - result = tst.minimum_image_displacement(dr=torch.tensor(dr), cell=cell, pbc=pbc) - torch.testing.assert_close(result, torch.tensor(expected)) + dr_tensor = torch.tensor(dr, dtype=DTYPE) + cell = torch.tensor(cell, dtype=DTYPE) + result = ft.minimum_image_displacement(dr=dr_tensor, cell=cell, pbc=pbc) + torch.testing.assert_close(result, torch.tensor(expected, dtype=DTYPE)) @pytest.mark.parametrize( @@ -957,7 +956,7 @@ def test_get_pair_displacements( Tests function with and without PBC, with specific pairs, and with explicit shifts. """ - dr, distances = tst.get_pair_displacements( + dr, distances = ft.get_pair_displacements( positions=positions, cell=cell, pbc=pbc, pairs=pairs, shifts=shifts ) @@ -967,10 +966,7 @@ def test_get_pair_displacements( @pytest.mark.parametrize( ("v", "expected"), - [ - ([1, 2, 3], [0, 1, 3, 6]), - ([[1, 2], [3, 4]], [0, 1, 3, 6, 10]), - ], + [([1, 2, 3], [0, 1, 3, 6]), ([[1, 2], [3, 4]], [0, 1, 3, 6, 10])], ) def test_strides_of(v: list[int], expected: list[int]) -> None: """Test strides_of with 1D and 2D tensors. @@ -978,14 +974,14 @@ def test_strides_of(v: list[int], expected: list[int]) -> None: Verifies that the function correctly computes cumulative strides for both 1D and multidimensional tensors. """ - strides = tst.strides_of(torch.tensor(v)) + strides = ft.strides_of(torch.tensor(v)) torch.testing.assert_close(strides, torch.tensor(expected)) def test_strides_of_empty() -> None: """Test strides_of with empty tensor.""" v = torch.tensor([], dtype=torch.int64) - strides = tst.strides_of(v) + strides = ft.strides_of(v) expected = torch.tensor([0], dtype=torch.int64) torch.testing.assert_close(strides, expected) @@ -1027,20 +1023,20 @@ def test_get_number_of_cell_repeats( Tests with different cell sizes, PBC conditions, and batch sizes. """ - num_repeats = tst.get_number_of_cell_repeats(cutoff, cell, pbc) + num_repeats = ft.get_number_of_cell_repeats(cutoff, cell, pbc) # Check shape assert num_repeats.shape == expected_shape # Check specific properties based on test case - if "min_value" in expected_props: - assert torch.all(num_repeats >= expected_props["min_value"]) + if min_value := expected_props.get("min_value"): + assert torch.all(num_repeats >= min_value) if expected_props.get("all_equal"): assert num_repeats[0, 0] == num_repeats[0, 1] == num_repeats[0, 2] - if "zero_dim" in expected_props: - assert num_repeats[0, expected_props["zero_dim"]] == 0 + if zero_dim := expected_props.get("zero_dim"): + assert num_repeats[0, zero_dim] == 0 if expected_props.get("batch_equal"): assert num_repeats[1, 0] == num_repeats[1, 1] == num_repeats[1, 2] @@ -1061,22 +1057,22 @@ def test_get_cell_shift_idx( Tests the function with symmetric, zero, and asymmetric repeats. """ - n_repeats = torch.tensor(num_repeats) - shifts = tst.get_cell_shift_idx(n_repeats, torch.float) + n_repeats = torch.tensor(num_repeats, dtype=torch.float64) + shifts = ft.get_cell_shift_idx(n_repeats, torch.float64) # Check shape assert shifts.shape == expected_shape # Check ranges or exact values - if "min" in expected_range and "max" in expected_range: - assert torch.all(shifts >= expected_range["min"]) - assert torch.all(shifts <= expected_range["max"]) + if (min_val := expected_range.get("min")) and (max_val := expected_range.get("max")): + assert torch.all(shifts >= min_val) + assert torch.all(shifts <= max_val) - if "exact" in expected_range: - torch.testing.assert_close(shifts, torch.tensor(expected_range["exact"])) + if exact := expected_range.get("exact"): + torch.testing.assert_close(shifts, torch.tensor(exact, dtype=torch.float64)) - if "dim_values" in expected_range: - for dim, (min_val, max_val) in expected_range["dim_values"].items(): + if dim_values := expected_range.get("dim_values"): + for dim, (min_val, max_val) in dim_values.items(): assert torch.all(shifts[:, dim] >= min_val) assert torch.all(shifts[:, dim] <= max_val) @@ -1090,7 +1086,7 @@ def test_ravel_3d(idx_3d: list[list[int]], shape: list[int], expected: list[int] Verifies correct conversion of 3D indices to linear indices. """ - linear_idx = tst.ravel_3d(torch.tensor(idx_3d), torch.tensor(shape)) + linear_idx = ft.ravel_3d(torch.tensor(idx_3d), torch.tensor(shape)) torch.testing.assert_close(linear_idx, torch.tensor(expected)) @@ -1105,7 +1101,7 @@ def test_unravel_3d( Verifies correct conversion of linear indices back to 3D indices. """ - idx_3d = tst.unravel_3d(torch.tensor(linear_idx), torch.tensor(shape)) + idx_3d = ft.unravel_3d(torch.tensor(linear_idx), torch.tensor(shape)) torch.testing.assert_close(idx_3d, torch.tensor(expected)) @@ -1114,15 +1110,22 @@ def test_ravel_unravel_3d_roundtrip() -> None: original_idx = torch.tensor([[0, 1, 2], [1, 0, 3], [1, 2, 0]]) shape = torch.tensor([2, 3, 4]) - linear_idx = tst.ravel_3d(original_idx, shape) - reconstructed_idx = tst.unravel_3d(linear_idx, shape) + linear_idx = ft.ravel_3d(original_idx, shape) + reconstructed_idx = ft.unravel_3d(linear_idx, shape) torch.testing.assert_close(reconstructed_idx, original_idx) @pytest.mark.parametrize( ("cell", "pos", "n_bins_s", "expected"), - [(torch.eye(3) * 2.0, [[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], [2, 2, 2], [0, 7])], + [ + ( + torch.eye(3, dtype=torch.float64) * 2.0, + [[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]], + [2, 2, 2], + [0, 7], + ) + ], ) def test_get_linear_bin_idx( cell: torch.Tensor, pos: list[float], n_bins_s: list[float], expected: list[float] @@ -1131,8 +1134,10 @@ def test_get_linear_bin_idx( Verifies correct calculation of linear bin indices for positions. """ - bin_idx = tst.get_linear_bin_idx(cell, torch.tensor(pos), torch.tensor(n_bins_s)) - torch.testing.assert_close(bin_idx, torch.tensor(expected)) + bin_idx = ft.get_linear_bin_idx( + cell, torch.tensor(pos, dtype=torch.float64), torch.tensor(n_bins_s) + ) + torch.testing.assert_close(bin_idx, torch.tensor(expected, dtype=torch.int64)) def test_scatter_bin_index_basic() -> None: @@ -1142,7 +1147,7 @@ def test_scatter_bin_index_basic() -> None: n_images = 5 bin_index = torch.tensor([0, 0, 1, 2]) - bin_id = tst.scatter_bin_index(n_bins, max_n_atom_per_bin, n_images, bin_index) + bin_id = ft.scatter_bin_index(n_bins, max_n_atom_per_bin, n_images, bin_index) # Check shape and basic properties assert bin_id.shape == torch.Size([3, 2]) @@ -1177,7 +1182,7 @@ def test_compute_distances_with_cell_shifts( Tests with and without cell shifts applied. """ - distances = tst.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) + distances = ft.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) torch.testing.assert_close(distances, expected) @@ -1187,7 +1192,7 @@ def test_compute_cell_shifts_basic() -> None: shifts_idx = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) system_mapping = torch.tensor([0, 0]) - cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, system_mapping) + cell_shifts = ft.compute_cell_shifts(cell, shifts_idx, system_mapping) expected = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) torch.testing.assert_close(cell_shifts, expected) @@ -1207,7 +1212,7 @@ def test_get_fully_connected_mapping( i_ids = torch.tensor([0, 1, 2]) shifts_idx = torch.tensor([[0.0, 0.0, 0.0]]) - mapping, shifts = tst.get_fully_connected_mapping( + mapping, shifts = ft.get_fully_connected_mapping( i_ids=i_ids, shifts_idx=shifts_idx, self_interaction=self_interaction ) @@ -1227,7 +1232,7 @@ def test_get_fully_connected_mapping_with_multiple_shifts() -> None: i_ids = torch.tensor([0, 1]) shifts_idx = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) - mapping, shifts = tst.get_fully_connected_mapping( + mapping, shifts = ft.get_fully_connected_mapping( i_ids=i_ids, shifts_idx=shifts_idx, self_interaction=False ) @@ -1244,7 +1249,7 @@ def test_linked_cell_basic() -> None: cutoff = 1.5 num_repeats = torch.tensor([1, 1, 1]) - neigh_atom, neigh_shift_idx = tst.linked_cell( + neigh_atom, _neigh_shift_idx = ft.linked_cell( pos, cell, cutoff, num_repeats, self_interaction=False ) @@ -1274,7 +1279,7 @@ def test_build_linked_cell_neighborhood_basic() -> None: cutoff = 1.5 n_atoms = torch.tensor([2, 2]) - mapping, system_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood( + mapping, system_mapping, _cell_shifts_idx = ft.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction=False ) diff --git a/tests/test_voigt.py b/tests/test_voigt.py index 126f72b04..c490b604b 100644 --- a/tests/test_voigt.py +++ b/tests/test_voigt.py @@ -52,9 +52,9 @@ def test_batch_conversion(): assert result.shape == (2, 3, 3) # Test each batch independently - for i in range(2): - single_result = voigt_6_to_full_3x3_stress(batch_voigt[i]) - assert torch.allclose(result[i], single_result) + for batch_idx in range(2): + single_result = voigt_6_to_full_3x3_stress(batch_voigt[batch_idx]) + assert torch.allclose(result[batch_idx], single_result) def test_symmetry(): diff --git a/tests/workflows/test_a2c.py b/tests/workflows/test_a2c.py index a95ce0e0c..27dd26d7e 100644 --- a/tests/workflows/test_a2c.py +++ b/tests/workflows/test_a2c.py @@ -1,12 +1,10 @@ -from typing import cast - import pytest import torch -from pymatgen.core.composition import Composition +from pymatgen.core import Composition import torch_sim as ts +from tests.conftest import DEVICE, DTYPE from torch_sim.models.soft_sphere import SoftSphereModel -from torch_sim.optimizers import FireState, UnitCellFireState from torch_sim.workflows import a2c @@ -14,18 +12,26 @@ ("positions", "cell", "expected_min_dist"), [ ( - torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), - torch.eye(3) * 10.0, + torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE, dtype=DTYPE) * 10.0, 1.0, ), ( - torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]), - torch.eye(3) * 5.0, + torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE, dtype=DTYPE) * 5.0, 0.866025, # sqrt(3)/2 ), ( - torch.tensor([[0.0, 0.0, 0.0], [2.9, 0.0, 0.0], [0.0, 0.0, 2.9]]), - torch.tensor([[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]]), + torch.tensor( + [[0.0, 0.0, 0.0], [2.9, 0.0, 0.0], [0.0, 0.0, 2.9]], + device=DEVICE, + dtype=DTYPE, + ), + torch.tensor( + [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + device=DEVICE, + dtype=DTYPE, + ), 0.1, # Due to PBC, atoms at 2.9 are closer via boundary ), ], @@ -35,7 +41,9 @@ def test_min_distance( ) -> None: """Test calculation of minimum distance between atoms.""" min_dist = a2c.min_distance(positions, cell) - assert torch.isclose(min_dist, torch.tensor(expected_min_dist), atol=1e-5) + assert torch.isclose( + min_dist, torch.tensor(expected_min_dist, device=DEVICE, dtype=DTYPE), atol=1e-5 + ) @pytest.mark.parametrize( @@ -70,39 +78,38 @@ def test_get_diameter_parametrized( ], ) def test_get_diameter_matrix_parametrized( - composition_str: str, expected_size: int, dtype: torch.dtype, device: torch.device + composition_str: str, expected_size: int, dtype: torch.dtype ) -> None: """Test diameter matrix calculation with different compositions.""" comp = Composition(composition_str) - matrix = a2c.get_diameter_matrix(comp, device=device, dtype=dtype) + matrix = a2c.get_diameter_matrix(comp, device=DEVICE, dtype=dtype) # Check matrix properties assert matrix.shape == (expected_size, expected_size) assert matrix.dtype == dtype - assert matrix.device == device + assert matrix.device == DEVICE assert torch.all(matrix > 0) assert torch.allclose(matrix, matrix.T) # Symmetry -def test_random_packed_structure_basic(device: torch.device) -> None: +def test_random_packed_structure_basic() -> None: """Test basic functionality of random_packed_structure.""" - comp = Composition("Cu4") - cell = torch.eye(3, device=device) * 5.0 + comp: Composition = Composition("Cu4") + cell: torch.Tensor = torch.eye(3, device=DEVICE, dtype=DTYPE) * 5.0 # Test with minimal optimization to ensure state is created - state = a2c.random_packed_structure( + state, _log = a2c.random_packed_structure( composition=comp, cell=cell, seed=42, - # Use a diameter to ensure the state is created - diameter=2.5, + diameter=2.5, # Use a diameter to ensure the state is created max_iter=1, - device=device, + device=DEVICE, + dtype=DTYPE, ) - # Check state properties assert state.positions.shape == (4, 3) - assert state.positions.device == device + assert state.positions.device == DEVICE assert torch.all(state.positions >= 0) assert torch.all(state.positions <= cell[0, 0]) @@ -116,49 +123,40 @@ def test_random_packed_structure_optimization( cell_size: float, diameter: float, max_iter: int, - device: torch.device, ) -> None: """Test random_packed_structure with optimization.""" comp = Composition(composition_str) - cell = torch.eye(3, device=device) * cell_size + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * cell_size - log = [] - result = a2c.random_packed_structure( + state, log = a2c.random_packed_structure( composition=comp, cell=cell, seed=42, diameter=diameter, max_iter=max_iter, - device=device, - log=log, + device=DEVICE, + dtype=DTYPE, ) - # Handle the case where a tuple is returned when log is provided - if isinstance(result, tuple): - state, _ = result - else: - state = result - # Check that optimization happened assert len(log) > 0 assert state.energy is not None -def test_random_packed_structure_auto_diameter(device: torch.device) -> None: +def test_random_packed_structure_auto_diameter() -> None: """Test random_packed_structure with auto_diameter option.""" comp = Composition("Cu4") - cell = torch.eye(3, device=device) * 6.0 + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 6.0 - state = a2c.random_packed_structure( + state, _log = a2c.random_packed_structure( composition=comp, cell=cell, seed=42, auto_diameter=True, max_iter=3, - device=device, + device=DEVICE, + dtype=DTYPE, ) - state = cast("FireState", state) - # Just check that it ran without errors assert state.positions is not None assert state.energy is not None @@ -171,30 +169,30 @@ def test_random_packed_structure_auto_diameter(device: torch.device) -> None: "initial_energy", "final_energy", "e_tol", - "fe_lower_limit", + "e_form_lower_limit", "fe_upper_limit", "fusion_distance", "expected", ), [ ( - torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]), - torch.eye(3) * 5.0, + torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE) * 5.0, *(0.0, -1.0, 0.001, -5.0, 0.0, 1.5, False), ), ( # Invalid - no energy decrease - torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]), - torch.eye(3) * 5.0, + torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE) * 5.0, *(-1.0, -1.0, 0.001, -5.0, 0.0, 1.5, False), ), ( # Invalid - energy too low - torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]), - torch.eye(3) * 5.0, + torch.tensor([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE) * 5.0, *(0.0, -10.0, 0.001, -5.0, 0.0, 1.5, False), ), ( # Invalid - atoms too close - torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), - torch.eye(3) * 5.0, + torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + torch.eye(3, device=DEVICE) * 5.0, *(0.0, -1.0, 0.001, -5.0, 0.0, 1.5, False), ), ], @@ -206,7 +204,7 @@ def test_valid_subcell( initial_energy: float, final_energy: float, e_tol: float, - fe_lower_limit: float, + e_form_lower_limit: float, fe_upper_limit: float, fusion_distance: float, expected: bool, @@ -219,7 +217,7 @@ def test_valid_subcell( initial_energy=initial_energy, final_energy=final_energy, e_tol=e_tol, - fe_lower_limit=fe_lower_limit, + e_form_lower_limit=e_form_lower_limit, fe_upper_limit=fe_upper_limit, fusion_distance=fusion_distance, ) @@ -242,7 +240,6 @@ def test_get_subcells_to_crystallize_parametrized( n_min: int, n_max: int, should_find_candidates: bool, - device: torch.device, ) -> None: """Test subcell candidate extraction with different parameters.""" frac_positions = torch.tensor( @@ -253,7 +250,8 @@ def test_get_subcells_to_crystallize_parametrized( [0.6, 0.6, 0.6], [0.8, 0.8, 0.8], ], - device=device, + device=DEVICE, + dtype=DTYPE, ) species = ["Cu", "Cu", "O", "O", "O"] @@ -278,7 +276,6 @@ def test_get_subcells_with_max_coeff( max_coeff: int, elements: list[str], expected_min_candidates: int, - device: torch.device, ) -> None: """Test subcell extraction with max_coeff parameter.""" frac_positions = torch.tensor( @@ -289,7 +286,7 @@ def test_get_subcells_with_max_coeff( [0.6, 0.6, 0.6], [0.8, 0.8, 0.8], ], - device=device, + device=DEVICE, ) species = ["Cu", "Cu", "O", "O", "O"] @@ -332,7 +329,7 @@ def test_get_target_temperature_parametrized( def create_test_model( - *, device: torch.device, compute_stress: bool = True + *, device: torch.device, compute_stress: bool = True, dtype: torch.dtype = DTYPE ) -> SoftSphereModel: """Create a simple soft sphere model for testing.""" return SoftSphereModel( @@ -342,33 +339,33 @@ def create_test_model( device=device, compute_forces=True, compute_stress=compute_stress, + dtype=dtype, ) def create_test_state(positions: torch.Tensor, cell: torch.Tensor) -> ts.SimState: """Create a simple simulation state for testing.""" n_atoms = positions.shape[0] - device = positions.device return ts.SimState( positions=positions, cell=cell, pbc=True, - masses=torch.ones(n_atoms, device=device), - atomic_numbers=torch.ones(n_atoms, device=device, dtype=torch.long), + masses=torch.ones(n_atoms, device=positions.device, dtype=positions.dtype), + atomic_numbers=torch.ones(n_atoms, device=positions.device, dtype=torch.long), ) @pytest.mark.parametrize("max_iter", [1, 2, 3]) -def test_get_unit_cell_relaxed_structure(max_iter: int, device: torch.device) -> None: +def test_get_unit_cell_relaxed_structure(max_iter: int) -> None: """Test unit cell relaxation with FIRE algorithm.""" # Create a simple test system positions = torch.tensor( - [[0.0, 0.0, 0.0], [1.5, 0.0, 0.0], [0.0, 1.5, 0.0]], device=device + [[0.0, 0.0, 0.0], [1.5, 0.0, 0.0], [0.0, 1.5, 0.0]], device=DEVICE, dtype=DTYPE ) - cell = torch.eye(3, device=device) * 5.0 + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 5.0 # Create model and state - model = create_test_model(device=device) + model = create_test_model(device=DEVICE) state = create_test_state(positions, cell) # Run relaxation with minimal steps @@ -377,7 +374,34 @@ def test_get_unit_cell_relaxed_structure(max_iter: int, device: torch.device) -> ) # Basic checks - assert isinstance(relaxed_state, UnitCellFireState) + assert isinstance(relaxed_state, ts.FireState) + assert logger["energy"].shape[0] == max_iter + assert isinstance(final_energy[0], float) + assert isinstance(final_pressure[0], float) + + +@pytest.mark.parametrize("max_iter", [1, 2, 3]) +def test_get_frechet_cell_relaxed_structure(max_iter: int) -> None: + """Test unit cell relaxation with FIRE algorithm.""" + # Create a simple test system + positions = torch.tensor( + [[0.0, 0.0, 0.0], [1.5, 0.0, 0.0], [0.0, 1.5, 0.0]], device=DEVICE, dtype=DTYPE + ) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 5.0 + + # Create model and state + model = create_test_model(device=DEVICE) + state = create_test_state(positions, cell) + + # Run relaxation with minimal steps + relaxed_state, logger, final_energy, final_pressure = ( + a2c.get_frechet_cell_relaxed_structure( + state=state, model=model, max_iter=max_iter + ) + ) + + # Basic checks + assert isinstance(relaxed_state, ts.FireState) assert logger["energy"].shape[0] == max_iter assert isinstance(final_energy[0], float) assert isinstance(final_pressure[0], float) @@ -389,12 +413,12 @@ def test_get_unit_cell_relaxed_structure(max_iter: int, device: torch.device) -> ids=["Equal number of positions and species", "Larger system"], ) def test_subcells_to_structures_parametrized( - n_positions: int, n_species: int, cell_size: float, device: torch.device + n_positions: int, n_species: int, cell_size: float ) -> None: """Test subcell extraction and conversion with various parameters.""" # Create test data with varying sizes - frac_positions = torch.rand((n_positions, 3), device=device) - cell = torch.eye(3, device=device) * cell_size + frac_positions = torch.rand((n_positions, 3), device=DEVICE) + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * cell_size # Create alternating Cu/O species list species = ["Cu" if idx % 2 == 0 else "O" for idx in range(n_species)] @@ -408,13 +432,13 @@ def test_subcells_to_structures_parametrized( # Check output format assert len(structures) == len(candidates) - for pos, subcell, spec in structures: + for pos, subcell, species in structures: assert isinstance(pos, torch.Tensor) assert isinstance(subcell, torch.Tensor) - assert isinstance(spec, list) + assert isinstance(species, list) assert pos.shape[1] == 3 # 3D positions assert subcell.shape == (3, 3) # 3x3 cell matrix - assert all(isinstance(s, str) for s in spec) # Species strings + assert all(isinstance(s, str) for s in species) # Species strings # Ensure positions are in [0,1] range (fractional coordinates) assert torch.all(pos >= 0.0) @@ -423,20 +447,19 @@ def test_subcells_to_structures_parametrized( def test_subcells_to_structures_ensures_proper_scaling() -> None: """Test that subcells_to_structures properly scales the positions and cell.""" - device = torch.device("cpu") # Create test data with a known grid of points frac_positions = torch.tensor( - [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], device=device + [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]], device=DEVICE, dtype=DTYPE ) - cell = torch.eye(3, device=device) * 10.0 + cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 10.0 species = ["Cu", "Cu", "O"] # Create a candidate with known bounds - ids = torch.tensor([0, 1]) # First two atoms - lower = torch.tensor([0.0, 0.0, 0.0]) - upper = torch.tensor([0.25, 0.25, 0.25]) + ids = torch.tensor([0, 1], device=DEVICE, dtype=torch.long) # First two atoms + lower = torch.tensor([0.0, 0.0, 0.0], device=DEVICE, dtype=DTYPE) + upper = torch.tensor([0.25, 0.25, 0.25], device=DEVICE, dtype=DTYPE) candidates = [(ids, lower, upper)] # Convert to structures @@ -447,14 +470,14 @@ def test_subcells_to_structures_ensures_proper_scaling() -> None: # Check that positions are rescaled to [0,1] range assert torch.allclose( - subcell_pos[0], torch.tensor([0.4, 0.4, 0.4]) + subcell_pos[0], torch.tensor([0.4, 0.4, 0.4], device=DEVICE, dtype=DTYPE) ) # (0.1-0.0)/0.25 = 0.4 assert torch.allclose( - subcell_pos[1], torch.tensor([0.8, 0.8, 0.8]) + subcell_pos[1], torch.tensor([0.8, 0.8, 0.8], device=DEVICE, dtype=DTYPE) ) # (0.2-0.0)/0.25 = 0.8 # Check that cell is scaled properly - expected_cell = torch.eye(3) * 2.5 # 10.0 * 0.25 = 2.5 + expected_cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.5 # 10.0 * 0.25 = 2.5 assert torch.allclose(subcell, expected_cell) # Check species are correct @@ -463,7 +486,7 @@ def test_subcells_to_structures_ensures_proper_scaling() -> None: @pytest.mark.parametrize("restrict_to_compositions", [["CuO"], ["Cu2O", "CuO2"]]) def test_get_subcells_with_composition_restrictions( - restrict_to_compositions: list[str], device: torch.device + restrict_to_compositions: list[str], ) -> None: """Test subcell extraction with composition restrictions.""" frac_positions = torch.tensor( @@ -475,7 +498,8 @@ def test_get_subcells_with_composition_restrictions( [0.6, 0.6, 0.6], [0.8, 0.8, 0.8], ], - device=device, + device=DEVICE, + dtype=DTYPE, ) species = ["Cu", "Cu", "Cu", "O", "O", "O"] @@ -499,8 +523,7 @@ def test_get_subcells_with_composition_restrictions( def test_get_subcells_to_crystallize_invalid_inputs() -> None: """Test invalid inputs for subcell extraction.""" - device = torch.device("cpu") - frac_positions = torch.tensor([[0.1, 0.1, 0.1]], device=device) + frac_positions = torch.tensor([[0.1, 0.1, 0.1]], device=DEVICE, dtype=DTYPE) species = ["Cu"] # Test with max_coeff but no elements diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 49d3e380e..ff8e40c93 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -1,10 +1,10 @@ """TorchSim package base module.""" # ruff: noqa: F401 - import os from datetime import datetime +import torch_sim as ts from torch_sim import ( autobatching, elastic, @@ -22,22 +22,52 @@ units, ) from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher -from torch_sim.integrators import npt_langevin, nve, nvt_langevin - -# state propagators -from torch_sim.monte_carlo import swap_monte_carlo +from torch_sim.integrators import ( + INTEGRATOR_REGISTRY, + MdFlavor, + NVTNoseHooverState, + nve_init, + nve_update, + nvt_langevin_init, + nvt_langevin_update, + nvt_nose_hoover_init, + nvt_nose_hoover_invariant, + nvt_nose_hoover_update, +) +from torch_sim.integrators.npt import ( + NPTLangevinState, + NPTNoseHooverState, + npt_langevin_init, + npt_langevin_update, + npt_nose_hoover_init, + npt_nose_hoover_invariant, + npt_nose_hoover_update, +) +from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step from torch_sim.optimizers import ( - frechet_cell_fire, - gradient_descent, - unit_cell_fire, - unit_cell_gradient_descent, + OPTIM_REGISTRY, + FireState, + OptimFlavor, + OptimState, + fire_init, + fire_step, + gradient_descent_init, + gradient_descent_step, +) +from torch_sim.optimizers.cell_filters import ( + CELL_FILTER_REGISTRY, + CellFilter, + CellFireState, + CellOptimState, + get_cell_filter, ) - -# quantities/properties from torch_sim.properties.correlations import CorrelationCalculator -from torch_sim.quantities import calc_kinetic_energy, calc_kT - -# high level runners and support +from torch_sim.quantities import ( + calc_kinetic_energy, + calc_kT, + get_pressure, + system_wise_max_force, +) from torch_sim.runners import ( generate_energy_convergence_fn, generate_force_convergence_fn, @@ -45,8 +75,6 @@ optimize, static, ) - -# state and state manipulation from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TorchSimTrajectory, TrajectoryReporter diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 83c341bf1..0a0966830 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -11,7 +11,7 @@ batcher = BinningAutoBatcher(model, memory_scales_with="n_atoms") batcher.load_states(states) final_states = [] - for batch in batcher: + for batch, _indices in batcher: final_states.append(evolve_batch(batch)) final_states = batcher.restore_original_order(final_states) @@ -20,33 +20,34 @@ model architectures and GPU configurations. """ -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterator, Sequence from itertools import chain from typing import Any, get_args import torch +import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, concatenate_states +from torch_sim.state import SimState from torch_sim.typing import MemoryScaling -def to_constant_volume_bins( # noqa: C901, PLR0915 - items: dict[int, float] | list[float] | list[tuple], +def to_constant_volume_bins[T: dict[int, float] | list[float] | list[tuple[T, ...]]]( # noqa: C901, PLR0915 + items: T, max_volume: float, *, weight_pos: int | None = None, - key: Callable | None = None, + key: Callable[[T], float] | None = None, lower_bound: float | None = None, upper_bound: float | None = None, -) -> list[dict[int, float]] | list[list[float]] | list[list[tuple]]: +) -> list[T]: """Distribute items into bins of fixed maximum volume. Groups items into the minimum number of bins possible while ensuring each bin's total weight does not exceed max_volume. Items are sorted by weight in descending order before binning to improve packing efficiency. - Upstreamed from binpacking by @benmaier. https://pypi.org/project/binpacking/. + Ported here from binpacking by @benmaier. https://pypi.org/project/binpacking. Args: items (dict[int, float] | list[float] | list[tuple]): Items to distribute, @@ -83,36 +84,33 @@ def _get_bins(lst: list[float], ndx: list[int]) -> list[float]: def _argmax_bins(lst: list[float]) -> int: return max(range(len(lst)), key=lst.__getitem__) - def _revargsort_bins(lst: list[float]) -> list[int]: + def _rev_argsort_bins(lst: list[float]) -> list[int]: return sorted(range(len(lst)), key=lambda i: -lst[i]) - is_dict = isinstance(items, dict) - if not hasattr(items, "__len__"): raise TypeError("d must be iterable") - if not is_dict and hasattr(items[0], "__len__"): + if not isinstance(items, dict) and hasattr(items[0], "__len__"): if weight_pos is not None: key = lambda x: x[weight_pos] # noqa: E731 if key is None: raise ValueError("Must provide weight_pos or key for tuple list") - if not is_dict and key: + if not isinstance(items, dict) and key: new_dict = dict(enumerate(items)) - items = {i: key(val) for i, val in enumerate(items)} - is_dict = True + items = {idx: key(val) for idx, val in enumerate(items)} # type: ignore[invalid-assignment] is_tuple_list = True else: is_tuple_list = False - if is_dict: + if isinstance(items, dict): # get keys and values (weights) keys_vals = items.items() keys = [k for k, v in keys_vals] vals = [v for k, v in keys_vals] # sort weights decreasingly - n_dcs = _revargsort_bins(vals) + n_dcs = _rev_argsort_bins(vals) weights = _get_bins(vals, n_dcs) keys = _get_bins(keys, n_dcs) @@ -140,7 +138,7 @@ def _revargsort_bins(lst: list[float]) -> list[int]: weights = _get_bins(weights, valid_ndcs) - if is_dict: + if isinstance(items, dict): keys = _get_bins(keys, valid_ndcs) # prepare array containing the current weight of the bins @@ -148,7 +146,7 @@ def _revargsort_bins(lst: list[float]) -> list[int]: # iterate through the weight list, starting with heaviest for item, weight in enumerate(weights): - if is_dict: + if isinstance(items, dict): key = keys[item] # find candidate bins where the weight might fit @@ -170,7 +168,7 @@ def _revargsort_bins(lst: list[float]) -> list[int]: # open a new bin b = len(weight_sum) weight_sum.append(0.0) - if is_dict: + if isinstance(items, dict): bins.append({}) else: bins.append([]) @@ -180,7 +178,7 @@ def _revargsort_bins(lst: list[float]) -> list[int]: b = 0 # put it in - if is_dict: + if isinstance(items, dict): bins[b][key] = weight else: bins[b].append(weight) @@ -192,10 +190,10 @@ def _revargsort_bins(lst: list[float]) -> list[int]: if not is_tuple_list: return bins new_bins = [] - for b in range(len(bins)): + for bin_idx in range(len(bins)): new_bins.append([]) - for _key in bins[b]: - new_bins[b].append(new_dict[_key]) + for _key in bins[bin_idx]: + new_bins[bin_idx].append(new_dict[_key]) return new_bins @@ -292,16 +290,16 @@ def determine_max_batch_size( while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: sizes.append(next_size) - for i in range(len(sizes)): - n_systems = sizes[i] - concat_state = concatenate_states([state] * n_systems) + for sys_idx in range(len(sizes)): + n_systems = sizes[sys_idx] + concat_state = ts.concatenate_states([state] * n_systems) try: measure_model_memory_forward(concat_state, model) except RuntimeError as exc: if "CUDA out of memory" in str(exc): # Return the last successful size, with a safety margin - return sizes[max(0, i - 2)] + return sizes[max(0, sys_idx - 2)] raise return sizes[-1] @@ -359,7 +357,7 @@ def calculate_memory_scaler( def estimate_max_memory_scaler( model: ModelInterface, state_list: list[SimState], - metric_values: list[float], + metric_values: list[float] | torch.Tensor, **kwargs: Any, ) -> float: """Estimate maximum memory scaling metric that fits in GPU memory. @@ -411,10 +409,13 @@ def estimate_max_memory_scaler( min_state_max_batches = determine_max_batch_size(min_state, model, **kwargs) max_state_max_batches = determine_max_batch_size(max_state, model, **kwargs) - return min(min_state_max_batches * min_metric, max_state_max_batches * max_metric) + return min( + min_state_max_batches * min_metric.item(), + max_state_max_batches * max_metric.item(), + ) -class BinningAutoBatcher: +class BinningAutoBatcher[T: SimState]: """Batcher that groups states into bins of similar computational cost. Divides a collection of states into batches that can be processed efficiently @@ -430,7 +431,6 @@ class BinningAutoBatcher: memory_scales_with (str): Metric type used for memory estimation. max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. - return_indices (bool): Whether to return original indices with batches. state_slices (list[SimState]): Individual states to be batched. memory_scalers (list[float]): Memory scaling metrics for each state. index_to_scaler (dict): Mapping from state index to its scaling metric. @@ -449,7 +449,7 @@ class BinningAutoBatcher: # Load states and process them in batches batcher.load_states(states) final_states = [] - for batch in batcher: + for batch, _indices in batcher: final_states.append(evolve_batch(batch)) # Restore original order @@ -462,7 +462,6 @@ def __init__( *, memory_scales_with: MemoryScaling = "n_atoms_x_density", max_memory_scaler: float | None = None, - return_indices: bool = False, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, max_memory_padding: float = 1.0, @@ -479,29 +478,23 @@ def __init__( Defaults to "n_atoms_x_density". max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. - return_indices (bool): Whether to return original indices along with batches. - Defaults to False. max_atoms_to_try (int): Maximum number of atoms to try when estimating max_memory_scaler. Defaults to 500,000. memory_scaling_factor (float): Factor to multiply batch size by in each iteration. Larger values will get a batch size more quickly, smaller values will get a more accurate limit. Must be greater than 1. Defaults to 1.6. - max_memory_padding (float): Multiply the autodetermined max_memory_scaler + max_memory_padding (float): Multiply the auto-determined max_memory_scaler by this value to account for fluctuations in max memory. Defaults to 1.0. """ self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try self.memory_scales_with = memory_scales_with - self.return_indices = return_indices self.model = model self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding - def load_states( - self, - states: list[SimState] | SimState, - ) -> float: + def load_states(self, states: T | Sequence[T]) -> float: """Load new states into the batcher. Processes the input states, computes memory scaling metrics for each, @@ -509,7 +502,7 @@ def load_states( to maximize GPU utilization. Args: - states (list[SimState] | SimState): Collection of states to batch. Either a + states (SimState | list[SimState]): Collection of states to batch. Either a list of individual SimState objects or a single batched SimState that will be split into individual states. Each SimState has shape information specific to its instance. @@ -565,97 +558,79 @@ def load_states( ) self.batched_states = [] for index_bin in self.index_bins: - self.batched_states.append([self.state_slices[i] for i in index_bin]) + self.batched_states.append([self.state_slices[idx] for idx in index_bin]) self.current_state_bin = 0 return self.max_memory_scaler - def next_batch( - self, *, return_indices: bool = False - ) -> SimState | tuple[SimState, list[int]] | None: + def next_batch(self) -> tuple[T | None, list[int]]: """Get the next batch of states. Returns batches sequentially until all states have been processed. Each batch contains states grouped together to maximize GPU utilization without exceeding memory constraints. - Args: - return_indices (bool): Whether to return original indices along with the - batch. Overrides the value set during initialization. Defaults to False. - Returns: - SimState | tuple[SimState, list[int]] | None: - - If return_indices is False: A concatenated SimState containing the next - batch of states, or None if no more batches. - - If return_indices is True: Tuple of (concatenated SimState, indices), - where indices are the original positions of the states, or None if no - more batches. - - Examples: - Get batches one by one: + tuple[T | None, list[int]]: A tuple containing: + - A concatenated SimState containing the next batch of states, + or None if no more batches + - List of indices of states in the current batch - .. code-block:: python - - all_converged_state, convergence = [], None - while (result := batcher.next_batch(state, convergence))[0] is not None: - state, converged_states = result - all_converged_states.extend(converged_states) + Example:: - evolve_batch(state) - convergence = convergence_criterion(state) - else: - all_converged_states.extend(result[1]) + # Get batches one by one + for batch, indices in batcher: + process_batch(batch) """ # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here if self.current_state_bin < len(self.batched_states): state_bin = self.batched_states[self.current_state_bin] - state = concatenate_states(state_bin) + state = ts.concatenate_states(state_bin) + indices = ( + self.index_bins[self.current_state_bin] + if self.current_state_bin < len(self.index_bins) + else [] + ) self.current_state_bin += 1 - if return_indices: - return state, self.index_bins[self.current_state_bin - 1] - return state - return None + return state, indices + return None, [] - def __iter__(self) -> Iterator[SimState | tuple[SimState, list[int]]]: + def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator. Allows using the batcher in a for loop to iterate through all batches. Resets the current state bin index to start iteration from the beginning. Returns: - Iterator[SimState | tuple[SimState, list[int]]]: Self as an iterator. + Iterator[tuple[T, list[int]]]: Self as an iterator. Example:: # Iterate through all batches - for batch in batcher: + for batch, indices in batcher: process_batch(batch) - """ return self - def __next__(self) -> SimState | tuple[SimState, list[int]]: + def __next__(self) -> tuple[T, list[int]]: """Get the next batch for iteration. Implements the iterator protocol to allow using the batcher in a for loop. - Automatically includes indices if return_indices was set to True during - initialization. Returns: - SimState | tuple[SimState, list[int]]: The next batch of states, - potentially with indices. + tuple[T, list[int]]: The next batch of states and their indices. Raises: StopIteration: When there are no more batches. """ - next_batch = self.next_batch(return_indices=self.return_indices) + next_batch, indices = self.next_batch() if next_batch is None: raise StopIteration - return next_batch + return next_batch, indices - def restore_original_order(self, batched_states: list[SimState]) -> list[SimState]: + def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: """Reorder processed states back to their original sequence. Takes states that were processed in batches and restores them to the @@ -663,7 +638,7 @@ def restore_original_order(self, batched_states: list[SimState]) -> list[SimStat processing to ensure results correspond to the input states. Args: - batched_states (list[SimState]): State batches to reorder. These can be + batched_states (Sequence[SimState]): State batches to reorder. These can be either concatenated batch states that will be split, or already split individual states. @@ -679,7 +654,7 @@ def restore_original_order(self, batched_states: list[SimState]) -> list[SimStat # Process batches and restore original order results = [] - for batch in batcher: + for batch, _indices in batcher: results.append(process_batch(batch)) ordered_results = batcher.restore_original_order(results) @@ -701,7 +676,7 @@ def restore_original_order(self, batched_states: list[SimState]) -> list[SimStat return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] -class InFlightAutoBatcher: +class InFlightAutoBatcher[T: SimState]: """Batcher that dynamically swaps states based on convergence. Optimizes GPU utilization by removing converged states from the batch and @@ -720,7 +695,6 @@ class InFlightAutoBatcher: memory_scales_with (str): Metric type used for memory estimation. max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. - return_indices (bool): Whether to return original indices with batches. max_iterations (int | None): Maximum number of iterations per state. state_slices (list[SimState]): Individual states to be batched. memory_scalers (list[float]): Memory scaling metrics for each state. @@ -764,7 +738,6 @@ def __init__( max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, - return_indices: bool = False, max_iterations: int | None = None, max_memory_padding: float = 1.0, ) -> None: @@ -780,8 +753,6 @@ def __init__( Defaults to "n_atoms_x_density". max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. - return_indices (bool): Whether to return original indices along with batches. - Defaults to False. max_atoms_to_try (int): Maximum number of atoms to try when estimating max_memory_scaler. Defaults to 500,000. memory_scaling_factor (float): Factor to multiply batch size by in each @@ -791,7 +762,7 @@ def __init__( max_iterations (int | None): Maximum number of iterations to process a state before considering it complete, regardless of convergence. Used to prevent infinite loops. Defaults to None (no limit). - max_memory_padding (float): Multiply the autodetermined max_memory_scaler + max_memory_padding (float): Multiply the auto-determined max_memory_scaler by this value to account for fluctuations in max memory. Defaults to 1.0. """ self.model = model @@ -799,14 +770,10 @@ def __init__( self.max_memory_scaler = max_memory_scaler or None self.max_atoms_to_try = max_atoms_to_try self.memory_scaling_factor = memory_scaling_factor - self.return_indices = return_indices self.max_attempts = max_iterations # TODO: change to max_iterations self.max_memory_padding = max_memory_padding - def load_states( - self, - states: list[SimState] | Iterator[SimState] | SimState, - ) -> None: + def load_states(self, states: Sequence[T] | Iterator[T] | T) -> None: """Load new states into the batcher. Processes the input states, computes memory scaling metrics for each, @@ -841,7 +808,7 @@ def load_states( """ if isinstance(states, SimState): states = states.split() - if isinstance(states, list): + if isinstance(states, list | tuple): states = iter(states) self.states_iterator = states @@ -857,7 +824,7 @@ def load_states( self._first_batch = self._get_first_batch() return self.max_memory_scaler - def _get_next_states(self) -> list[SimState]: + def _get_next_states(self) -> list[T]: """Add states from the iterator until max_memory_scaler is reached. Pulls states from the iterator and adds them to the current batch until @@ -866,9 +833,9 @@ def _get_next_states(self) -> list[SimState]: Returns: list[SimState]: new states added to the batch. """ - new_metrics = [] - new_idx = [] - new_states = [] + new_metrics: list[float] = [] + new_idx: list[int] = [] + new_states: list[T] = [] for state in self.states_iterator: metric = calculate_memory_scaler(state, self.memory_scales_with) if metric > self.max_memory_scaler: @@ -914,14 +881,14 @@ def _delete_old_states(self, completed_idx: list[int]) -> None: self.current_scalers.pop(idx) self.completed_idx_og_order.append(og_idx) - def _get_first_batch(self) -> SimState: + def _get_first_batch(self) -> T: """Create and return the first batch of states. Initializes the batcher by estimating memory requirements if needed and creating the first batch of states to process. Returns: - Tuple of (first batch, empty list of completed states). + T: first batch of states. """ # we need to sample a state and use it to estimate the max metric # for the first batch @@ -956,14 +923,11 @@ def _get_first_batch(self) -> SimState: self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding newer_states = self._get_next_states() states = [*states, *newer_states] - return concatenate_states([first_state, *states]) + return ts.concatenate_states([first_state, *states]) def next_batch( # noqa: C901 - self, updated_state: SimState | None, convergence_tensor: torch.Tensor | None - ) -> ( - tuple[SimState | None, list[SimState]] - | tuple[SimState | None, list[SimState], list[int]] - ): + self, updated_state: T | None, convergence_tensor: torch.Tensor | None + ) -> tuple[T, list[T]]: """Get the next batch of states based on convergence. Removes converged states from the batch, adds new states if possible, @@ -979,32 +943,25 @@ def next_batch( # noqa: C901 (False). Should be None only for the first call. Returns: - tuple[SimState | None, list[SimState]] | tuple[SimState | None, - list[SimState], list[int]]: - - If return_indices is False: Tuple of (next_batch, completed_states) - where next_batch is a SimState or None if all states are processed, - and completed_states is a list of SimState objects. - - If return_indices is True: Tuple of (next_batch, completed_states, - indices) where indices are the current batch's positions. + tuple[SimState | None, list[SimState]]: (next_batch, completed_states) + where next_batch is a SimState or None if all states are processed, + and completed_states is a list of SimState objects. Raises: AssertionError: If convergence_tensor doesn't match the expected shape or if other validation checks fail. - Examples: - Process states with convergence checking: - - .. code-block:: python + Example:: - # Initial call - batch, completed = batcher.next_batch(None, None) + # Initial call + batch, completed = batcher.next_batch(None, None) - # Process batch and check for convergence - batch = process_batch(batch) - convergence = check_convergence(batch) + # Process batch and check for convergence + batch = process_batch(batch) + convergence = check_convergence(batch) - # Get next batch with converged states removed and new states added - batch, completed = batcher.next_batch(batch, convergence) + # Get next batch with converged states removed and new states added + batch, completed = batcher.next_batch(batch, convergence) Notes: When max_iterations is set, states that exceed this limit will be @@ -1012,8 +969,6 @@ def next_batch( # noqa: C901 """ if not self.first_batch_returned: self.first_batch_returned = True - if self.return_indices: - return self._first_batch, [], self.current_idx return self._first_batch, [] if ( @@ -1026,6 +981,10 @@ def next_batch( # noqa: C901 # assert statements helpful for debugging, should be moved to validate fn # the first two are most important + if updated_state is None: + raise ValueError("updated_state cannot be None") + if convergence_tensor is None: + raise ValueError("convergence_tensor cannot be None") if len(convergence_tensor) != updated_state.n_systems: raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_systems=}") if len(self.current_idx) != len(self.current_scalers): @@ -1057,23 +1016,16 @@ def next_batch( # noqa: C901 # there are no states left to run, return the completed states if not self.current_idx: - return ( - (None, completed_states, []) - if self.return_indices - else (None, completed_states) - ) + return None, completed_states # type: ignore[invalid-return-type] # concatenate remaining state with next states if updated_state.n_systems > 0: next_states = [updated_state, *next_states] - next_batch = concatenate_states(next_states) - - if self.return_indices: - return next_batch, completed_states, self.current_idx + next_batch = ts.concatenate_states(next_states) return next_batch, completed_states - def restore_original_order(self, completed_states: list[SimState]) -> list[SimState]: + def restore_original_order(self, completed_states: Sequence[T]) -> list[T]: """Reorder completed states back to their original sequence. Takes states that were completed in arbitrary order and restores them @@ -1081,7 +1033,7 @@ def restore_original_order(self, completed_states: list[SimState]) -> list[SimSt the hot-swapping strategy to ensure results correspond to input states. Args: - completed_states (list[SimState]): Completed states to reorder. Each + completed_states (Sequence[SimState]): Completed states to reorder. Each SimState contains simulation data with shape specific to its instance. Returns: @@ -1112,7 +1064,6 @@ def restore_original_order(self, completed_states: list[SimState]) -> list[SimSt or you will only get the subset of states that have completed so far. """ # TODO: should act on full states, not state slices - if len(completed_states) != len(self.completed_idx_og_order): raise ValueError( f"Number of completed states ({len(completed_states)}) does not match " diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 6c8a740de..0e91fac55 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -15,8 +15,8 @@ Physical Review B, 90(22), 224104 Online Resources: - -- Materials Project Documentation: https://docs.materialsproject.org/methodology/elasticity/ +- Materials Project Documentation + https://docs.materialsproject.org/methodology/elasticity/ """ from collections.abc import Callable @@ -25,6 +25,7 @@ import torch from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import OptimState from torch_sim.state import SimState from torch_sim.typing import BravaisType @@ -38,10 +39,10 @@ class DeformationRule: Attributes: axes: List of indices indicating which strain components to consider - for the specific crystal symmetry, following Voigt notation: - [0=xx, 1=yy, 2=zz, 3=yz, 4=xz, 5=xy] + for the specific crystal symmetry, following Voigt notation: + [0=xx, 1=yy, 2=zz, 3=yz, 4=xz, 5=xy] symmetry_handler: Callable function that constructs the stress-strain - relationship matrix according to the crystal symmetry. + relationship matrix according to the crystal symmetry. """ axes: list[int] @@ -62,11 +63,11 @@ def get_bravais_type( # noqa: PLR0911 angle_tol: Tolerance for floating-point comparisons of lattice angles in degrees Returns: - BravaisType: Bravais type + BravaisType: StrEnum value """ # Get cell parameters row_vector_cell = state.row_vector_cell.squeeze() - a, b, c = torch.linalg.norm(row_vector_cell, axis=1) + a, b, c = torch.linalg.norm(row_vector_cell, dim=1) # Get cell angles in degrees alpha = torch.rad2deg( @@ -87,7 +88,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.CUBIC + return BravaisType.cubic # Hexagonal: a = b β‰  c, alpha = beta = 90Β°, gamma = 120Β° if ( @@ -96,7 +97,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 120) < angle_tol ): - return BravaisType.HEXAGONAL + return BravaisType.hexagonal # Tetragonal: a = b β‰  c, alpha = beta = gamma = 90Β° if ( @@ -106,7 +107,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.TETRAGONAL + return BravaisType.tetragonal # Orthorhombic: a β‰  b β‰  c, alpha = beta = gamma = 90Β° if ( @@ -116,7 +117,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(a - b) > length_tol and (abs(b - c) > length_tol or abs(a - c) > length_tol) ): - return BravaisType.ORTHORHOMBIC + return BravaisType.orthorhombic # Monoclinic: a β‰  b β‰  c, alpha = gamma = 90Β°, beta β‰  90Β° if ( @@ -124,7 +125,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(gamma - 90) < angle_tol and abs(beta - 90) > angle_tol ): - return BravaisType.MONOCLINIC + return BravaisType.monoclinic # Trigonal/Rhombohedral: a = b = c, alpha = beta = gamma β‰  90Β° if ( @@ -134,10 +135,10 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - gamma) < angle_tol and abs(alpha - 90) > angle_tol ): - return BravaisType.TRIGONAL + return BravaisType.trigonal # Triclinic: a β‰  b β‰  c, alpha β‰  beta β‰  gamma β‰  90Β° - return BravaisType.TRICLINIC + return BravaisType.triclinic def regular_symmetry(strains: torch.Tensor) -> torch.Tensor: @@ -551,13 +552,13 @@ def triclinic_symmetry(strains: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Matrix of shape (6, 21) where columns correspond to - all possible elastic constants in order: - [C11, C12, C13, C14, C15, C16, - C22, C23, C24, C25, C26, - C33, C34, C35, C36, - C44, C45, C46, - C55, C56, - C66] + all possible elastic constants in order: + [C11, C12, C13, C14, C15, C16, + C22, C23, C24, C25, C26, + C33, C34, C35, C36, + C44, C45, C46, + C55, C56, + C66] """ if not isinstance(strains, torch.Tensor): strains = torch.tensor(strains) @@ -684,7 +685,7 @@ def get_elementary_deformations( n_deform: int = 5, max_strain_normal: float = 0.01, max_strain_shear: float = 0.06, - bravais_type: BravaisType = None, + bravais_type: BravaisType | None = None, ) -> list[SimState]: """Generate elementary deformations for elastic tensor calculation. @@ -715,29 +716,24 @@ def get_elementary_deformations( # Deformation rules for different Bravais lattices # Each tuple contains (allowed_axes, symmetry_handler_function) deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), - BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.TRIGONAL: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), - BravaisType.ORTHORHOMBIC: DeformationRule( + BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), + BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.trigonal: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.tetragonal: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), + BravaisType.orthorhombic: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } - # Default to triclinic (lowest symmetry) if bravais_type not specified - if bravais_type is None: - bravais_type = BravaisType.TRICLINIC - # Get deformation rules for this Bravais lattice rule = deformation_rules[bravais_type] allowed_axes = rule.axes # Generate deformed structures deformed_states = [] - device = state.device - dtype = state.dtype + device, dtype = state.device, state.dtype for axis in allowed_axes: if axis < 3: # Normal strain @@ -794,8 +790,7 @@ def get_strain( - Ξ΅[4] = Ξ΅xz = u[2,0] - Ξ΅[5] = Ξ΅xy = u[1,0] """ - dtype = deformed_state.positions.dtype - device = deformed_state.positions.device + dtype, device = deformed_state.dtype, deformed_state.device if not isinstance(deformed_state, SimState): raise TypeError("deformed_state must be an SimState") @@ -838,7 +833,7 @@ def voigt_6_to_full_3x3_stress(stress_voigt: torch.Tensor) -> torch.Tensor: [Οƒxx, Οƒyy, Οƒzz, Οƒyz, Οƒxz, Οƒxy] in Voigt notation Returns: - torch.Tensor: Tensor of shape (..., 3, 3) containing the full stress matrix + torch.Tensor: Of shape (..., 3, 3) containing the full stress matrix """ device = stress_voigt.device dtype = stress_voigt.dtype @@ -894,7 +889,7 @@ def get_elastic_coeffs( deformed_states: list[SimState], stresses: torch.Tensor, base_pressure: torch.Tensor, - bravais_type: BravaisType, + bravais_type: BravaisType = BravaisType.triclinic, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]]: """Calculate elastic tensor from stress-strain relationships. @@ -907,7 +902,8 @@ def get_elastic_coeffs( stresses: Tensor of shape (n_states, 6) containing stress components for each state base_pressure: Reference pressure of the base state - bravais_type: Crystal system (BravaisType enum) + bravais_type (BravaisType): Crystal system. Defaults to Triclinic (lowest + symmetry). Returns: tuple containing: @@ -927,15 +923,15 @@ def get_elastic_coeffs( """ # Deformation rules for different Bravais lattices deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), - BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.TRIGONAL: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), - BravaisType.ORTHORHOMBIC: DeformationRule( + BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), + BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.trigonal: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.tetragonal: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), + BravaisType.orthorhombic: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } # Get symmetry handler for this Bravais lattice @@ -968,15 +964,15 @@ def get_elastic_coeffs( # Calculate elastic constants with pressure correction p = base_pressure pressure_corrections = { - BravaisType.CUBIC: torch.tensor([-p, p, -p]), - BravaisType.HEXAGONAL: torch.tensor([-p, -p, p, p, -p]), - BravaisType.TRIGONAL: torch.tensor([-p, -p, p, p, p, p, -p]), - BravaisType.TETRAGONAL: torch.tensor([-p, -p, p, p, -p, -p, -p]), - BravaisType.ORTHORHOMBIC: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), - BravaisType.MONOCLINIC: torch.tensor( + BravaisType.cubic: torch.tensor([-p, p, -p]), + BravaisType.hexagonal: torch.tensor([-p, -p, p, p, -p]), + BravaisType.trigonal: torch.tensor([-p, -p, p, p, p, p, -p]), + BravaisType.tetragonal: torch.tensor([-p, -p, p, p, -p, -p, -p]), + BravaisType.orthorhombic: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), + BravaisType.monoclinic: torch.tensor( [-p, -p, -p, p, p, p, -p, -p, -p, p, p, p, p] ), - BravaisType.TRICLINIC: torch.tensor( + BravaisType.triclinic: torch.tensor( [ -p, p, @@ -1039,7 +1035,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 # Initialize full tensor C = torch.zeros((6, 6), dtype=Cij.dtype, device=Cij.device) - if bravais_type == BravaisType.TRICLINIC: + if bravais_type == BravaisType.triclinic: if len(Cij) != 21: raise ValueError( f"Triclinic symmetry requires 21 independent constants, " @@ -1052,19 +1048,19 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[i, j] = C[j, i] = Cij[idx] idx += 1 - elif bravais_type == BravaisType.CUBIC: + elif bravais_type == BravaisType.cubic: C11, C12, C44 = Cij diag = torch.tensor([C11, C11, C11, C44, C44, C44]) C.diagonal().copy_(diag) C[0, 1] = C[1, 0] = C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C12 - elif bravais_type == BravaisType.HEXAGONAL: + elif bravais_type == BravaisType.hexagonal: C11, C12, C13, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13 - elif bravais_type == BravaisType.TRIGONAL: + elif bravais_type == BravaisType.trigonal: C11, C12, C13, C14, C15, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 @@ -1076,7 +1072,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[3, 5] = C[5, 3] = -C15 C[4, 5] = C[5, 4] = C14 - elif bravais_type == BravaisType.TETRAGONAL: + elif bravais_type == BravaisType.tetragonal: C11, C12, C13, C16, C33, C44, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, C66])) C[0, 1] = C[1, 0] = C12 @@ -1084,14 +1080,14 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[0, 5] = C[5, 0] = C16 C[1, 5] = C[5, 1] = -C16 - elif bravais_type == BravaisType.ORTHORHOMBIC: + elif bravais_type == BravaisType.orthorhombic: C11, C12, C13, C22, C23, C33, C44, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C13 C[1, 2] = C[2, 1] = C23 - elif bravais_type == BravaisType.MONOCLINIC: + elif bravais_type == BravaisType.monoclinic: C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 @@ -1108,8 +1104,8 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 def calculate_elastic_tensor( model: ModelInterface, *, - state: SimState, - bravais_type: BravaisType = BravaisType.TRICLINIC, + state: OptimState, + bravais_type: BravaisType = BravaisType.triclinic, max_strain_normal: float = 0.01, max_strain_shear: float = 0.06, n_deform: int = 5, @@ -1127,8 +1123,7 @@ def calculate_elastic_tensor( Returns: torch.Tensor: Elastic tensor """ - device = state.positions.device - dtype = state.positions.dtype + device, dtype = state.device, state.dtype # Calculate deformations for the bravais type deformations = get_elementary_deformations( @@ -1143,17 +1138,15 @@ def calculate_elastic_tensor( ref_pressure = -torch.trace(state.stress.squeeze()) / 3 stresses = torch.zeros((len(deformations), 6), device=device, dtype=dtype) - for i, deformation in enumerate(deformations): + for def_idx, deformation in enumerate(deformations): result = model(deformation) - stresses[i] = full_3x3_to_voigt_6_stress(result["stress"].squeeze()) + stresses[def_idx] = full_3x3_to_voigt_6_stress(result["stress"].squeeze()) # Calculate elastic tensor - C_ij, Res = get_elastic_coeffs( + C_ij, _residuals = get_elastic_coeffs( state, deformations, stresses, ref_pressure, bravais_type ) - C = get_elastic_tensor_from_coeffs(C_ij, bravais_type) - - return C # noqa: RET504 + return get_elastic_tensor_from_coeffs(C_ij, bravais_type) def calculate_elastic_moduli(C: torch.Tensor) -> tuple[float, float, float, float]: diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index e38c1925b..f996d3778 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -6,13 +6,12 @@ supports periodic boundary conditions. Examples: - >>> from torch_sim.integrators import nve - >>> nve_init, nve_update = nve( - ... model, dt=1e-3 * units.time, kT=300.0 * units.temperature - ... ) - >>> state = nve_init(initial_state) + >>> import torch_sim as ts + >>> state = ts.nvt_langevin_init(model, initial_state, kT=300.0 * units.temperature) >>> for _ in range(1000): - ... state = nve_update(state) + ... state = ts.nvt_langevin_update( + ... model, state, dt=1e-3 * units.time, kT=300.0 * units.temperature + ... ) Notes: All integrators support batched operations for efficient parallel simulation @@ -20,8 +19,50 @@ """ # ruff: noqa: F401 +from collections.abc import Callable +from enum import StrEnum +from typing import Any, Final + +import torch_sim as ts from .md import MDState, calculate_momenta, momentum_step, position_step, velocity_verlet -from .npt import NPTLangevinState, npt_langevin -from .nve import nve -from .nvt import nvt_langevin +from .npt import ( + NPTLangevinState, + NPTNoseHooverState, + npt_langevin_init, + npt_langevin_update, + npt_nose_hoover_init, + npt_nose_hoover_invariant, + npt_nose_hoover_update, +) +from .nve import nve_init, nve_update +from .nvt import ( + NVTNoseHooverState, + nvt_langevin_init, + nvt_langevin_update, + nvt_nose_hoover_init, + nvt_nose_hoover_invariant, + nvt_nose_hoover_update, +) + + +class MdFlavor(StrEnum): + """Flavor of molecular dynamics simulation.""" + + nve = "nve" + nvt_langevin = "nvt_langevin" + nvt_nose_hoover = "nvt_nose_hoover" + npt_langevin = "npt_langevin" + npt_nose_hoover = "npt_nose_hoover" + + +# Integrator registry - maps integrator names to (init_fn, step_fn) pairs +INTEGRATOR_REGISTRY: Final[ + dict[MdFlavor, tuple[Callable[..., Any], Callable[..., Any]]] +] = { + MdFlavor.nve: (nve_init, nve_update), + MdFlavor.nvt_langevin: (nvt_langevin_init, nvt_langevin_update), + MdFlavor.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_update), + MdFlavor.npt_langevin: (npt_langevin_init, npt_langevin_update), + MdFlavor.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_update), +} diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 490e35284..4d8f209a0 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -60,7 +60,7 @@ def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, system_idx: torch.Tensor, - kT: torch.Tensor | float, + kT: float | torch.Tensor, seed: int | None = None, ) -> torch.Tensor: """Initialize particle momenta based on temperature. @@ -96,7 +96,7 @@ def calculate_momenta( ) * torch.sqrt(masses * kT).unsqueeze(-1) systemwise_momenta = torch.zeros( - (system_idx[-1] + 1, momenta.shape[1]), device=device, dtype=dtype + size=(int(system_idx[-1]) + 1, momenta.shape[1]), device=device, dtype=dtype ) # create 3 copies of system_idx @@ -117,7 +117,7 @@ def calculate_momenta( ) -def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: +def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """Update particle momenta using current forces. This function performs the momentum update step of velocity Verlet integration @@ -137,7 +137,7 @@ def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def position_step(state: MDState, dt: torch.Tensor) -> MDState: +def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """Update particle positions using current velocities. This function performs the position update step of velocity Verlet integration @@ -164,7 +164,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def velocity_verlet(state: MDState, dt: torch.Tensor, model: ModelInterface) -> MDState: +def velocity_verlet[T: MDState](state: T, dt: torch.Tensor, model: ModelInterface) -> T: """Perform one complete velocity Verlet integration step. This function implements the velocity Verlet algorithm, which provides @@ -324,7 +324,7 @@ def init_fn( xi = torch.zeros(chain_length, dtype=dtype, device=device) p_xi = torch.zeros(chain_length, dtype=dtype, device=device) - Q = kT * tau**2 * torch.ones(chain_length, dtype=dtype, device=device) + Q = kT * torch.square(tau) * torch.ones(chain_length, dtype=dtype, device=device) Q[0] *= degrees_of_freedom return NoseHooverChain(xi, p_xi, Q, tau, KE, degrees_of_freedom) @@ -359,11 +359,11 @@ def substep_fn( M = chain_length - 1 # Update chain momenta backwards - G = p_xi[M - 1] ** 2 / Q[M - 1] - kT + G = torch.square(p_xi[M - 1]) / Q[M - 1] - kT p_xi[M] += delta_4 * G for m in range(M - 1, 0, -1): - G = p_xi[m - 1] ** 2 / Q[m - 1] - kT + G = torch.square(p_xi[m - 1]) / Q[m - 1] - kT scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) @@ -374,7 +374,7 @@ def substep_fn( # Rescale system momenta scale = torch.exp(-delta_2 * p_xi[0] / Q[0]) - KE = KE * scale**2 + KE = KE * torch.square(scale) P = P * scale # Update positions @@ -385,7 +385,7 @@ def substep_fn( for m in range(M): scale = torch.exp(-delta_8 * p_xi[m + 1] / Q[m + 1]) p_xi[m] = scale * (scale * p_xi[m] + delta_4 * G) - G = p_xi[m] ** 2 / Q[m] - kT + G = torch.square(p_xi[m]) / Q[m] - kT p_xi[M] += delta_4 * G return P, NoseHooverChain(xi, p_xi, Q, _tau, KE, DOF), kT @@ -416,29 +416,35 @@ def half_step_chain_fn( return P, state - def update_chain_mass_fn(state: NoseHooverChain, kT: torch.Tensor) -> NoseHooverChain: + def update_chain_mass_fn( + chain_state: NoseHooverChain, kT: torch.Tensor + ) -> NoseHooverChain: """Update chain masses to maintain target oscillation period. Args: - state: Current chain state + chain_state: Current chain state kT: Target temperature Returns: Updated chain state with new masses """ - device = state.positions.device - dtype = state.positions.dtype + device = chain_state.positions.device + dtype = chain_state.positions.dtype - Q = kT * state.tau**2 * torch.ones(chain_length, dtype=dtype, device=device) - Q[0] *= state.degrees_of_freedom + Q = ( + kT + * torch.square(chain_state.tau) + * torch.ones(chain_length, dtype=dtype, device=device) + ) + Q[0] *= chain_state.degrees_of_freedom return NoseHooverChain( - state.positions, - state.momenta, + chain_state.positions, + chain_state.momenta, Q, - state.tau, - state.kinetic_energy, - state.degrees_of_freedom, + chain_state.tau, + chain_state.kinetic_energy, + chain_state.degrees_of_freedom, ) return NoseHooverChainFns(init_fn, half_step_chain_fn, update_chain_mass_fn) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e7d428443..fb4f6d914 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -15,7 +15,6 @@ construct_nose_hoover_chain, ) from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -85,7 +84,380 @@ def momenta(self) -> torch.Tensor: return self.velocities * self.masses.unsqueeze(-1) -# Extracted out from npt_langevin body to test fix in https://github.com/TorchSim/torch-sim/pull/153 +def _npt_langevin_beta( + state: NPTLangevinState, + alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, +) -> torch.Tensor: + """Calculate random noise term for particle Langevin dynamics. + + This function generates the stochastic force term for the Langevin thermostat + according to the fluctuation-dissipation theorem, ensuring proper thermal + sampling at the target temperature. + + Args: + state (NPTLangevinState): Current NPT state + alpha (torch.Tensor): Friction coefficient, either scalar or + shape [n_systems] + kT (torch.Tensor): Temperature in energy units, either scalar or + shape [n_systems] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + + Returns: + torch.Tensor: Random noise term for force calculation [n_particles, n_dim] + """ + # Generate system-specific noise with correct shape + noise = torch.randn_like(state.velocities) + + # Calculate the thermal noise amplitude by system + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_systems) + + # Map system kT to atoms + atom_kT = batch_kT[state.system_idx] + + # Calculate the prefactor for each atom + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + prefactor = torch.sqrt(2 * alpha * atom_kT * dt) + + return prefactor.unsqueeze(-1) * noise + + +def _npt_langevin_cell_beta( + state: NPTLangevinState, cell_alpha: torch.Tensor, kT: torch.Tensor, dt: torch.Tensor +) -> torch.Tensor: + """Generate random noise for cell fluctuations in NPT dynamics. + + This function creates properly scaled random noise for cell dynamics in NPT + simulations, following the fluctuation-dissipation theorem to ensure correct + thermal sampling of cell degrees of freedom. + + Args: + state (NPTLangevinState): Current NPT state + cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or + with shape [n_systems] + kT (torch.Tensor): System temperature in energy units, either scalar or + with shape [n_systems] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + device (torch.device): Device for tensor operations + dtype (torch.dtype): Data type for tensor operations + + Returns: + torch.Tensor: Scaled random noise for cell dynamics with shape + [n_systems, n_dimensions, n_dimensions] + """ + # Generate standard normal distribution (zero mean, unit variance) + noise = torch.randn_like(state.cell_positions, device=state.device, dtype=state.dtype) + + # Ensure cell_alpha and kT have batch dimension if they're scalars + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_systems) + if kT.ndim == 0: + kT = kT.expand(state.n_systems) + + # Reshape for broadcasting + cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) + kT = kT.view(-1, 1, 1) # shape: (n_systems, 1, 1) + dt = dt.expand(state.n_systems).view(-1, 1, 1) if dt.ndim == 0 else dt.view(-1, 1, 1) + + # Scale to satisfy the fluctuation-dissipation theorem + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + scaling_factor = torch.sqrt(2.0 * cell_alpha * kT * dt) + + return scaling_factor * noise + + +def _npt_langevin_cell_position_step( + state: NPTLangevinState, + dt: torch.Tensor, + pressure_force: torch.Tensor, + kT: torch.Tensor, + cell_alpha: torch.Tensor, +) -> NPTLangevinState: + """Update the cell position in NPT dynamics. + + This function updates the cell position (effectively the volume) in NPT dynamics + using the current cell velocities, pressure forces, and thermal noise. It + implements the position update part of the Langevin barostat algorithm. + + Args: + state (NPTLangevinState): Current NPT state + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + pressure_force (torch.Tensor): Pressure force for barostat + [n_systems, n_dim, n_dim] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_systems] + cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or + with shape [n_systems] + + Returns: + NPTLangevinState: Updated state with new cell positions + """ + # Calculate effective mass term + Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Ensure parameters have batch dimension + if dt.ndim == 0: + dt = dt.expand(state.n_systems) + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_systems) + + # Reshape for broadcasting + dt_expanded = dt.view(-1, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) + + # Calculate damping factor for cell position update + cell_b = 1 / (1 + ((cell_alpha_expanded * dt_expanded) / Q_2)) + + # Deterministic velocity contribution + c_1 = cell_b * dt_expanded * state.cell_velocities + + # Force contribution + c_2 = cell_b * dt_expanded * dt_expanded * pressure_force / Q_2 + + # Random noise contribution (thermal fluctuations) + c_3 = cell_b * dt_expanded * _npt_langevin_cell_beta(state, cell_alpha, kT, dt) / Q_2 + + # Update cell positions with all contributions + state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 + return state + + +def _npt_langevin_cell_velocity_step( + state: NPTLangevinState, + F_p_n: torch.Tensor, + dt: torch.Tensor, + pressure_force: torch.Tensor, + cell_alpha: torch.Tensor, + kT: torch.Tensor, +) -> NPTLangevinState: + """Update the cell velocities in NPT dynamics. + + This function updates the cell velocities using a Langevin-type integrator, + accounting for both deterministic forces from pressure differences and + stochastic thermal noise. It implements the velocity update part of the + Langevin barostat algorithm. + + Args: + state (NPTLangevinState): Current NPT state + F_p_n (torch.Tensor): Initial pressure force with shape + [n_systems, n_dimensions, n_dimensions] + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + pressure_force (torch.Tensor): Final pressure force + shape [n_systems, n_dim, n_dim] + cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or + shape [n_systems] + kT (torch.Tensor): Temperature in energy units, either scalar or + shape [n_systems] + + Returns: + NPTLangevinState: Updated state with new cell velocities + """ + # Ensure parameters have batch dimension + if dt.ndim == 0: + dt = dt.expand(state.n_systems) + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_systems) + if kT.ndim == 0: + kT = kT.expand(state.n_systems) + + # Reshape for broadcasting - need to maintain 3x3 dimensions + dt_expanded = dt.view(-1, 1, 1) # shape: (n_systems, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Calculate cell masses per system - reshape to match 3x3 cell matrices + cell_masses_expanded = state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # These factors come from the Langevin integration scheme + a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( + 1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded + ) + b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) + + # Calculate the three terms for velocity update + # a will broadcast from (n_systems, 1, 1) to (n_systems, 3, 3) + c_1 = a * state.cell_velocities # Damped old velocity + + # Force contribution (average of initial and final forces) + c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) + + # Generate system-specific cell noise with correct shape (n_systems, 3, 3) + cell_noise = torch.randn_like(state.cell_velocities) + + # Calculate thermal noise amplitude + noise_prefactor = torch.sqrt( + 2 * cell_alpha_expanded * kT.view(-1, 1, 1) * dt_expanded + ) + noise_term = noise_prefactor * cell_noise / torch.sqrt(cell_masses_expanded) + + # Random noise contribution + c_3 = b * noise_term + + # Update velocities with all contributions + state.cell_velocities = c_1 + c_2 + c_3 + return state + + +def _npt_langevin_position_step( + state: NPTLangevinState, + L_n: torch.Tensor, # This should be shape (n_systems,) + dt: torch.Tensor, + kT: torch.Tensor, + alpha: torch.Tensor, +) -> NPTLangevinState: + """Update the particle positions in NPT dynamics. + + This function updates particle positions accounting for both the changing + cell dimensions and the particle velocities/forces. It handles the scaling + of positions due to volume changes as well as the normal position updates + from velocities. + + Args: + state (NPTLangevinState): Current NPT state + L_n (torch.Tensor): Previous cell length scale with shape [n_systems] + dt: Integration timestep, either scalar or with shape [n_systems] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_systems] + alpha (torch.Tensor | None): Friction coefficient, either scalar or with + shape [n_systems]. + + Returns: + NPTLangevinState: Updated state with new positions + """ + # Calculate effective mass term by system + # Map masses to have batch dimension + M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) + + # Calculate new cell length scale (cube root of volume for isotropic scaling) + L_n_new = torch.pow( + state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 + ) # shape: (n_systems,) + + # Map system-specific L_n and L_n_new to atom-level using system indices + # Make sure L_n is the right shape (n_systems,) before indexing + if L_n.ndim != 1 or L_n.shape[0] != state.n_systems: + # If L_n has wrong shape, calculate it again to ensure correct shape + L_n = torch.pow(state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3) + + # Map system-specific values to atoms using system indices + L_n_atoms = L_n[state.system_idx] # shape: (n_atoms,) + L_n_new_atoms = L_n_new[state.system_idx] # shape: (n_atoms,) + + # Calculate damping factor + alpha_atoms = alpha + if alpha.ndim > 0: + alpha_atoms = alpha[state.system_idx] + dt_atoms = dt + if dt.ndim > 0: + dt_atoms = dt[state.system_idx] + + b = 1 / (1 + ((alpha_atoms * dt_atoms) / M_2)) + + # Scale positions due to cell volume change + c_1 = (L_n_new_atoms / L_n_atoms).unsqueeze(-1) * state.positions + + # Time step factor with average length scale + c_2 = ( + (2 * L_n_new_atoms / (L_n_new_atoms + L_n_atoms)).unsqueeze(-1) + * b + * dt_atoms.unsqueeze(-1) + ) + + # Generate atom-specific noise + noise = torch.randn_like(state.velocities) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_systems) + atom_kT = batch_kT[state.system_idx] + + # Calculate noise prefactor according to fluctuation-dissipation theorem + noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) + noise_term = noise_prefactor.unsqueeze(-1) * noise + + # Velocity and force contributions with random noise + c_3 = ( + state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 + ) + + # Update positions with all contributions + state.positions = c_1 + c_2 * c_3 + + # Apply periodic boundary conditions if needed + if state.pbc: + state.positions = ts.transforms.pbc_wrap_batched( + state.positions, state.cell, state.system_idx + ) + + return state + + +def _npt_langevin_velocity_step( + state: NPTLangevinState, + forces: torch.Tensor, + dt: torch.Tensor, + kT: torch.Tensor, + alpha: torch.Tensor, +) -> NPTLangevinState: + """Update the particle velocities in NPT dynamics. + + This function updates particle velocities using a Langevin-type integrator, + accounting for both deterministic forces and stochastic thermal noise. + It implements the velocity update part of the Langevin thermostat algorithm. + + Args: + state (NPTLangevinState): Current NPT state + forces: Forces on particles + dt: Integration timestep, either scalar or with shape [n_systems] + kT: Target temperature in energy units, either scalar or + with shape [n_systems] + alpha (torch.Tensor | None): Friction coefficient, either scalar or with + shape [n_systems]. + + Returns: + NPTLangevinState: Updated state with new velocities + """ + # Calculate denominator for update equations + M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) + + # Map batch parameters to atom level + alpha_atoms = alpha + if alpha.ndim > 0: + alpha_atoms = alpha[state.system_idx] + dt_atoms = dt + if dt.ndim > 0: + dt_atoms = dt[state.system_idx] + + # Calculate damping factors for Langevin integration + a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) + b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2) + + # Velocity contribution with damping + c_1 = a * state.velocities + + # Force contribution (average of initial and final forces) + c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2 + + # Generate atom-specific noise + noise = torch.randn_like(state.velocities) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_systems) + atom_kT = batch_kT[state.system_idx] + + # Calculate noise prefactor according to fluctuation-dissipation theorem + noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) + noise_term = noise_prefactor.unsqueeze(-1) * noise + + # Random noise contribution + c_3 = b * noise_term / state.masses.unsqueeze(-1) + + # Update velocities with all contributions + state.velocities = c_1 + c_2 + c_3 + return state + + def _compute_cell_force( state: NPTLangevinState, external_pressure: torch.Tensor, @@ -152,36 +524,33 @@ def _compute_cell_force( return virial + e_kin_per_atom * state.n_atoms_per_system.view(-1, 1, 1) -def npt_langevin( # noqa: C901, PLR0915 +def npt_langevin_init( model: ModelInterface, + state: SimState | StateDict, *, - dt: torch.Tensor, kT: torch.Tensor, - external_pressure: torch.Tensor, + dt: torch.Tensor, alpha: torch.Tensor | None = None, cell_alpha: torch.Tensor | None = None, b_tau: torch.Tensor | None = None, seed: int | None = None, -) -> tuple[ - Callable[[SimState | StateDict, torch.Tensor], NPTLangevinState], - Callable[[NPTLangevinState, torch.Tensor], NPTLangevinState], -]: - """Initialize and return an NPT (isothermal-isobaric) integrator with Langevin - dynamics. - - This function sets up integration in the NPT ensemble, where particle number (N), - pressure (P), and temperature (T) are conserved. It allows the simulation cell to - fluctuate to maintain the target pressure, while using Langevin dynamics to - maintain constant temperature. + **_kwargs: Any, +) -> NPTLangevinState: + """Initialize an NPT Langevin state from input data. + + This function creates the initial state for NPT Langevin dynamics, + setting up all necessary variables including particle velocities, + cell parameters, and barostat variables. It computes initial forces + and stress using the provided model. Args: model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + state (SimState | StateDict): Either a SimState object or a dictionary + containing positions, masses, cell, pbc kT (torch.Tensor): Target temperature in energy units, either scalar or with shape [n_systems] - external_pressure (torch.Tensor): Target pressure to maintain, either scalar - or shape [n_systems, n_dim, n_dim] for anisotropic pressure + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] alpha (torch.Tensor, optional): Friction coefficient for particle Langevin thermostat, either scalar or shape [n_systems]. Defaults to 1/(100*dt). cell_alpha (torch.Tensor, optional): Friction coefficient for cell Langevin @@ -192,13 +561,8 @@ def npt_langevin( # noqa: C901, PLR0915 seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: - tuple: - - callable: Function to initialize the NPTLangevinState from input data - with signature: init_fn(state, kT=kT, seed=seed) -> NPTLangevinState - - callable: Update function that evolves system by one timestep - with signature: update_fn(state, dt=dt, kT=kT, - external_pressure=external_pressure, alpha=alpha, - cell_alpha=cell_alpha) -> NPTLangevinState + NPTLangevinState: Initialized state for NPT Langevin integration containing + all required attributes for particle and cell dynamics Notes: - The model must provide stress tensor calculations for proper pressure coupling @@ -224,596 +588,167 @@ def npt_langevin( # noqa: C901, PLR0915 kT = torch.tensor(kT, device=device, dtype=dtype) if isinstance(b_tau, float): b_tau = torch.tensor(b_tau, device=device, dtype=dtype) - if isinstance(external_pressure, float): - external_pressure = torch.tensor(external_pressure, device=device, dtype=dtype) - - def beta( - state: NPTLangevinState, - alpha: torch.Tensor, - kT: torch.Tensor, - dt: torch.Tensor, - ) -> torch.Tensor: - """Calculate random noise term for particle Langevin dynamics. - - This function generates the stochastic force term for the Langevin thermostat - according to the fluctuation-dissipation theorem, ensuring proper thermal - sampling at the target temperature. - - Args: - state (NPTLangevinState): Current NPT state - alpha (torch.Tensor): Friction coefficient, either scalar or - shape [n_systems] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - - Returns: - torch.Tensor: Random noise term for force calculation [n_particles, n_dim] - """ - # Generate system-specific noise with correct shape - noise = torch.randn_like(state.velocities) - - # Calculate the thermal noise amplitude by system - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) - - # Map system kT to atoms - atom_kT = batch_kT[state.system_idx] - - # Calculate the prefactor for each atom - # The standard deviation should be sqrt(2*alpha*kB*T*dt) - prefactor = torch.sqrt(2 * alpha * atom_kT * dt) - - return prefactor.unsqueeze(-1) * noise - - def cell_beta( - state: NPTLangevinState, - cell_alpha: torch.Tensor, - kT: torch.Tensor, - dt: torch.Tensor, - ) -> torch.Tensor: - """Generate random noise for cell fluctuations in NPT dynamics. - - This function creates properly scaled random noise for cell dynamics in NPT - simulations, following the fluctuation-dissipation theorem to ensure correct - thermal sampling of cell degrees of freedom. - - Args: - state (NPTLangevinState): Current NPT state - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_systems] - kT (torch.Tensor): System temperature in energy units, either scalar or - with shape [n_systems] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - - Returns: - torch.Tensor: Scaled random noise for cell dynamics with shape - [n_systems, n_dimensions, n_dimensions] - """ - # Generate standard normal distribution (zero mean, unit variance) - noise = torch.randn_like(state.cell_positions, device=device, dtype=dtype) - - # Ensure cell_alpha and kT have batch dimension if they're scalars - if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_systems) - if kT.ndim == 0: - kT = kT.expand(state.n_systems) - - # Reshape for broadcasting - cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) - kT = kT.view(-1, 1, 1) # shape: (n_systems, 1, 1) - if dt.ndim == 0: - dt = dt.expand(state.n_systems).view(-1, 1, 1) - else: - dt = dt.view(-1, 1, 1) - - # Scale to satisfy the fluctuation-dissipation theorem - # The standard deviation should be sqrt(2*alpha*kB*T*dt) - scaling_factor = torch.sqrt(2.0 * cell_alpha * kT * dt) - - return scaling_factor * noise - - def compute_cell_force( - state: NPTLangevinState, - external_pressure: torch.Tensor, - kT: torch.Tensor, - ) -> torch.Tensor: - """Compute forces on the cell for NPT dynamics. - - This function calculates the forces acting on the simulation cell - based on the difference between internal stress and external pressure, - plus a kinetic contribution. These forces drive the volume changes - needed to maintain constant pressure. - - Args: - state (NPTLangevinState): Current NPT state - external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_systems, n_dimensions, n_dimensions] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] - - Returns: - torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] - """ - return _compute_cell_force(state, external_pressure, kT) - - def cell_position_step( - state: NPTLangevinState, - dt: torch.Tensor, - pressure_force: torch.Tensor, - kT: torch.Tensor = kT, - cell_alpha: torch.Tensor = cell_alpha, - ) -> NPTLangevinState: - """Update the cell position in NPT dynamics. - - This function updates the cell position (effectively the volume) in NPT dynamics - using the current cell velocities, pressure forces, and thermal noise. It - implements the position update part of the Langevin barostat algorithm. - - Args: - state (NPTLangevinState): Current NPT state - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - pressure_force (torch.Tensor): Pressure force for barostat - [n_systems, n_dim, n_dim] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_systems] - - Returns: - NPTLangevinState: Updated state with new cell positions - """ - # Calculate effective mass term - Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) - - # Ensure parameters have batch dimension - if dt.ndim == 0: - dt = dt.expand(state.n_systems) - if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_systems) - - # Reshape for broadcasting - dt_expanded = dt.view(-1, 1, 1) - cell_alpha_expanded = cell_alpha.view(-1, 1, 1) - - # Calculate damping factor for cell position update - cell_b = 1 / (1 + ((cell_alpha_expanded * dt_expanded) / Q_2)) - - # Deterministic velocity contribution - c_1 = cell_b * dt_expanded * state.cell_velocities - - # Force contribution - c_2 = cell_b * dt_expanded * dt_expanded * pressure_force / Q_2 - - # Random noise contribution (thermal fluctuations) - c_3 = ( - cell_b - * dt_expanded - * cell_beta(state=state, cell_alpha=cell_alpha, kT=kT, dt=dt) - / Q_2 - ) - - # Update cell positions with all contributions - state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 - return state - - def cell_velocity_step( - state: NPTLangevinState, - F_p_n: torch.Tensor, - dt: torch.Tensor, - pressure_force: torch.Tensor, - cell_alpha: torch.Tensor, - kT: torch.Tensor, - ) -> NPTLangevinState: - """Update the cell velocities in NPT dynamics. - - This function updates the cell velocities using a Langevin-type integrator, - accounting for both deterministic forces from pressure differences and - stochastic thermal noise. It implements the velocity update part of the - Langevin barostat algorithm. - - Args: - state (NPTLangevinState): Current NPT state - F_p_n (torch.Tensor): Initial pressure force with shape - [n_systems, n_dimensions, n_dimensions] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - pressure_force (torch.Tensor): Final pressure force - shape [n_systems, n_dim, n_dim] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_systems] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] - - Returns: - NPTLangevinState: Updated state with new cell velocities - """ - # Ensure parameters have batch dimension - if dt.ndim == 0: - dt = dt.expand(state.n_systems) - if cell_alpha.ndim == 0: - cell_alpha = cell_alpha.expand(state.n_systems) - if kT.ndim == 0: - kT = kT.expand(state.n_systems) - - # Reshape for broadcasting - need to maintain 3x3 dimensions - dt_expanded = dt.view(-1, 1, 1) # shape: (n_systems, 1, 1) - cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) - - # Calculate cell masses per system - reshape to match 3x3 cell matrices - cell_masses_expanded = state.cell_masses.view( - -1, 1, 1 - ) # shape: (n_systems, 1, 1) - - # These factors come from the Langevin integration scheme - a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( - 1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded - ) - b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) - - # Calculate the three terms for velocity update - # a will broadcast from (n_systems, 1, 1) to (n_systems, 3, 3) - c_1 = a * state.cell_velocities # Damped old velocity - - # Force contribution (average of initial and final forces) - c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) - - # Generate system-specific cell noise with correct shape (n_systems, 3, 3) - cell_noise = torch.randn_like(state.cell_velocities) - - # Calculate thermal noise amplitude - noise_prefactor = torch.sqrt( - 2 * cell_alpha_expanded * kT.view(-1, 1, 1) * dt_expanded - ) - noise_term = noise_prefactor * cell_noise / torch.sqrt(cell_masses_expanded) - - # Random noise contribution - c_3 = b * noise_term - - # Update velocities with all contributions - state.cell_velocities = c_1 + c_2 + c_3 - return state - - def langevin_position_step( - state: NPTLangevinState, - L_n: torch.Tensor, # This should be shape (n_systems,) - dt: torch.Tensor, - kT: torch.Tensor, - ) -> NPTLangevinState: - """Update the particle positions in NPT dynamics. - - This function updates particle positions accounting for both the changing - cell dimensions and the particle velocities/forces. It handles the scaling - of positions due to volume changes as well as the normal position updates - from velocities. - - Args: - state (NPTLangevinState): Current NPT state - L_n (torch.Tensor): Previous cell length scale with shape [n_systems] - dt: Integration timestep, either scalar or with shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - - Returns: - NPTLangevinState: Updated state with new positions - """ - # Calculate effective mass term by system - # Map masses to have batch dimension - M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) - - # Calculate new cell length scale (cube root of volume for isotropic scaling) - L_n_new = torch.pow( - state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 - ) # shape: (n_systems,) - - # Map system-specific L_n and L_n_new to atom-level using system indices - # Make sure L_n is the right shape (n_systems,) before indexing - if L_n.ndim != 1 or L_n.shape[0] != state.n_systems: - # If L_n has wrong shape, calculate it again to ensure correct shape - L_n = torch.pow( - state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 - ) - - # Map system-specific values to atoms using system indices - L_n_atoms = L_n[state.system_idx] # shape: (n_atoms,) - L_n_new_atoms = L_n_new[state.system_idx] # shape: (n_atoms,) - - # Calculate damping factor - alpha_atoms = alpha - if alpha.ndim > 0: - alpha_atoms = alpha[state.system_idx] - dt_atoms = dt - if dt.ndim > 0: - dt_atoms = dt[state.system_idx] - - b = 1 / (1 + ((alpha_atoms * dt_atoms) / M_2)) - - # Scale positions due to cell volume change - c_1 = (L_n_new_atoms / L_n_atoms).unsqueeze(-1) * state.positions - - # Time step factor with average length scale - c_2 = ( - (2 * L_n_new_atoms / (L_n_new_atoms + L_n_atoms)).unsqueeze(-1) - * b - * dt_atoms.unsqueeze(-1) - ) - - # Generate atom-specific noise - noise = torch.randn_like(state.velocities) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) - atom_kT = batch_kT[state.system_idx] - - # Calculate noise prefactor according to fluctuation-dissipation theorem - noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) - noise_term = noise_prefactor.unsqueeze(-1) * noise - - # Velocity and force contributions with random noise - c_3 = ( - state.velocities - + dt_atoms.unsqueeze(-1) * state.forces / M_2 - + noise_term / M_2 - ) - - # Update positions with all contributions - state.positions = c_1 + c_2 * c_3 - - # Apply periodic boundary conditions if needed - if state.pbc: - state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx - ) - - return state - - def langevin_velocity_step( - state: NPTLangevinState, - forces: torch.Tensor, - dt: torch.Tensor, - kT: torch.Tensor, - ) -> NPTLangevinState: - """Update the particle velocities in NPT dynamics. - - This function updates particle velocities using a Langevin-type integrator, - accounting for both deterministic forces and stochastic thermal noise. - It implements the velocity update part of the Langevin thermostat algorithm. - - Args: - state (NPTLangevinState): Current NPT state - forces: Forces on particles - dt: Integration timestep, either scalar or with shape [n_systems] - kT: Target temperature in energy units, either scalar or - with shape [n_systems] - - Returns: - NPTLangevinState: Updated state with new velocities - """ - # Calculate denominator for update equations - M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) - - # Map batch parameters to atom level - alpha_atoms = alpha - if alpha.ndim > 0: - alpha_atoms = alpha[state.system_idx] - dt_atoms = dt - if dt.ndim > 0: - dt_atoms = dt[state.system_idx] - - # Calculate damping factors for Langevin integration - a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) - b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2) - - # Velocity contribution with damping - c_1 = a * state.velocities - - # Force contribution (average of initial and final forces) - c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2 - - # Generate atom-specific noise - noise = torch.randn_like(state.velocities) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) - atom_kT = batch_kT[state.system_idx] - - # Calculate noise prefactor according to fluctuation-dissipation theorem - noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) - noise_term = noise_prefactor.unsqueeze(-1) * noise - - # Random noise contribution - c_3 = b * noise_term / state.masses.unsqueeze(-1) - - # Update velocities with all contributions - state.velocities = c_1 + c_2 + c_3 - return state - - def npt_init( - state: SimState | StateDict, - kT: torch.Tensor = kT, - seed: int | None = seed, - ) -> NPTLangevinState: - """Initialize an NPT Langevin state from input data. - - This function creates the initial state for NPT Langevin dynamics, - setting up all necessary variables including particle velocities, - cell parameters, and barostat variables. It computes initial forces - and stress using the provided model. - - Args: - state (SimState | StateDict): Either a SimState object or a dictionary - containing positions, masses, cell, pbc - kT (torch.Tensor): Temperature in energy units for initializing momenta - seed (int, optional): Random seed for reproducibility - - Returns: - NPTLangevinState: Initialized state for NPT Langevin integration containing - all required attributes for particle and cell dynamics - """ - if not isinstance(state, SimState): - state = SimState(**state) - # Get model output to initialize forces and stress - model_output = model(state) + if not isinstance(state, SimState): + state = SimState(**state) - # Initialize momenta if not provided - momenta = getattr( - state, - "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), - ) + # Get model output to initialize forces and stress + model_output = model(state) - # Initialize cell parameters - reference_cell = state.cell.clone() + # Initialize momenta if not provided + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) - # Calculate initial cell_positions (volume) - cell_positions = ( - torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) - ) # shape: (n_systems, 1, 1) - - # Initialize cell velocities to zero - cell_velocities = torch.zeros((state.n_systems, 3, 3), device=device, dtype=dtype) - - # Calculate cell masses based on system size and temperature - # This follows standard NPT barostat mass scaling - n_atoms_per_system = torch.bincount(state.system_idx) - batch_kT = ( - kT.expand(state.n_systems) - if isinstance(kT, torch.Tensor) and kT.ndim == 0 - else kT - ) - cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau - - # Create the initial state - return NPTLangevinState( - positions=state.positions, - velocities=momenta / state.masses.unsqueeze(-1), - energy=model_output["energy"], - forces=model_output["forces"], - stress=model_output["stress"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - reference_cell=reference_cell, - cell_positions=cell_positions, - cell_velocities=cell_velocities, - cell_masses=cell_masses, - ) + # Initialize cell parameters + reference_cell = state.cell.clone() - def npt_update( - state: NPTLangevinState, - dt: torch.Tensor = dt, - kT: torch.Tensor = kT, - external_pressure: torch.Tensor = external_pressure, - alpha: torch.Tensor = alpha, - cell_alpha: torch.Tensor = cell_alpha, - ) -> NPTLangevinState: - """Perform one complete NPT Langevin dynamics integration step. - - This function implements a modified integration scheme for NPT dynamics, - handling both atomic and cell updates with Langevin thermostats to maintain - constant temperature and pressure. The integration scheme couples particle - motion with cell volume fluctuations. + # Calculate initial cell_positions (volume) + cell_positions = ( + torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) + ) # shape: (n_systems, 1, 1) - Args: - state (NPTLangevinState): Current NPT state with particle and cell variables - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_systems] - external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_systems, n_dim, n_dim] - alpha (torch.Tensor): Position friction coefficient, either scalar or - shape [n_systems] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_systems] + # Initialize cell velocities to zero + cell_velocities = torch.zeros((state.n_systems, 3, 3), device=device, dtype=dtype) - Returns: - NPTLangevinState: Updated NPT state after one timestep with new positions, - velocities, cell parameters, forces, energy, and stress - """ - # Convert any scalar parameters to tensors with batch dimension if needed - if isinstance(alpha, float): - alpha = torch.tensor(alpha, device=device, dtype=dtype) - if isinstance(kT, float): - kT = torch.tensor(kT, device=device, dtype=dtype) - if isinstance(cell_alpha, float): - cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) - - # Make sure parameters have batch dimension if they're scalars - batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT - - # Update barostat mass based on current temperature - # This ensures proper coupling between system and barostat - n_atoms_per_system = torch.bincount(state.system_idx) - state.cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau - - # Compute model output for current state - model_output = model(state) - state.forces = model_output["forces"] - state.stress = model_output["stress"] - - # Store initial values for integration - forces = state.forces - F_p_n = compute_cell_force( - state=state, external_pressure=external_pressure, kT=kT - ) - L_n = torch.pow( - state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 - ) # shape: (n_systems,) - - # Step 1: Update cell position - state = cell_position_step(state=state, dt=dt, pressure_force=F_p_n, kT=kT) - - # Update cell (currently only isotropic fluctuations) - dim = state.positions.shape[1] # Usually 3 for 3D - # V_0 and V are shape: (n_systems,) - V_0 = torch.linalg.det(state.reference_cell) - V = state.cell_positions.reshape(state.n_systems, -1)[:, 0] + # Calculate cell masses based on system size and temperature + # This follows standard NPT barostat mass scaling + n_atoms_per_system = torch.bincount(state.system_idx) + batch_kT = ( + kT.expand(state.n_systems) + if isinstance(kT, torch.Tensor) and kT.ndim == 0 + else kT + ) + cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + + # Create the initial state + return NPTLangevinState( + positions=state.positions, + velocities=momenta / state.masses.unsqueeze(-1), + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + reference_cell=reference_cell, + cell_positions=cell_positions, + cell_velocities=cell_velocities, + cell_masses=cell_masses, + ) - # Scale cell uniformly in all dimensions - scaling = (V / V_0) ** (1.0 / dim) # shape: (n_systems,) - # Apply scaling to reference cell to get new cell - new_cell = torch.zeros_like(state.cell) - for b in range(state.n_systems): - new_cell[b] = scaling[b] * state.reference_cell[b] +def npt_langevin_update( + model: ModelInterface, + state: NPTLangevinState, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + alpha: torch.Tensor, + cell_alpha: torch.Tensor, + b_tau: torch.Tensor, +) -> NPTLangevinState: + """Perform one complete NPT Langevin dynamics integration step. - state.cell = new_cell + This function implements a modified integration scheme for NPT dynamics, + handling both atomic and cell updates with Langevin thermostats to maintain + constant temperature and pressure. The integration scheme couples particle + motion with cell volume fluctuations. - # Step 2: Update particle positions - state = langevin_position_step(state=state, L_n=L_n, dt=dt, kT=kT) + Args: + model (ModelInterface): Neural network model that computes energies, forces, + and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. + state (NPTLangevinState): Current NPT state with particle and cell variables + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + kT (torch.Tensor): Target temperature in energy units, either scalar or + shape [n_systems] + external_pressure (torch.Tensor): Target external pressure, either scalar or + tensor with shape [n_systems, n_dim, n_dim] + alpha (torch.Tensor): Position friction coefficient, either scalar or + shape [n_systems] + cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or + shape [n_systems] + b_tau (torch.Tensor): Barostat time constant, either scalar or shape [n_systems] - # Recompute model output after position updates - model_output = model(state) - state.energy = model_output["energy"] - state.forces = model_output["forces"] - state.stress = model_output["stress"] + Returns: + NPTLangevinState: Updated NPT state after one timestep with new positions, + velocities, cell parameters, forces, energy, and stress + """ + device, dtype = model.device, model.dtype - # Compute updated pressure force - F_p_n_new = compute_cell_force( - state=state, external_pressure=external_pressure, kT=kT - ) + # Convert any scalar parameters to tensors with batch dimension if needed + if isinstance(alpha, float): + alpha = torch.tensor(alpha, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + if isinstance(cell_alpha, float): + cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) - # Step 3: Update cell velocities - state = cell_velocity_step( - state=state, - F_p_n=F_p_n, - dt=dt, - pressure_force=F_p_n_new, - cell_alpha=cell_alpha, - kT=kT, - ) + # Make sure parameters have batch dimension if they're scalars + batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT - # Step 4: Update particle velocities - state = langevin_velocity_step(state=state, forces=forces, dt=dt, kT=kT) + # Update barostat mass based on current temperature + # This ensures proper coupling between system and barostat + n_atoms_per_system = torch.bincount(state.system_idx) + state.cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + + # Compute model output for current state + model_output = model(state) + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Store initial values for integration + forces = state.forces + F_p_n = _compute_cell_force(state=state, external_pressure=external_pressure, kT=kT) + L_n = torch.pow( + state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 + ) # shape: (n_systems,) + + # Step 1: Update cell position + state = _npt_langevin_cell_position_step(state, dt, F_p_n, kT, cell_alpha) + + # Update cell (currently only isotropic fluctuations) + dim = state.positions.shape[1] # Usually 3 for 3D + # V_0 and V are shape: (n_systems,) + V_0 = torch.linalg.det(state.reference_cell) + V = state.cell_positions.reshape(state.n_systems, -1)[:, 0] + + # Scale cell uniformly in all dimensions + scaling = (V / V_0) ** (1.0 / dim) # shape: (n_systems,) + + # Apply scaling to reference cell to get new cell + new_cell = torch.zeros_like(state.cell) + for sys_idx in range(state.n_systems): + new_cell[sys_idx] = scaling[sys_idx] * state.reference_cell[sys_idx] + + state.cell = new_cell + + # Step 2: Update particle positions + state = _npt_langevin_position_step(state, L_n, dt, kT, alpha) + + # Recompute model output after position updates + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Compute updated pressure force + F_p_n_new = _compute_cell_force( + state=state, external_pressure=external_pressure, kT=kT + ) - return state # noqa: RET504 + # Step 3: Update cell velocities + state = _npt_langevin_cell_velocity_step(state, F_p_n, dt, F_p_n_new, cell_alpha, kT) - return npt_init, npt_update + # Step 4: Update particle velocities + return _npt_langevin_velocity_step(state, forces, dt, kT, alpha) @dataclass @@ -928,674 +863,663 @@ def current_cell(self) -> torch.Tensor: return scale * self.reference_cell -def npt_nose_hoover( # noqa: C901, PLR0915 - *, - model: ModelInterface, - kT: torch.Tensor, - external_pressure: torch.Tensor, - dt: torch.Tensor, - chain_length: int = 3, - chain_steps: int = 2, - sy_steps: int = 3, -) -> tuple[ - Callable[[SimState | StateDict], NPTNoseHooverState], - Callable[[NPTNoseHooverState, torch.Tensor], NPTNoseHooverState], -]: - """Create an NPT simulation with Nose-Hoover chain thermostats. +def _npt_nose_hoover_cell_info( + state: NPTNoseHooverState, +) -> tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]: + """Gets the current volume and a function to compute the cell from volume. - This function returns initialization and update functions for NPT molecular dynamics - with Nose-Hoover chain thermostats for temperature and pressure control. + This helper function computes the current system volume and returns a function + that can compute the simulation cell for any given volume. This is useful for + integration algorithms that need to update the cell based on volume changes. Args: - model (ModelInterface): Model to compute forces and energies - kT (torch.Tensor): Target temperature in energy units - external_pressure (torch.Tensor): Target external pressure - dt (torch.Tensor): Integration timestep - chain_length (int, optional): Length of Nose-Hoover chains. Defaults to 3. - chain_steps (int, optional): Chain integration substeps. Defaults to 2. - sy_steps (int, optional): Suzuki-Yoshida integration order. Defaults to 3. + state (NPTNoseHooverState): Current state of the NPT system Returns: tuple: - - Callable[[SimState | StateDict], NPTNoseHooverState]: Initialization - function - - Callable[[NPTNoseHooverState, torch.Tensor], NPTNoseHooverState]: Update - function + - torch.Tensor: Current system volume with shape [n_systems] + - callable: Function that takes a volume tensor [n_systems] and returns + the corresponding cell matrix [n_systems, n_dimensions, n_dimensions] Notes: - - Uses Nose-Hoover chains for both temperature and pressure control - - Implements symplectic integration with Suzuki-Yoshida decomposition - - Cell dynamics use logarithmic coordinates for volume updates - - Conserves extended system Hamiltonian + - Uses logarithmic cell coordinate parameterization + - Volume changes are measured relative to reference cell + - Cell scaling preserves shape while changing volume + - Supports batched operations """ - device, dtype = model.device, model.dtype - - def _npt_cell_info( - state: NPTNoseHooverState, - ) -> tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]: - """Gets the current volume and a function to compute the cell from volume. + dim = state.positions.shape[1] + ref = state.reference_cell # [n_systems, dim, dim] + V_0 = torch.det(ref) # [n_systems] - Reference volume + V = V_0 * torch.exp(dim * state.cell_position) # [n_systems] - Current volume - This helper function computes the current system volume and returns a function - that can compute the simulation cell for any given volume. This is useful for - integration algorithms that need to update the cell based on volume changes. + def volume_to_cell(V: torch.Tensor) -> torch.Tensor: + """Compute cell matrix for given volumes. Args: - state (NPTNoseHooverState): Current state of the NPT system + V (torch.Tensor): Volumes with shape [n_systems] Returns: - tuple: - - torch.Tensor: Current system volume with shape [n_systems] - - callable: Function that takes a volume tensor [n_systems] and returns - the corresponding cell matrix [n_systems, n_dimensions, n_dimensions] - - Notes: - - Uses logarithmic cell coordinate parameterization - - Volume changes are measured relative to reference cell - - Cell scaling preserves shape while changing volume - - Supports batched operations + torch.Tensor: Cell matrices with shape [n_systems, dim, dim] """ - dim = state.positions.shape[1] - ref = state.reference_cell # [n_systems, dim, dim] - V_0 = torch.det(ref) # [n_systems] - Reference volume - V = V_0 * torch.exp(dim * state.cell_position) # [n_systems] - Current volume + scale = torch.pow(V / V_0, 1.0 / dim) # [n_systems] + # Expand scale to [n_systems, 1, 1] for broadcasting + scale = scale.unsqueeze(-1).unsqueeze(-1) + return scale * ref - def volume_to_cell(V: torch.Tensor) -> torch.Tensor: - """Compute cell matrix for given volumes. + return V, volume_to_cell - Args: - V (torch.Tensor): Volumes with shape [n_systems] - Returns: - torch.Tensor: Cell matrices with shape [n_systems, dim, dim] - """ - scale = (V / V_0) ** (1.0 / dim) # [n_systems] - # Expand scale to [n_systems, 1, 1] for broadcasting - scale = scale.unsqueeze(-1).unsqueeze(-1) - return scale * ref +def _npt_nose_hoover_update_cell_mass( + state: NPTNoseHooverState, kT: torch.Tensor, device: torch.device, dtype: torch.dtype +) -> NPTNoseHooverState: + """Update the cell mass parameter in an NPT simulation. - return V, volume_to_cell + This function updates the mass parameter associated with cell volume fluctuations + based on the current system size and target temperature. The cell mass controls + how quickly the volume can change and is chosen to maintain stable pressure + control. - def update_cell_mass( - state: NPTNoseHooverState, kT: torch.Tensor - ) -> NPTNoseHooverState: - """Update the cell mass parameter in an NPT simulation. + Args: + state (NPTNoseHooverState): Current state of the NPT system + kT (torch.Tensor): Target temperature in energy units, either scalar or + shape [n_systems] + device (torch.device): Device for tensor operations + dtype (torch.dtype): Data type for tensor operations - This function updates the mass parameter associated with cell volume fluctuations - based on the current system size and target temperature. The cell mass controls - how quickly the volume can change and is chosen to maintain stable pressure - control. + Returns: + NPTNoseHooverState: Updated state with new cell mass - Args: - state (NPTNoseHooverState): Current state of the NPT system - kT (torch.Tensor): Target temperature in energy units, either scalar or - shape [n_systems] + Notes: + - Cell mass scales with system size (N+1) and dimensionality + - Larger cell mass gives slower but more stable volume fluctuations + - Mass depends on barostat relaxation time (tau) + - Supports batched operations + """ + _n_particles, dim = state.positions.shape - Returns: - NPTNoseHooverState: Updated state with new cell mass + # Convert kT to tensor if it's not already one + if not isinstance(kT, torch.Tensor): + kT = torch.tensor(kT, device=device, dtype=dtype) - Notes: - - Cell mass scales with system size (N+1) and dimensionality - - Larger cell mass gives slower but more stable volume fluctuations - - Mass depends on barostat relaxation time (tau) - - Supports batched operations - """ - n_particles, dim = state.positions.shape + # Handle both scalar and batched kT + kT_system = kT.expand(state.n_systems) if kT.ndim == 0 else kT - # Convert kT to tensor if it's not already one - if not isinstance(kT, torch.Tensor): - kT = torch.tensor(kT, device=device, dtype=dtype) + # Calculate cell masses for each system + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + cell_mass = ( + dim * (n_atoms_per_system + 1) * kT_system * torch.square(state.barostat.tau) + ) - # Handle both scalar and batched kT - kT_system = kT.expand(state.n_systems) if kT.ndim == 0 else kT + # Update state with new cell masses + state.cell_mass = cell_mass.to(device=device, dtype=dtype) + return state - # Calculate cell masses for each system - n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) - cell_mass = dim * (n_atoms_per_system + 1) * kT_system * state.barostat.tau**2 - # Update state with new cell masses - state.cell_mass = cell_mass.to(device=device, dtype=dtype) - return state +def _npt_nose_hoover_sinhx_x(x: torch.Tensor) -> torch.Tensor: + """Compute sinh(x)/x using Taylor series expansion near x=0. - def sinhx_x(x: torch.Tensor) -> torch.Tensor: - """Compute sinh(x)/x using Taylor series expansion near x=0. + This function implements a Taylor series approximation of sinh(x)/x that is + accurate near x=0. The series expansion is: + sinh(x)/x = 1 + xΒ²/6 + x⁴/120 + x⁢/5040 + x⁸/362880 + x¹⁰/39916800 - This function implements a Taylor series approximation of sinh(x)/x that is - accurate near x=0. The series expansion is: - sinh(x)/x = 1 + xΒ²/6 + x⁴/120 + x⁢/5040 + x⁸/362880 + x¹⁰/39916800 + Args: + x (torch.Tensor): Input tensor - Args: - x (torch.Tensor): Input tensor + Returns: + torch.Tensor: Approximation of sinh(x)/x - Returns: - torch.Tensor: Approximation of sinh(x)/x - - Notes: - - Uses 6 terms of Taylor series for good accuracy near x=0 - - Relative error < 1e-12 for |x| < 0.5 - - More efficient than direct sinh(x)/x computation for small x - - Avoids division by zero at x=0 - - Example: - >>> x = torch.tensor([0.0, 0.1, 0.2]) - >>> y = sinhx_x(x) - >>> print(y) # tensor([1, 1.0017, 1.0067]) - """ - return ( - 1 + x**2 / 6 + x**4 / 120 + x**6 / 5040 + x**8 / 362_880 + x**10 / 39_916_800 - ) + Notes: + - Uses 6 terms of Taylor series for good accuracy near x=0 + - Relative error < 1e-12 for |x| < 0.5 + - More efficient than direct sinh(x)/x computation for small x + - Avoids division by zero at x=0 + + Example: + >>> x = torch.tensor([0.0, 0.1, 0.2]) + >>> y = sinhx_x(x) + >>> print(y) # tensor([1, 1.0017, 1.0067]) + """ + return ( + 1 + + torch.pow(x, 2) / 6 + + torch.pow(x, 4) / 120 + + torch.pow(x, 6) / 5040 + + torch.pow(x, 8) / 362_880 + + torch.pow(x, 10) / 39_916_800 + ) - def exp_iL1( # noqa: N802 - state: NPTNoseHooverState, - velocities: torch.Tensor, - cell_velocity: torch.Tensor, - dt: torch.Tensor, - ) -> torch.Tensor: - """Apply the exp(iL1) operator for NPT dynamics position updates. - This function implements the position update operator for NPT dynamics using - a symplectic integration scheme. It accounts for both particle motion and - cell scaling effects through the cell velocity, with optional periodic boundary - conditions. +def _npt_nose_hoover_exp_iL1( # noqa: N802 + state: NPTNoseHooverState, + velocities: torch.Tensor, + cell_velocity: torch.Tensor, + dt: torch.Tensor, +) -> torch.Tensor: + """Apply the exp(iL1) operator for NPT dynamics position updates. - The update follows the form: - R_new = R + (exp(x) - 1)R + dt*V*exp(x/2)*sinh(x/2)/(x/2) - where x = V_b * dt is the cell velocity term + This function implements the position update operator for NPT dynamics using + a symplectic integration scheme. It accounts for both particle motion and + cell scaling effects through the cell velocity, with optional periodic boundary + conditions. - Args: - state (NPTNoseHooverState): Current simulation state - velocities (torch.Tensor): Particle velocities [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] - dt (torch.Tensor): Integration timestep + The update follows the form: + R_new = R + (exp(x) - 1)R + dt*V*exp(x/2)*sinh(x/2)/(x/2) + where x = V_b * dt is the cell velocity term - Returns: - torch.Tensor: Updated particle positions with optional periodic wrapping - - Notes: - - Uses Taylor series for sinh(x)/x near x=0 for numerical stability - - Properly handles cell scaling through cell_velocity - - Maintains time-reversibility of the integration scheme - - Applies periodic boundary conditions if state.pbc is True - - Supports batched operations with proper atom-to-system mapping - """ - # Map system-level cell velocities to atom level using system indices - cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] + Args: + state (NPTNoseHooverState): Current simulation state + velocities (torch.Tensor): Particle velocities [n_particles, n_dimensions] + cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] + dt (torch.Tensor): Integration timestep + + Returns: + torch.Tensor: Updated particle positions with optional periodic wrapping - # Compute cell velocity terms per atom - x = cell_velocity_atoms * dt # [n_atoms] - x_2 = x / 2 # [n_atoms] + Notes: + - Uses Taylor series for sinh(x)/x near x=0 for numerical stability + - Properly handles cell scaling through cell_velocity + - Maintains time-reversibility of the integration scheme + - Applies periodic boundary conditions if state.pbc is True + - Supports batched operations with proper atom-to-system mapping + """ + # Map system-level cell velocities to atom level using system indices + cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] - # Compute sinh(x/2)/(x/2) using stable Taylor series - sinh_term = sinhx_x(x_2) # [n_atoms] + # Compute cell velocity terms per atom + x = cell_velocity_atoms * dt # [n_atoms] + x_2 = x / 2 # [n_atoms] - # Expand dimensions for broadcasting with positions [n_atoms, 3] - x_expanded = x.unsqueeze(-1) # [n_atoms, 1] - x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] - sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] + # Compute sinh(x/2)/(x/2) using stable Taylor series + sinh_term = _npt_nose_hoover_sinhx_x(x_2) # [n_atoms] - # Compute position updates - new_positions = ( - state.positions * (torch.exp(x_expanded) - 1) - + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded + # Expand dimensions for broadcasting with positions [n_atoms, 3] + x_expanded = x.unsqueeze(-1) # [n_atoms, 1] + x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] + sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] + + # Compute position updates + new_positions = ( + state.positions * (torch.exp(x_expanded) - 1) + + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded + ) + new_positions = state.positions + new_positions + + # Apply periodic boundary conditions if needed + if state.pbc: + return ts.transforms.pbc_wrap_batched( + new_positions, state.current_cell, state.system_idx ) - new_positions = state.positions + new_positions + return new_positions + + +def _npt_nose_hoover_exp_iL2( # noqa: N802 + state: NPTNoseHooverState, + alpha: torch.Tensor, + momenta: torch.Tensor, + forces: torch.Tensor, + cell_velocity: torch.Tensor, + dt_2: torch.Tensor, +) -> torch.Tensor: + """Apply the exp(iL2) operator for NPT dynamics momentum updates. + + This function implements the momentum update operator for NPT dynamics using + a symplectic integration scheme. It accounts for both force terms and + cell velocity scaling effects. + + The update follows the form: + P_new = P*exp(-x) + dt/2 * F * exp(-x/2) * sinh(x/2)/(x/2) + where x = alpha * V_b * dt/2 + + Args: + state (NPTNoseHooverState): Current simulation state for batch mapping + alpha (torch.Tensor): Cell scaling parameter + momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] + forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] + cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] + dt_2 (torch.Tensor): Half timestep (dt/2) + + Returns: + torch.Tensor: Updated particle momenta + + Notes: + - Uses Taylor series for sinh(x)/x near x=0 for numerical stability + - Properly handles cell velocity scaling effects + - Maintains time-reversibility of the integration scheme + - Part of the NPT integration algorithm + - Supports batched operations with proper atom-to-system mapping + """ + # Map system-level cell velocities to atom level using system indices + cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] + + # Compute scaling terms per atom + alpha_atoms = alpha[state.system_idx] # [n_atoms] + x = alpha_atoms * cell_velocity_atoms * dt_2 # [n_atoms] + x_2 = x / 2 # [n_atoms] + + # Compute sinh(x/2)/(x/2) using stable Taylor series + sinh_term = _npt_nose_hoover_sinhx_x(x_2) # [n_atoms] + + # Expand dimensions for broadcasting with momenta [n_atoms, 3] + x_expanded = x.unsqueeze(-1) # [n_atoms, 1] + x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] + sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] + + # Update momenta with both scaling and force terms + return momenta * torch.exp(-x_expanded) + dt_2 * forces * sinh_expanded * torch.exp( + -x_2_expanded + ) + + +def _npt_nose_hoover_compute_cell_force( + alpha: torch.Tensor, + volume: torch.Tensor, + positions: torch.Tensor, + momenta: torch.Tensor, + masses: torch.Tensor, + stress: torch.Tensor, + external_pressure: torch.Tensor, + system_idx: torch.Tensor, +) -> torch.Tensor: + """Compute the force on the cell degree of freedom in NPT dynamics. + + This function calculates the force driving cell volume changes in NPT simulations. + The force includes contributions from: + 1. Kinetic energy scaling (alpha * KE) + 2. Internal stress (from stress_fn) + 3. External pressure (P*V) + + Args: + alpha (torch.Tensor): Cell scaling parameter + volume (torch.Tensor): Current system volume with shape [n_systems] + positions (torch.Tensor): Particle positions [n_particles, n_dimensions] + momenta (torch.Tensor): Particle momenta [n_particles, n_dimensions] + masses (torch.Tensor): Particle masses [n_particles] + stress (torch.Tensor): Stress tensor [n_systems, n_dimensions, n_dimensions] + external_pressure (torch.Tensor): Target external pressure + system_idx (torch.Tensor): System indices for atoms [n_particles] - # Apply periodic boundary conditions if needed - if state.pbc: - return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx + Returns: + torch.Tensor: Force on the cell degree of freedom with shape [n_systems] + + Notes: + - Force drives volume changes to maintain target pressure + - Includes both kinetic and potential contributions + - Uses stress tensor for potential energy contribution + - Properly handles periodic boundary conditions + - Supports batched operations + """ + _N, dim = positions.shape + n_systems = len(volume) + + # Compute kinetic energy contribution per system + # Split momenta and masses by system + KE_per_system = torch.zeros(n_systems, device=positions.device, dtype=positions.dtype) + for sys_idx in range(n_systems): + system_mask = system_idx == sys_idx + if system_mask.any(): + system_momenta = momenta[system_mask] + system_masses = masses[system_mask] + KE_per_system[sys_idx] = ts.calc_kinetic_energy( + masses=system_masses, momenta=system_momenta ) - return new_positions - - def exp_iL2( # noqa: N802 - state: NPTNoseHooverState, - alpha: torch.Tensor, - momenta: torch.Tensor, - forces: torch.Tensor, - cell_velocity: torch.Tensor, - dt_2: torch.Tensor, - ) -> torch.Tensor: - """Apply the exp(iL2) operator for NPT dynamics momentum updates. - - This function implements the momentum update operator for NPT dynamics using - a symplectic integration scheme. It accounts for both force terms and - cell velocity scaling effects. - - The update follows the form: - P_new = P*exp(-x) + dt/2 * F * exp(-x/2) * sinh(x/2)/(x/2) - where x = alpha * V_b * dt/2 - Args: - state (NPTNoseHooverState): Current simulation state for batch mapping - alpha (torch.Tensor): Cell scaling parameter - momenta (torch.Tensor): Current particle momenta [n_particles, n_dimensions] - forces (torch.Tensor): Forces on particles [n_particles, n_dimensions] - cell_velocity (torch.Tensor): Cell velocity with shape [n_systems] - dt_2 (torch.Tensor): Half timestep (dt/2) + # Get stress tensor and compute trace per system + # Handle stress tensor with batch dimension + if stress.ndim == 3: + internal_pressure = torch.diagonal(stress, dim1=-2, dim2=-1).sum( + dim=-1 + ) # [n_systems] + else: + # Single system case - expand to batch dimension + internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems) + + # Compute force on cell coordinate per system + # F = alpha * KE - dU/dV - P*V*d + return ( + (alpha * KE_per_system) + - (internal_pressure * volume) + - (external_pressure * volume * dim) + ) - Returns: - torch.Tensor: Updated particle momenta - - Notes: - - Uses Taylor series for sinh(x)/x near x=0 for numerical stability - - Properly handles cell velocity scaling effects - - Maintains time-reversibility of the integration scheme - - Part of the NPT integration algorithm - - Supports batched operations with proper atom-to-system mapping - """ - # Map system-level cell velocities to atom level using system indices - cell_velocity_atoms = cell_velocity[state.system_idx] # [n_atoms] - - # Compute scaling terms per atom - x = alpha * cell_velocity_atoms * dt_2 # [n_atoms] - x_2 = x / 2 # [n_atoms] - - # Compute sinh(x/2)/(x/2) using stable Taylor series - sinh_term = sinhx_x(x_2) # [n_atoms] - - # Expand dimensions for broadcasting with momenta [n_atoms, 3] - x_expanded = x.unsqueeze(-1) # [n_atoms, 1] - x_2_expanded = x_2.unsqueeze(-1) # [n_atoms, 1] - sinh_expanded = sinh_term.unsqueeze(-1) # [n_atoms, 1] - - # Update momenta with both scaling and force terms - return momenta * torch.exp( - -x_expanded - ) + dt_2 * forces * sinh_expanded * torch.exp(-x_2_expanded) - - def compute_cell_force( - alpha: torch.Tensor, - volume: torch.Tensor, - positions: torch.Tensor, - momenta: torch.Tensor, - masses: torch.Tensor, - stress: torch.Tensor, - external_pressure: torch.Tensor, - system_idx: torch.Tensor, - ) -> torch.Tensor: - """Compute the force on the cell degree of freedom in NPT dynamics. - - This function calculates the force driving cell volume changes in NPT simulations. - The force includes contributions from: - 1. Kinetic energy scaling (alpha * KE) - 2. Internal stress (from stress_fn) - 3. External pressure (P*V) - Args: - alpha (torch.Tensor): Cell scaling parameter - volume (torch.Tensor): Current system volume with shape [n_systems] - positions (torch.Tensor): Particle positions [n_particles, n_dimensions] - momenta (torch.Tensor): Particle momenta [n_particles, n_dimensions] - masses (torch.Tensor): Particle masses [n_particles] - stress (torch.Tensor): Stress tensor [n_systems, n_dimensions, n_dimensions] - external_pressure (torch.Tensor): Target external pressure - system_idx (torch.Tensor): System indices for atoms [n_particles] +def _npt_nose_hoover_inner_step( + model: ModelInterface, + state: NPTNoseHooverState, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTNoseHooverState: + """Perform one inner step of NPT integration using velocity Verlet algorithm. - Returns: - torch.Tensor: Force on the cell degree of freedom with shape [n_systems] - - Notes: - - Force drives volume changes to maintain target pressure - - Includes both kinetic and potential contributions - - Uses stress tensor for potential energy contribution - - Properly handles periodic boundary conditions - - Supports batched operations - """ - N, dim = positions.shape - n_systems = len(volume) + This function implements a single integration step for NPT dynamics, including: + 1. Cell momentum and particle momentum updates (half step) + 2. Position and cell position updates (full step) + 3. Force updates with new positions and cell + 4. Final momentum updates (half step) - # Compute kinetic energy contribution per system - # Split momenta and masses by system - KE_per_system = torch.zeros( - n_systems, device=positions.device, dtype=positions.dtype - ) - for b in range(n_systems): - system_mask = system_idx == b - if system_mask.any(): - system_momenta = momenta[system_mask] - system_masses = masses[system_mask] - KE_per_system[b] = calc_kinetic_energy( - masses=system_masses, momenta=system_momenta - ) - - # Get stress tensor and compute trace per system - # Handle stress tensor with batch dimension - if stress.ndim == 3: - internal_pressure = torch.diagonal(stress, dim1=-2, dim2=-1).sum( - dim=-1 - ) # [n_systems] - else: - # Single system case - expand to batch dimension - internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems) - - # Compute force on cell coordinate per system - # F = alpha * KE - dU/dV - P*V*d - return ( - (alpha * KE_per_system) - - (internal_pressure * volume) - - (external_pressure * volume * dim) - ) + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTNoseHooverState): Current system state + dt (torch.Tensor): Integration timestep + external_pressure (torch.Tensor): Target external pressure - def npt_inner_step( - state: NPTNoseHooverState, - dt: torch.Tensor, - external_pressure: torch.Tensor, - ) -> NPTNoseHooverState: - """Perform one inner step of NPT integration using velocity Verlet algorithm. + Returns: + NPTNoseHooverState: Updated state after one integration step + """ + # Get target pressure from kwargs or use default + dt_2 = dt / 2 + + # Unpack state variables for clarity + positions = state.positions + momenta = state.momenta + masses = state.masses + forces = state.forces + cell_position = state.cell_position # [n_systems] + cell_momentum = state.cell_momentum # [n_systems] + cell_mass = state.cell_mass # [n_systems] + + # Get current volume and cell function + volume, volume_to_cell = _npt_nose_hoover_cell_info(state) + cell = volume_to_cell(volume) + + # Get model output + state.cell = cell + model_output = model(state) + + # First half step: Update momenta + n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) + alpha = 1 + 1 / n_atoms_per_system # [n_systems] + + cell_force_val = _npt_nose_hoover_compute_cell_force( + alpha=alpha, + volume=volume, + positions=positions, + momenta=momenta, + masses=masses, + stress=model_output["stress"], + external_pressure=external_pressure, + system_idx=state.system_idx, + ) - This function implements a single integration step for NPT dynamics, including: - 1. Cell momentum and particle momentum updates (half step) - 2. Position and cell position updates (full step) - 3. Force updates with new positions and cell - 4. Final momentum updates (half step) + # Update cell momentum and particle momenta + cell_momentum = cell_momentum + dt_2 * cell_force_val + momenta = _npt_nose_hoover_exp_iL2( + state, alpha, momenta, forces, cell_momentum / cell_mass, dt_2 + ) - Args: - state (NPTNoseHooverState): Current system state - dt (torch.Tensor): Integration timestep - external_pressure (torch.Tensor): Target external pressure + # Full step: Update positions + cell_position = cell_position + cell_momentum / cell_mass * dt - Returns: - NPTNoseHooverState: Updated state after one integration step - """ - # Get target pressure from kwargs or use default - dt_2 = dt / 2 - - # Unpack state variables for clarity - positions = state.positions - momenta = state.momenta - masses = state.masses - forces = state.forces - cell_position = state.cell_position # [n_systems] - cell_momentum = state.cell_momentum # [n_systems] - cell_mass = state.cell_mass # [n_systems] - - n_particles, dim = positions.shape - - # Get current volume and cell function - volume, volume_to_cell = _npt_cell_info(state) - cell = volume_to_cell(volume) - - # Get model output - state.cell = cell - model_output = model(state) - - # First half step: Update momenta - n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) - alpha = 1 + 1 / n_atoms_per_system # [n_systems] - - cell_force_val = compute_cell_force( - alpha=alpha, - volume=volume, - positions=positions, - momenta=momenta, - masses=masses, - stress=model_output["stress"], - external_pressure=external_pressure, - system_idx=state.system_idx, - ) + # Update state with new cell_position before calling functions that depend on it + state.cell_position = cell_position - # Update cell momentum and particle momenta - cell_momentum = cell_momentum + dt_2 * cell_force_val - momenta = exp_iL2(state, alpha, momenta, forces, cell_momentum / cell_mass, dt_2) + # Get updated cell + volume, volume_to_cell = _npt_nose_hoover_cell_info(state) + cell = volume_to_cell(volume) - # Full step: Update positions - cell_position = cell_position + cell_momentum / cell_mass * dt + # Update particle positions and forces + positions = _npt_nose_hoover_exp_iL1( + state, state.velocities, cell_momentum / cell_mass, dt + ) + state.positions = positions + state.cell = cell + model_output = model(state) - # Update state with new cell_position before calling functions that depend on it - state.cell_position = cell_position + # Second half step: Update momenta + momenta = _npt_nose_hoover_exp_iL2( + state, alpha, momenta, model_output["forces"], cell_momentum / cell_mass, dt_2 + ) + cell_force_val = _npt_nose_hoover_compute_cell_force( + alpha=alpha, + volume=volume, + positions=positions, + momenta=momenta, + masses=masses, + stress=model_output["stress"], + external_pressure=external_pressure, + system_idx=state.system_idx, + ) + cell_momentum = cell_momentum + dt_2 * cell_force_val - # Get updated cell - volume, volume_to_cell = _npt_cell_info(state) - cell = volume_to_cell(volume) + # Return updated state + state.positions = positions + state.momenta = momenta + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.cell_position = cell_position + state.cell_momentum = cell_momentum + state.cell_mass = cell_mass + return state - # Update particle positions and forces - positions = exp_iL1(state, state.velocities, cell_momentum / cell_mass, dt) - state.positions = positions - state.cell = cell - model_output = model(state) - # Second half step: Update momenta - momenta = exp_iL2( - state, alpha, momenta, model_output["forces"], cell_momentum / cell_mass, dt_2 - ) - cell_force_val = compute_cell_force( - alpha=alpha, - volume=volume, - positions=positions, - momenta=momenta, - masses=masses, - stress=model_output["stress"], - external_pressure=external_pressure, - system_idx=state.system_idx, - ) - cell_momentum = cell_momentum + dt_2 * cell_force_val - - # Return updated state - state.positions = positions - state.momenta = momenta - state.forces = model_output["forces"] - state.energy = model_output["energy"] - state.cell_position = cell_position - state.cell_momentum = cell_momentum - state.cell_mass = cell_mass - return state - - def npt_nose_hoover_init( - state: SimState | StateDict, - kT: torch.Tensor = kT, - t_tau: torch.Tensor | None = None, - b_tau: torch.Tensor | None = None, - seed: int | None = None, - **kwargs: Any, - ) -> NPTNoseHooverState: - """Initialize the NPT Nose-Hoover state. - - This function initializes a state for NPT molecular dynamics with Nose-Hoover - chain thermostats for both temperature and pressure control. It sets up the - system with appropriate initial conditions including particle positions, momenta, - cell variables, and thermostat chains. +def npt_nose_hoover_init( + model: ModelInterface, + state: SimState | StateDict, + *, + kT: torch.Tensor, + dt: torch.Tensor, + chain_length: int = 3, + chain_steps: int = 2, + sy_steps: int = 3, + t_tau: torch.Tensor | None = None, + b_tau: torch.Tensor | None = None, + seed: int | None = None, + **kwargs: Any, +) -> NPTNoseHooverState: + """Initialize the NPT Nose-Hoover state. - Args: - state: Initial system state as SimState or dict containing positions, masses, - cell, and PBC information - kT: Target temperature in energy units - t_tau: Thermostat relaxation time. Controls how quickly temperature - equilibrates. Defaults to 100*dt - b_tau: Barostat relaxation time. Controls how quickly pressure equilibrates. - Defaults to 1000*dt - seed: Random seed for momenta initialization. Used for reproducible runs - **kwargs: Additional state variables like atomic_numbers or - pre-initialized momenta + This function initializes a state for NPT molecular dynamics with Nose-Hoover + chain thermostats for both temperature and pressure control. It sets up the + system with appropriate initial conditions including particle positions, momenta, + cell variables, and thermostat chains. - Returns: - NPTNoseHooverState: Initialized state containing: - - Particle positions, momenta, forces - - Cell position, momentum and mass (all with batch dimensions) - - Reference cell matrix (with batch dimensions) - - Thermostat and barostat chain variables - - System energy - - Other state variables (masses, PBC, etc.) - - Notes: - - Uses separate Nose-Hoover chains for temperature and pressure control - - Cell mass is set based on system size and barostat relaxation time - - Initial momenta are drawn from Maxwell-Boltzmann distribution if not - provided - - Cell dynamics use logarithmic coordinates for volume updates - - All cell properties are properly initialized with batch dimensions - """ - # Initialize the NPT Nose-Hoover state - # Thermostat relaxation time - if t_tau is None: - t_tau = 100 * dt - - # Barostat relaxation time - if b_tau is None: - b_tau = 1000 * dt - - # Setup thermostats with appropriate timescales - barostat_fns = construct_nose_hoover_chain( - dt, chain_length, chain_steps, sy_steps, b_tau - ) - thermostat_fns = construct_nose_hoover_chain( - dt, chain_length, chain_steps, sy_steps, t_tau - ) + Args: + model (ModelInterface): Model to compute forces and energies + state: Initial system state as SimState or dict containing positions, masses, + cell, and PBC information + kT: Target temperature in energy units + external_pressure: Target external pressure + dt: Integration timestep + chain_length: Length of Nose-Hoover chains. Defaults to 3. + chain_steps: Chain integration substeps. Defaults to 2. + sy_steps: Suzuki-Yoshida integration order. Defaults to 3. + t_tau: Thermostat relaxation time. Controls how quickly temperature + equilibrates. Defaults to 100*dt + b_tau: Barostat relaxation time. Controls how quickly pressure equilibrates. + Defaults to 1000*dt + seed: Random seed for momenta initialization. Used for reproducible runs + **kwargs: Additional state variables like atomic_numbers or + pre-initialized momenta - if not isinstance(state, SimState): - state = SimState(**state) + Returns: + NPTNoseHooverState: Initialized state containing: + - Particle positions, momenta, forces + - Cell position, momentum and mass (all with batch dimensions) + - Reference cell matrix (with batch dimensions) + - Thermostat and barostat chain variables + - System energy + - Other state variables (masses, PBC, etc.) - n_particles, dim = state.positions.shape - n_systems = state.n_systems - atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) + Notes: + - Uses separate Nose-Hoover chains for temperature and pressure control + - Cell mass is set based on system size and barostat relaxation time + - Initial momenta are drawn from Maxwell-Boltzmann distribution if not + provided + - Cell dynamics use logarithmic coordinates for volume updates + - All cell properties are properly initialized with batch dimensions + """ + device, dtype = model.device, model.dtype - # Initialize cell variables with proper system dimensions - cell_position = torch.zeros(n_systems, device=device, dtype=dtype) - cell_momentum = torch.zeros(n_systems, device=device, dtype=dtype) + # Initialize the NPT Nose-Hoover state + # Thermostat relaxation time + if t_tau is None: + t_tau = 100 * dt - # Convert kT to tensor if it's not already one - if not isinstance(kT, torch.Tensor): - kT = torch.tensor(kT, device=device, dtype=dtype) + # Barostat relaxation time + if b_tau is None: + b_tau = 1000 * dt - # Handle both scalar and batched kT - kT_system = kT.expand(n_systems) if kT.ndim == 0 else kT + # Setup thermostats with appropriate timescales + barostat_fns = construct_nose_hoover_chain( + dt, chain_length, chain_steps, sy_steps, b_tau + ) + thermostat_fns = construct_nose_hoover_chain( + dt, chain_length, chain_steps, sy_steps, t_tau + ) - # Calculate cell masses for each system - n_atoms_per_system = torch.bincount(state.system_idx, minlength=n_systems) - cell_mass = dim * (n_atoms_per_system + 1) * kT_system * b_tau**2 - cell_mass = cell_mass.to(device=device, dtype=dtype) + if not isinstance(state, SimState): + state = SimState(**state) - # Calculate cell kinetic energy (using first system for initialization) - KE_cell = calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) + _n_particles, dim = state.positions.shape + n_systems = state.n_systems + atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) - # Ensure reference_cell has proper system dimensions - if state.cell.ndim == 2: - # Single cell matrix - expand to batch dimension - reference_cell = state.cell.unsqueeze(0).expand(n_systems, -1, -1).clone() - else: - # Already has batch dimension - reference_cell = state.cell.clone() - - # Handle scalar cell input - if (torch.is_tensor(state.cell) and state.cell.ndim == 0) or isinstance( - state.cell, int | float - ): - cell_matrix = torch.eye(dim, device=device, dtype=dtype) * state.cell - reference_cell = cell_matrix.unsqueeze(0).expand(n_systems, -1, -1).clone() - state.cell = reference_cell - - # Get model output - model_output = model(state) - forces = model_output["forces"] - energy = model_output["energy"] - - # Create initial state - npt_state = NPTNoseHooverState( - positions=state.positions, - momenta=None, - energy=energy, - forces=forces, - masses=state.masses, - atomic_numbers=atomic_numbers, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - reference_cell=reference_cell, - cell_position=cell_position, - cell_momentum=cell_momentum, - cell_mass=cell_mass, - barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=None, - barostat_fns=barostat_fns, - thermostat_fns=thermostat_fns, - ) + # Initialize cell variables with proper system dimensions + cell_position = torch.zeros(n_systems, device=device, dtype=dtype) + cell_momentum = torch.zeros(n_systems, device=device, dtype=dtype) - # Initialize momenta - momenta = kwargs.get( - "momenta", - calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed - ), - ) + # Convert kT to tensor if it's not already one + if not isinstance(kT, torch.Tensor): + kT = torch.tensor(kT, device=device, dtype=dtype) - # Initialize thermostat - npt_state.momenta = momenta - KE = calc_kinetic_energy( - momenta=npt_state.momenta, - masses=npt_state.masses, - system_idx=npt_state.system_idx, - ) - npt_state.thermostat = thermostat_fns.initialize( - npt_state.positions.numel(), KE, kT - ) + # Handle both scalar and batched kT + kT_system = kT.expand(n_systems) if kT.ndim == 0 else kT - return npt_state + # Calculate cell masses for each system + n_atoms_per_system = torch.bincount(state.system_idx, minlength=n_systems) + cell_mass = dim * (n_atoms_per_system + 1) * kT_system * torch.square(b_tau) + cell_mass = cell_mass.to(device=device, dtype=dtype) - def npt_nose_hoover_update( - state: NPTNoseHooverState, - dt: torch.Tensor = dt, - kT: torch.Tensor = kT, - external_pressure: torch.Tensor = external_pressure, - ) -> NPTNoseHooverState: - """Perform a complete NPT integration step with Nose-Hoover chain thermostats. + # Calculate cell kinetic energy (using first system for initialization) + KE_cell = ts.calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) - This function performs a full NPT integration step including: - 1. Mass parameter updates for thermostats and cell - 2. Thermostat chain updates (half step) - 3. Inner NPT dynamics step - 4. Energy updates for thermostats - 5. Final thermostat chain updates (half step) + # Ensure reference_cell has proper system dimensions + if state.cell.ndim == 2: + # Single cell matrix - expand to batch dimension + reference_cell = state.cell.unsqueeze(0).expand(n_systems, -1, -1).clone() + else: + # Already has batch dimension + reference_cell = state.cell.clone() - Args: - state (NPTNoseHooverState): Current system state - dt (torch.Tensor): Integration timestep - kT (torch.Tensor): Target temperature - external_pressure (torch.Tensor): Target external pressure + # Handle scalar cell input + if (torch.is_tensor(state.cell) and state.cell.ndim == 0) or isinstance( + state.cell, int | float + ): + cell_matrix = torch.eye(dim, device=device, dtype=dtype) * state.cell + reference_cell = cell_matrix.unsqueeze(0).expand(n_systems, -1, -1).clone() + state.cell = reference_cell + + # Get model output + model_output = model(state) + forces = model_output["forces"] + energy = model_output["energy"] + + # Create initial state + npt_state = NPTNoseHooverState( + positions=state.positions, + momenta=torch.zeros_like(state.positions), + energy=energy, + forces=forces, + masses=state.masses, + atomic_numbers=atomic_numbers, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + reference_cell=reference_cell, + cell_position=cell_position, + cell_momentum=cell_momentum, + cell_mass=cell_mass, + barostat=barostat_fns.initialize(1, KE_cell, kT), + thermostat=thermostat_fns.initialize(), + barostat_fns=barostat_fns, + thermostat_fns=thermostat_fns, + ) - Returns: - NPTNoseHooverState: Updated state after complete integration step - """ - # Unpack state variables for clarity - barostat = state.barostat - thermostat = state.thermostat - - # Update mass parameters - state.barostat = state.barostat_fns.update_mass(barostat, kT) - state.thermostat = state.thermostat_fns.update_mass(thermostat, kT) - state = update_cell_mass(state, kT) - - # First half step of thermostat chains - state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT - ) - state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT - ) + # Initialize momenta + momenta = kwargs.get( + "momenta", + calculate_momenta( + npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed + ), + ) - # Perform inner NPT step - state = npt_inner_step( - state=state, - dt=dt, - external_pressure=external_pressure, - ) + # Initialize thermostat + npt_state.momenta = momenta + KE = ts.calc_kinetic_energy( + momenta=npt_state.momenta, + masses=npt_state.masses, + system_idx=npt_state.system_idx, + ) + npt_state.thermostat = thermostat_fns.initialize(npt_state.positions.numel(), KE, kT) - # Update kinetic energies for thermostats - KE = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, system_idx=state.system_idx - ) - state.thermostat.kinetic_energy = KE + return npt_state - KE_cell = calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) - state.barostat.kinetic_energy = KE_cell - # Second half step of thermostat chains - state.momenta, state.thermostat = state.thermostat_fns.half_step( - state.momenta, state.thermostat, kT - ) - state.cell_momentum, state.barostat = state.barostat_fns.half_step( - state.cell_momentum, state.barostat, kT - ) - return state +def npt_nose_hoover_update( + model: ModelInterface, + state: NPTNoseHooverState, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTNoseHooverState: + """Perform a complete NPT integration step with Nose-Hoover chain thermostats. + + This function performs a full NPT integration step including: + 1. Mass parameter updates for thermostats and cell + 2. Thermostat chain updates (half step) + 3. Inner NPT dynamics step + 4. Energy updates for thermostats + 5. Final thermostat chain updates (half step) - return npt_nose_hoover_init, npt_nose_hoover_update + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTNoseHooverState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + + Returns: + NPTNoseHooverState: Updated state after complete integration step + """ + device, dtype = model.device, model.dtype + + # Unpack state variables for clarity + barostat = state.barostat + thermostat = state.thermostat + + # Update mass parameters + state.barostat = state.barostat_fns.update_mass(barostat, kT) + state.thermostat = state.thermostat_fns.update_mass(thermostat, kT) + state = _npt_nose_hoover_update_cell_mass(state, kT, device, dtype) + + # First half step of thermostat chains + state.cell_momentum, state.barostat = state.barostat_fns.half_step( + state.cell_momentum, state.barostat, kT + ) + state.momenta, state.thermostat = state.thermostat_fns.half_step( + state.momenta, state.thermostat, kT + ) + + # Perform inner NPT step + state = _npt_nose_hoover_inner_step(model, state, dt, external_pressure) + + # Update kinetic energies for thermostats + KE = ts.calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + state.thermostat.kinetic_energy = KE + + KE_cell = ts.calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum) + state.barostat.kinetic_energy = KE_cell + + # Second half step of thermostat chains + state.momenta, state.thermostat = state.thermostat_fns.half_step( + state.momenta, state.thermostat, kT + ) + state.cell_momentum, state.barostat = state.barostat_fns.half_step( + state.cell_momentum, state.barostat, kT + ) + return state def npt_nose_hoover_invariant( @@ -1634,7 +1558,7 @@ def npt_nose_hoover_invariant( e_pot = state.energy # Should be scalar or [n_systems] # Calculate kinetic energy of particles per system - e_kin_per_system = calc_kinetic_energy( + e_kin_per_system = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) @@ -1654,7 +1578,7 @@ def npt_nose_hoover_invariant( # Note: These are global thermostat variables, so we add them to each system # Start thermostat_energy as a tensor with the right shape thermostat_energy = torch.zeros_like(e_tot) - thermostat_energy += (state.thermostat.momenta[0] ** 2) / ( + thermostat_energy += torch.square(state.thermostat.momenta[0]) / ( 2 * state.thermostat.masses[0] ) @@ -1676,10 +1600,10 @@ def npt_nose_hoover_invariant( ): if isinstance(kT, torch.Tensor) and kT.ndim == 0: # Scalar kT case - thermostat_energy += (momentum**2) / (2 * mass) + kT * pos + thermostat_energy += torch.square(momentum) / (2 * mass) + kT * pos else: # Batched kT case - thermostat_energy += (momentum**2) / (2 * mass) + kT_expanded * pos + thermostat_energy += torch.square(momentum) / (2 * mass) + kT_expanded * pos e_tot = e_tot + thermostat_energy @@ -1693,16 +1617,16 @@ def npt_nose_hoover_invariant( ): if isinstance(kT, torch.Tensor) and kT.ndim == 0: # Scalar kT case - barostat_energy += (momentum**2) / (2 * mass) + kT * pos + barostat_energy += torch.square(momentum) / (2 * mass) + kT * pos else: # Batched kT case - barostat_energy += (momentum**2) / (2 * mass) + kT_expanded * pos + barostat_energy += torch.square(momentum) / (2 * mass) + kT_expanded * pos e_tot = e_tot + barostat_energy # Add PV term and cell kinetic energy (both are per system) e_tot += external_pressure * volume - e_tot += (state.cell_momentum**2) / (2 * state.cell_mass) + e_tot += torch.square(state.cell_momentum) / (2 * state.cell_mass) # Return scalar if single system, otherwise return per-system values if state.n_systems == 1: diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index c7e413902..49509567e 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -1,6 +1,6 @@ """Implementations of NVE integrators.""" -from collections.abc import Callable +from typing import Any import torch @@ -15,128 +15,94 @@ from torch_sim.typing import StateDict -def nve( +def nve_init( model: ModelInterface, + state: SimState | StateDict, *, - dt: torch.Tensor, kT: torch.Tensor, seed: int | None = None, -) -> tuple[ - Callable[[SimState | StateDict, torch.Tensor], MDState], - Callable[[MDState, torch.Tensor], MDState], -]: - """Initialize and return an NVE (microcanonical) integrator. + **_kwargs: Any, +) -> MDState: + """Initialize an NVE state from input data. - This function sets up integration in the NVE ensemble, where particle number (N), - volume (V), and total energy (E) are conserved. It returns both an initialization - function and an update function for time evolution. - - The initialization function samples initial momenta from a Maxwell-Boltzmann - distribution at the specified temperature, while the update function - implements the velocity Verlet algorithm for energy-conserving dynamics. + Creates an initial state for NVE molecular dynamics by computing initial + energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution + at the specified temperature. Args: - model (torch.nn.Module): Neural network model that computes energies and forces. + model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] - kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_systems] - seed (int, optional): Random seed for reproducibility. Defaults to None. + state: Either a SimState object or a dictionary containing positions, + masses, cell, pbc, and other required state variables + kT: Temperature in energy units for initializing momenta, + scalar or with shape [n_systems] + seed: Random seed for reproducibility Returns: - tuple: - - callable: Function to initialize the MDState from input data and kT - with signature: init_fn(state, kT=kT, seed=seed) -> MDState - - callable: Update function that evolves system by one timestep - with signature: update_fn(state, dt=dt) -> MDState + MDState: Initialized state for NVE integration containing positions, + momenta, forces, energy, and other required attributes Notes: - - Uses velocity Verlet algorithm for time-reversible integration - - Conserves total energy in the absence of numerical errors - Initial velocities sampled from Maxwell-Boltzmann distribution - Time integration error scales as O(dtΒ²) """ + if not isinstance(state, SimState): + state = SimState(**state) + + model_output = model(state) + + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + return MDState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + ) + + +def nve_update( + model: ModelInterface, state: MDState, *, dt: torch.Tensor, **_kwargs: Any +) -> MDState: + """Perform one complete NVE (microcanonical) integration step. + + This function implements the velocity Verlet algorithm for NVE dynamics, + which provides energy-conserving time evolution. The integration sequence is: + 1. Half momentum update using current forces + 2. Full position update using updated momenta + 3. Force update at new positions + 4. Half momentum update using new forces + + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Current system state containing positions, momenta, forces + dt: Integration timestep, either scalar or shape [n_systems] + + Returns: + MDState: Updated state after one complete NVE step with new positions, + momenta, forces, and energy + + Notes: + - Uses velocity Verlet algorithm for time reversible integration + - Conserves energy in the absence of numerical errors + - Handles periodic boundary conditions if enabled in state + - Symplectic integrator preserving phase space volume + """ + state = momentum_step(state, dt / 2) + state = position_step(state, dt) + + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] - def nve_init( - state: SimState | StateDict, - kT: torch.Tensor = kT, - seed: int | None = seed, - ) -> MDState: - """Initialize an NVE state from input data. - - Creates an initial state for NVE molecular dynamics by computing initial - energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution - at the specified temperature. - - Args: - state (SimState | StateDict): Either a SimState object or a dictionary - containing positions, masses, cell, pbc, and other required state - variables - kT (torch.Tensor): Temperature in energy units for initializing momenta, - scalar or with shape [n_systems] - seed (int, optional): Random seed for reproducibility - - Returns: - MDState: Initialized state for NVE integration containing positions, - momenta, forces, energy, and other required attributes - """ - # Extract required data from input - if not isinstance(state, SimState): - state = SimState(**state) - - model_output = model(state) - - momenta = getattr( - state, - "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), - ) - - initial_state = MDState( - positions=state.positions, - momenta=momenta, - energy=model_output["energy"], - forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - ) - return initial_state # noqa: RET504 - - def nve_update(state: MDState, dt: torch.Tensor = dt, **_) -> MDState: - """Perform one complete NVE (microcanonical) integration step. - - This function implements the velocity Verlet algorithm for NVE dynamics, - which provides energy-conserving time evolution. The integration sequence is: - 1. Half momentum update using current forces - 2. Full position update using updated momenta - 3. Force update at new positions - 4. Half momentum update using new forces - - Args: - state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - **_: Additional unused keyword arguments (for compatibility) - - Returns: - MDState: Updated state after one complete NVE step with new positions, - momenta, forces, and energy - - Notes: - - Uses velocity Verlet algorithm for time reversible integration - - Conserves energy in the absence of numerical errors - - Handles periodic boundary conditions if enabled in state - - Symplectic integrator preserving phase space volume - """ - state = momentum_step(state, dt / 2) - state = position_step(state, dt) - - model_output = model(state) - state.energy = model_output["energy"] - state.forces = model_output["forces"] - - return momentum_step(state, dt / 2) - - return nve_init, nve_update + return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 18f0ae154..59ef2b8fd 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -1,11 +1,11 @@ """Implementations of NVT integrators.""" -from collections.abc import Callable from dataclasses import dataclass from typing import Any import torch +import torch_sim as ts from torch_sim.integrators.md import ( MDState, NoseHooverChain, @@ -17,48 +17,153 @@ velocity_verlet, ) from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict -def nvt_langevin( # noqa: C901 +def _ou_step( + state: MDState, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + gamma: float | torch.Tensor, +) -> MDState: + """Apply stochastic noise and friction for Langevin dynamics. + + This function implements the Ornstein-Uhlenbeck process for Langevin dynamics, + applying random noise and friction forces to particle momenta. The noise amplitude + is chosen to satisfy the fluctuation-dissipation theorem, ensuring proper + sampling of the canonical ensemble at temperature kT. + + Args: + state (MDState): Current system state containing positions, momenta, etc. + dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] + kT (torch.Tensor): Target temperature in energy units, either scalar or + with shape [n_systems] + gamma (torch.Tensor): Friction coefficient controlling noise strength, + either scalar or with shape [n_systems] + + Returns: + MDState: Updated state with new momenta after stochastic step + + Notes: + - Implements the "O" step in the BAOAB Langevin integration scheme + - Uses Ornstein-Uhlenbeck process for correct thermal sampling + - Noise amplitude scales with sqrt(mass) for equipartition + - Preserves detailed balance through fluctuation-dissipation relation + - The equation implemented is: + p(t+dt) = c1*p(t) + c2*sqrt(m)*N(0,1) + where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1Β²)) + """ + c1 = torch.exp(torch.tensor(-gamma * dt)) + + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: + # kT is a tensor with shape (n_systems,) + kT = kT[state.system_idx] + + # Index c1 and c2 with state.system_idx to align shapes with state.momenta + if isinstance(c1, torch.Tensor) and len(c1.shape) > 0: + c1 = c1[state.system_idx] + + c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1) + + # Generate random noise from normal distribution + noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) + new_momenta = ( + c1.unsqueeze(-1) * state.momenta + + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise + ) + state.momenta = new_momenta + return state + + +def nvt_langevin_init( model: ModelInterface, + state: SimState | StateDict, *, - dt: torch.Tensor, - kT: torch.Tensor, - gamma: torch.Tensor | None = None, + kT: float | torch.Tensor, seed: int | None = None, -) -> tuple[ - Callable[[SimState | StateDict, torch.Tensor], MDState], - Callable[[MDState, torch.Tensor], MDState], -]: - """Initialize and return an NVT (canonical) integrator using Langevin dynamics. + **_kwargs: Any, +) -> MDState: + """Initialize an NVT state from input data for Langevin dynamics. - This function sets up integration in the NVT ensemble, where particle number (N), - volume (V), and temperature (T) are conserved. It returns both an initial state - and an update function for time evolution. + Creates an initial state for NVT molecular dynamics by computing initial + energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution + at the specified temperature. - It uses Langevin dynamics with stochastic noise and friction to maintain constant - temperature. The integration scheme combines deterministic velocity Verlet steps with - stochastic Ornstein-Uhlenbeck processes following the BAOAB splitting scheme. + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Either a SimState object or a dictionary containing positions, + masses, cell, pbc, and other required state vars + kT: Temperature in energy units for initializing momenta, + either scalar or with shape [n_systems] + seed: Random seed for reproducibility + + Returns: + MDState: Initialized state for NVT integration containing positions, + momenta, forces, energy, and other required attributes + + Notes: + The initial momenta are sampled from a Maxwell-Boltzmann distribution + at the specified temperature. This provides a proper thermal initial + state for the subsequent Langevin dynamics. + """ + if not isinstance(state, SimState): + state = SimState(**state) + + model_output = model(state) + + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + return MDState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + ) + + +def nvt_langevin_update( + model: ModelInterface, + state: MDState, + *, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + gamma: float | torch.Tensor | None = None, +) -> MDState: + """Perform one complete Langevin dynamics integration step. + + This function implements the BAOAB splitting scheme for Langevin dynamics, + which provides accurate sampling of the canonical ensemble. The integration + sequence is: + 1. Half momentum update using forces (B step) + 2. Half position update using updated momenta (A step) + 3. Full stochastic update with noise and friction (O step) + 4. Half position update using updated momenta (A step) + 5. Half momentum update using new forces (B step) Args: - model (torch.nn.Module): Neural network model that computes energies and forces. + model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - dt (torch.Tensor): Integration timestep, either scalar or with shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or + state: Current system state containing positions, momenta, forces + dt: Integration timestep, either scalar or shape [n_systems] + kT: Target temperature in energy units, either scalar or with shape [n_systems] - gamma (torch.Tensor, optional): Friction coefficient for Langevin thermostat, + gamma: Friction coefficient for Langevin thermostat, either scalar or with shape [n_systems]. Defaults to 1/(100*dt). - seed (int, optional): Random seed for reproducibility. Defaults to None. Returns: - tuple: - - callable: Function to initialize the MDState from input data - with signature: init_fn(state, kT=kT, seed=seed) -> MDState - - callable: Update function that evolves system by one timestep - with signature: update_fn(state, dt=dt, kT=kT, gamma=gamma) -> MDState + MDState: Updated state after one complete Langevin step with new positions, + momenta, forces, and energy Notes: - Uses BAOAB splitting scheme for Langevin dynamics @@ -79,158 +184,16 @@ def nvt_langevin( # noqa: C901 if isinstance(dt, float): dt = torch.tensor(dt, device=device, dtype=dtype) - def ou_step( - state: MDState, - dt: torch.Tensor, - kT: torch.Tensor, - gamma: torch.Tensor, - ) -> MDState: - """Apply stochastic noise and friction for Langevin dynamics. - - This function implements the Ornstein-Uhlenbeck process for Langevin dynamics, - applying random noise and friction forces to particle momenta. The noise amplitude - is chosen to satisfy the fluctuation-dissipation theorem, ensuring proper - sampling of the canonical ensemble at temperature kT. - - Args: - state (MDState): Current system state containing positions, momenta, etc. - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - gamma (torch.Tensor): Friction coefficient controlling noise strength, - either scalar or with shape [n_systems] - - Returns: - MDState: Updated state with new momenta after stochastic step - - Notes: - - Implements the "O" step in the BAOAB Langevin integration scheme - - Uses Ornstein-Uhlenbeck process for correct thermal sampling - - Noise amplitude scales with sqrt(mass) for equipartition - - Preserves detailed balance through fluctuation-dissipation relation - - The equation implemented is: - p(t+dt) = c1*p(t) + c2*sqrt(m)*N(0,1) - where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1Β²)) - """ - c1 = torch.exp(-gamma * dt) - - if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: - # kT is a tensor with shape (n_systems,) - kT = kT[state.system_idx] - - # Index c1 and c2 with state.system_idx to align shapes with state.momenta - if isinstance(c1, torch.Tensor) and len(c1.shape) > 0: - c1 = c1[state.system_idx] - - c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) - - # Generate random noise from normal distribution - noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) - new_momenta = ( - c1.unsqueeze(-1) * state.momenta - + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise - ) - state.momenta = new_momenta - return state - - def langevin_init( - state: SimState | StateDict, - kT: torch.Tensor = kT, - seed: int | None = seed, - ) -> MDState: - """Initialize an NVT state from input data for Langevin dynamics. - - Creates an initial state for NVT molecular dynamics by computing initial - energies and forces, and sampling momenta from a Maxwell-Boltzmann distribution - at the specified temperature. - - Args: - state (SimState | StateDict): Either a SimState object or a dictionary - containing positions, masses, cell, pbc, and other required state vars - kT (torch.Tensor): Temperature in energy units for initializing momenta, - either scalar or with shape [n_systems] - seed (int, optional): Random seed for reproducibility - - Returns: - MDState: Initialized state for NVT integration containing positions, - momenta, forces, energy, and other required attributes - - Notes: - The initial momenta are sampled from a Maxwell-Boltzmann distribution - at the specified temperature. This provides a proper thermal initial - state for the subsequent Langevin dynamics. - """ - if not isinstance(state, SimState): - state = SimState(**state) - - model_output = model(state) - - momenta = getattr( - state, - "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), - ) - - initial_state = MDState( - positions=state.positions, - momenta=momenta, - energy=model_output["energy"], - forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - ) - return initial_state # noqa: RET504 - - def langevin_update( - state: MDState, - dt: torch.Tensor = dt, - kT: torch.Tensor = kT, - gamma: torch.Tensor = gamma, - ) -> MDState: - """Perform one complete Langevin dynamics integration step. - - This function implements the BAOAB splitting scheme for Langevin dynamics, - which provides accurate sampling of the canonical ensemble. The integration - sequence is: - 1. Half momentum update using forces (B step) - 2. Half position update using updated momenta (A step) - 3. Full stochastic update with noise and friction (O step) - 4. Half position update using updated momenta (A step) - 5. Half momentum update using new forces (B step) - - Args: - state (MDState): Current system state containing positions, momenta, forces - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - gamma (torch.Tensor): Friction coefficient for Langevin thermostat, - either scalar or with shape [n_systems] - - Returns: - MDState: Updated state after one complete Langevin step with new positions, - momenta, forces, and energy - """ - if isinstance(gamma, float): - gamma = torch.tensor(gamma, device=device, dtype=dtype) - - if isinstance(dt, float): - dt = torch.tensor(dt, device=device, dtype=dtype) - - state = momentum_step(state, dt / 2) - state = position_step(state, dt / 2) - state = ou_step(state, dt, kT, gamma) - state = position_step(state, dt / 2) + state = momentum_step(state, dt / 2) + state = position_step(state, dt / 2) + state = _ou_step(state, dt, kT, gamma) + state = position_step(state, dt / 2) - model_output = model(state) - state.energy = model_output["energy"] - state.forces = model_output["forces"] + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] - return momentum_step(state, dt / 2) - - return langevin_init, langevin_update + return momentum_step(state, dt / 2) @dataclass @@ -278,19 +241,20 @@ def velocities(self) -> torch.Tensor: return self.momenta / self.masses.unsqueeze(-1) -def nvt_nose_hoover( - *, +def nvt_nose_hoover_init( model: ModelInterface, - dt: torch.Tensor, + state: SimState | StateDict, + *, kT: torch.Tensor, + dt: torch.Tensor, + tau: torch.Tensor | None = None, chain_length: int = 3, chain_steps: int = 3, sy_steps: int = 3, -) -> tuple[ - Callable[[SimState | StateDict, torch.Tensor, int | None, Any], NVTNoseHooverState], - Callable[[NVTNoseHooverState, torch.Tensor], NVTNoseHooverState], -]: - """Initialize NVT Nose-Hoover chain thermostat integration. + seed: int | None = None, + **kwargs: Any, +) -> NVTNoseHooverState: + """Initialize the NVT Nose-Hoover state. This function sets up integration of an NVT system using a Nose-Hoover chain thermostat. The Nose-Hoover chain provides deterministic temperature control by @@ -299,160 +263,129 @@ def nvt_nose_hoover( Args: model: Neural network model that computes energies and forces - dt: Integration timestep + state: Initial system state as SimState or dict kT: Target temperature in energy units + dt: Integration timestep + tau: Thermostat relaxation time (defaults to 100*dt) chain_length: Number of thermostats in Nose-Hoover chain (default: 3) chain_steps: Number of chain integration substeps (default: 3) sy_steps: Number of Suzuki-Yoshida steps - must be 1, 3, 5, or 7 (default: 3) + seed: Random seed for momenta initialization + **kwargs: Additional state variables Returns: - tuple containing: - - Initialization function that takes a state and returns NVTNoseHooverState - - Update function that performs one complete integration step + Initialized NVTNoseHooverState with positions, momenta, forces, + and thermostat chain variables Notes: - The initialization function accepts: - - state: Initial system state (SimState or dict) - - kT: Target temperature (optional, defaults to constructor value) - - tau: Thermostat relaxation time (optional, defaults to 100*dt) - - seed: Random seed for momenta initialization (optional) - - **kwargs: Additional state variables - - The update function accepts: - - state: Current NVTNoseHooverState - - dt: Integration timestep (optional, defaults to constructor value) - - kT: Target temperature (optional, defaults to constructor value) - - The integration sequence is: - 1. Update chain masses + - The Nose-Hoover chain provides deterministic temperature control + - Extended system approach conserves an extended energy quantity + - Chain variables evolve to maintain target temperature + - Time-reversible when integrated with appropriate algorithms + """ + if tau is None: # Set default tau if not provided + tau = dt * 100.0 + + # Create thermostat functions + chain_fns = construct_nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau) + if not isinstance(state, SimState): + state = SimState(**state) + + atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) + + model_output = model(state) + momenta = kwargs.get( + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + # Calculate initial kinetic energy per system + KE = ts.calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) + + # Calculate degrees of freedom per system + n_atoms_per_system = torch.bincount(state.system_idx) + dof_per_system = ( + n_atoms_per_system * state.positions.shape[-1] + ) # n_atoms * n_dimensions + + # For now, sum the per-system DOF as chain expects a single int + # This is a limitation that should be addressed in the chain implementation + total_dof = int(dof_per_system.sum().item()) + + # Initialize state + return NVTNoseHooverState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + atomic_numbers=atomic_numbers, + system_idx=state.system_idx, + chain=chain_fns.initialize(total_dof, KE, kT), + _chain_fns=chain_fns, # Store the chain functions + ) + + +def nvt_nose_hoover_update( + model: ModelInterface, + state: NVTNoseHooverState, + *, + dt: torch.Tensor, + kT: torch.Tensor, +) -> NVTNoseHooverState: + """Perform one complete Nose-Hoover chain integration step. + + This function performs one integration step for an NVT system using a Nose-Hoover + chain thermostat. The integration scheme is time-reversible and conserves an + extended energy quantity. + + Args: + model: Neural network model that computes energies and forces + state: Current system state containing positions, momenta, forces, and chain + dt: Integration timestep + kT: Target temperature in energy units + + Returns: + Updated state after one complete Nose-Hoover step + + Notes: + Integration sequence: + 1. Update chain masses based on target temperature 2. First half-step of chain evolution 3. Full velocity Verlet step 4. Update chain kinetic energy 5. Second half-step of chain evolution """ + # Get chain functions from state + chain_fns = state._chain_fns # noqa: SLF001 + chain = state.chain - def nvt_nose_hoover_init( - state: SimState | StateDict, - kT: torch.Tensor = kT, - tau: torch.Tensor | None = None, - seed: int | None = None, - **kwargs: Any, - ) -> NVTNoseHooverState: - """Initialize the NVT Nose-Hoover state. - - Args: - state: Initial system state as SimState or dict - kT: Target temperature in energy units - tau: Thermostat relaxation time (defaults to 100*dt) - seed: Random seed for momenta initialization - **kwargs: Additional state variables - - Returns: - Initialized NVTNoseHooverState with positions, momenta, forces, - and thermostat chain variables - """ - # Set default tau if not provided - if tau is None: - tau = dt * 100.0 - - # Create thermostat functions - chain_fns = construct_nose_hoover_chain( - dt, chain_length, chain_steps, sy_steps, tau - ) - - if not isinstance(state, SimState): - state = SimState(**state) - - atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) - - model_output = model(state) - momenta = kwargs.get( - "momenta", - calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), - ) - - # Calculate initial kinetic energy per system - KE = calc_kinetic_energy( - masses=state.masses, momenta=momenta, system_idx=state.system_idx - ) - - # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - dof_per_system = ( - n_atoms_per_system * state.positions.shape[-1] - ) # n_atoms * n_dimensions - - # For now, sum the per-system DOF as chain expects a single int - # This is a limitation that should be addressed in the chain implementation - total_dof = int(dof_per_system.sum().item()) - - # Initialize state - state = NVTNoseHooverState( - positions=state.positions, - momenta=momenta, - energy=model_output["energy"], - forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - atomic_numbers=atomic_numbers, - system_idx=state.system_idx, - chain=chain_fns.initialize(total_dof, KE, kT), - _chain_fns=chain_fns, # Store the chain functions - ) - return state # noqa: RET504 - - def nvt_nose_hoover_update( - state: NVTNoseHooverState, - dt: torch.Tensor = dt, - kT: torch.Tensor = kT, - ) -> NVTNoseHooverState: - """Perform one complete Nose-Hoover chain integration step. - - Args: - state: Current system state containing positions, momenta, forces, and chain - dt: Integration timestep - kT: Target temperature in energy units - - Returns: - Updated state after one complete Nose-Hoover step - - Notes: - Integration sequence: - 1. Update chain masses based on target temperature - 2. First half-step of chain evolution - 3. Full velocity Verlet step - 4. Update chain kinetic energy - 5. Second half-step of chain evolution - """ - # Get chain functions from state - chain_fns = state._chain_fns # noqa: SLF001 - chain = state.chain - - # Update chain masses based on target temperature - chain = chain_fns.update_mass(chain, kT) - - # First half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) - state.momenta = momenta + # Update chain masses based on target temperature + chain = chain_fns.update_mass(chain, kT) - # Full velocity Verlet step - state = velocity_verlet(state=state, dt=dt, model=model) + # First half-step of chain evolution + momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + state.momenta = momenta - # Update chain kinetic energy per system - KE = calc_kinetic_energy( - masses=state.masses, momenta=state.momenta, system_idx=state.system_idx - ) - chain.kinetic_energy = KE + # Full velocity Verlet step + state = velocity_verlet(state=state, dt=dt, model=model) - # Second half-step of chain evolution - momenta, chain = chain_fns.half_step(state.momenta, chain, kT) - state.momenta = momenta - state.chain = chain + # Update chain kinetic energy per system + KE = ts.calc_kinetic_energy( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + chain.kinetic_energy = KE - return state + # Second half-step of chain evolution + momenta, chain = chain_fns.half_step(state.momenta, chain, kT) + state.momenta = momenta + state.chain = chain - return nvt_nose_hoover_init, nvt_nose_hoover_update + return state def nvt_nose_hoover_invariant( @@ -486,7 +419,7 @@ def nvt_nose_hoover_invariant( """ # Calculate system energy terms per system e_pot = state.energy - e_kin = calc_kinetic_energy( + e_kin = ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) @@ -500,7 +433,7 @@ def nvt_nose_hoover_invariant( # Add first thermostat term c = state.chain # Ensure chain momenta and masses broadcast correctly with batch dimensions - chain_ke_0 = c.momenta[0] ** 2 / (2 * c.masses[0]) + chain_ke_0 = torch.square(c.momenta[0]) / (2 * c.masses[0]) chain_pe_0 = dof * kT * c.positions[0] # If chain variables are scalars but we have batches, broadcast them diff --git a/torch_sim/io.py b/torch_sim/io.py index 8f61bd496..a2081c206 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -52,14 +52,14 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - system_idx = state.system_idx.detach().cpu().numpy() + system_indices = state.system_idx.detach().cpu().numpy() atoms_list = [] - for idx in np.unique(system_idx): - mask = system_idx == idx + for sys_idx in np.unique(system_indices): + mask = system_indices == sys_idx system_positions = positions[mask] system_numbers = atomic_numbers[mask] - system_cell = cell[idx].T # Transpose for ASE convention + system_cell = cell[sys_idx].T # Transpose for ASE convention # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in system_numbers] @@ -89,8 +89,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: - Assumes periodic boundary conditions """ try: - from pymatgen.core import Lattice, Structure - from pymatgen.core.periodic_table import Element + from pymatgen.core import Element, Lattice, Structure except ImportError: raise ImportError( "Pymatgen is required for state_to_structures conversion" @@ -100,18 +99,18 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - system_idx = state.system_idx.detach().cpu().numpy() + system_indices = state.system_idx.detach().cpu().numpy() # Get unique system indices and counts - unique_systems = np.unique(system_idx) - structures = [] + uniq_systems = np.unique(system_indices) + structures: list[Structure] = [] - for unique_system_idx in unique_systems: + for uniq_sys_idx in uniq_systems: # Get mask for current system - mask = system_idx == unique_system_idx + mask = system_indices == uniq_sys_idx system_positions = positions[mask] system_numbers = atomic_numbers[mask] - system_cell = cell[unique_system_idx].T # Transpose for conventional form + system_cell = cell[uniq_sys_idx].T # Transpose for conventional form # Create species list from atomic numbers species = [Element.from_Z(z) for z in system_numbers] @@ -154,33 +153,27 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_systems, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() - system_idx = state.system_idx.detach().cpu().numpy() + system_indices = state.system_idx.detach().cpu().numpy() - phonopy_atoms_list = [] - for idx in np.unique(system_idx): - mask = system_idx == idx + phonopy_atoms_list: list[PhonopyAtoms] = [] + for sys_idx in np.unique(system_indices): + mask = system_indices == sys_idx system_positions = positions[mask] system_numbers = atomic_numbers[mask] - system_cell = cell[idx].T # Transpose for Phonopy convention + system_cell = cell[sys_idx].T # Transpose for Phonopy convention # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in system_numbers] - phonopy_atoms_list.append( - PhonopyAtoms( - symbols=symbols, - positions=system_positions, - cell=system_cell, - pbc=state.pbc, - ) + phonopy_atoms = PhonopyAtoms( + symbols=symbols, positions=system_positions, cell=system_cell, pbc=state.pbc ) + phonopy_atoms_list.append(phonopy_atoms) return phonopy_atoms_list def atoms_to_state( - atoms: "Atoms | list[Atoms]", - device: torch.device, - dtype: torch.dtype, + atoms: "Atoms | list[Atoms]", device: torch.device, dtype: torch.dtype ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -211,28 +204,28 @@ def atoms_to_state( # Stack all properties in one go positions = torch.tensor( - np.concatenate([a.positions for a in atoms_list]), dtype=dtype, device=device + np.concatenate([at.positions for at in atoms_list]), dtype=dtype, device=device ) masses = torch.tensor( - np.concatenate([a.get_masses() for a in atoms_list]), dtype=dtype, device=device + np.concatenate([at.get_masses() for at in atoms_list]), dtype=dtype, device=device ) atomic_numbers = torch.tensor( - np.concatenate([a.get_atomic_numbers() for a in atoms_list]), + np.concatenate([at.get_atomic_numbers() for at in atoms_list]), dtype=torch.int, device=device, ) - cell = torch.tensor( # Transpose cell from ASE convention to torchsim convention - np.stack([a.cell.array.T for a in atoms_list]), dtype=dtype, device=device + cell = torch.tensor( # Transpose cell from ASE convention to TorchSim convention + np.stack([at.cell.array.T for at in atoms_list]), dtype=dtype, device=device ) # Create system indices using repeat_interleave - atoms_per_system = torch.tensor([len(a) for a in atoms_list], device=device) + atoms_per_system = torch.tensor([len(at) for at in atoms_list], device=device) system_idx = torch.repeat_interleave( torch.arange(len(atoms_list), device=device), atoms_per_system ) # Verify consistent pbc - if not all(all(a.pbc) == all(atoms_list[0].pbc) for a in atoms_list): + if not all(all(at.pbc) == all(atoms_list[0].pbc) for at in atoms_list): raise ValueError("All systems must have the same periodic boundary conditions") return ts.SimState( @@ -246,9 +239,7 @@ def atoms_to_state( def structures_to_state( - structure: "Structure | list[Structure]", - device: torch.device, - dtype: torch.dtype, + structure: "Structure | list[Structure]", device: torch.device, dtype: torch.dtype ) -> "ts.SimState": """Create a SimState from pymatgen Structure(s). @@ -349,15 +340,12 @@ def phonopy_to_state( ) # Stack all properties in one go + kwargs = {"dtype": dtype, "device": device} positions = torch.tensor( - np.concatenate([a.positions for a in phonopy_atoms_list]), - dtype=dtype, - device=device, + np.concatenate([at.positions for at in phonopy_atoms_list]), **kwargs ) masses = torch.tensor( - np.concatenate([a.masses for a in phonopy_atoms_list]), - dtype=dtype, - device=device, + np.concatenate([at.masses for at in phonopy_atoms_list]), **kwargs ) atomic_numbers = torch.tensor( np.concatenate([a.numbers for a in phonopy_atoms_list]), @@ -365,11 +353,11 @@ def phonopy_to_state( device=device, ) cell = torch.tensor( - np.stack([a.cell.T for a in phonopy_atoms_list]), dtype=dtype, device=device + np.stack([at.cell.T for at in phonopy_atoms_list]), dtype=dtype, device=device ) # Create system indices using repeat_interleave - atoms_per_system = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) + atoms_per_system = torch.tensor([len(at) for at in phonopy_atoms_list], device=device) system_idx = torch.repeat_interleave( torch.arange(len(phonopy_atoms_list), device=device), atoms_per_system ) @@ -377,7 +365,7 @@ def phonopy_to_state( """ NOTE: PhonopyAtoms does not have pbc attribute for Supercells assume True Verify consistent pbc - if not all(all(a.pbc) == all(phonopy_atoms_list[0].pbc) for a in phonopy_atoms_list): + if not all(all(at.pbc) == all(phonopy_atoms_lst[0].pbc) for at in phonopy_atoms_lst): raise ValueError("All systems must have the same periodic boundary conditions") """ diff --git a/torch_sim/math.py b/torch_sim/math.py index 7b4596f78..737422dc3 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -1,10 +1,9 @@ -"""Mathematical operations and utilities.""" +"""Mathematical operations and utilities. Adapted from https://github.com/abhijeetgangan/torch_matfunc.""" # ruff: noqa: FBT001, FBT002, RUF002, RUF003 from typing import Any, Final -import numpy as np import torch from torch.autograd import Function @@ -27,75 +26,43 @@ def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch. return d, m -"""Below code is taken from https://github.com/abhijeetgangan/torch_matfunc""" - - -def expm_frechet( - A: torch.Tensor | np.ndarray, - E: torch.Tensor | np.ndarray, - method: str | None = None, - check_finite: bool = True, -) -> torch.Tensor: - """Frechet derivative of the matrix exponential of A in the direction E. - - Args: - A: (N, N) tensor.Tensor or np.ndarray. - Matrix of which to take the matrix exponential. - E: (N, N) tensor.Tensor or np.ndarray. - Matrix direction in which to take the Frechet derivative. - method: str, optional. Choice of algorithm. Should be one of - - `SPS` (default) - - `blockEnlarge` - check_finite: bool, optional. Whether to check that the input matrix contains - only finite numbers. Disabling may give a performance gain, but may result - in problems (crashes, non-termination) if the inputs do contain - infinities or NaNs. Defaults to True. - - Returns: torch.Tensor. Frechet derivative of the matrix exponential of A - in the direction E - """ - return expm_frechet_with_matrix_exp(A, E, method, check_finite)[1] - - -def expm_frechet_with_matrix_exp( # noqa: C901 - A: torch.Tensor | np.ndarray, - E: torch.Tensor | np.ndarray, +def expm_frechet( # noqa: C901 + A: torch.Tensor, + E: torch.Tensor, method: str | None = None, check_finite: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Frechet derivative of the matrix exponential of A in the direction E. Args: - A: (N, N) tensor.Tensor or np.ndarray. - Matrix of which to take the matrix exponential. - E: (N, N) tensor.Tensor or np.ndarray. - Matrix direction in which to take the Frechet derivative. + A: (N, N) array_like. Matrix of which to take the matrix exponential. + E: (N, N) array_like. Matrix direction in which to take the Frechet derivative. method: str, optional. Choice of algorithm. Should be one of - `SPS` (default) - `blockEnlarge` check_finite: bool, optional. Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain - infinities or NaNs. Defaults to True. + infinities or NaNs. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - expm_A: Matrix exponential of A - - expm_frechet_AE: Frechet derivative of the matrix exponential of A - in the direction E + expm_A: Matrix exponential of A. + expm_frechet_AE: Frechet derivative of the matrix exponential of A + in the direction E. """ - # Convert inputs to torch tensors if they aren't already - if not isinstance(A, torch.Tensor): - A = torch.tensor(A, dtype=torch.float64) - if not isinstance(E, torch.Tensor): - E = torch.tensor(E, dtype=torch.float64) - if check_finite: if not torch.isfinite(A).all(): raise ValueError("Matrix A contains non-finite values") if not torch.isfinite(E).all(): raise ValueError("Matrix E contains non-finite values") + # Convert inputs to torch tensors if they aren't already + if not isinstance(A, torch.Tensor): + A = torch.tensor(A, dtype=torch.float64) + if not isinstance(E, torch.Tensor): + E = torch.tensor(E, dtype=torch.float64) + if A.dim() != 2 or A.shape[0] != A.shape[1]: raise ValueError("expected A to be a square matrix") if E.dim() != 2 or E.shape[0] != E.shape[1]: @@ -111,7 +78,7 @@ def expm_frechet_with_matrix_exp( # noqa: C901 elif method == "blockEnlarge": expm_A, expm_frechet_AE = expm_frechet_block_enlarge(A, E) else: - raise ValueError(f"Unknown implementation {method}") + raise ValueError(f"Unknown {method=}") return expm_A, expm_frechet_AE @@ -126,10 +93,9 @@ def expm_frechet_block_enlarge( E: Direction matrix Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - expm_A: Matrix exponential of A - - expm_frechet_AE: Frechet derivative of the matrix exponential of A - in the direction E. + expm_A: Matrix exponential of A + expm_frechet_AE: torch.Tensor + Frechet derivative of the matrix exponential of A in the direction E """ n = A.shape[0] # Create block matrix M = [[A, E], [0, A]] @@ -146,7 +112,7 @@ def expm_frechet_block_enlarge( # Maximal values ell_m of ||2**-s A|| such that the backward error bound # does not exceed 2**-53. ell_table_61: Final = ( - 0, + None, # 1 2.11e-8, 3.56e-4, @@ -419,30 +385,28 @@ def vec(M: torch.Tensor) -> torch.Tensor: def expm_frechet_kronform( - A: torch.Tensor | np.ndarray, method: str | None = None, check_finite: bool = True + A: torch.Tensor, method: str | None = None, check_finite: bool = True ) -> torch.Tensor: """Construct the Kronecker form of the Frechet derivative of expm. Args: - A: torch.Tensor or np.ndarray. - Square matrix tensor with shape (N, N) + A: Square matrix tensor with shape (N, N) method: Optional extra keyword to be passed to expm_frechet check_finite: Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. - Defaults to True. Returns: torch.Tensor: Kronecker form of the Frechet derivative of the matrix exponential with shape (N*N, N*N) """ + if check_finite and not torch.isfinite(A).all(): + raise ValueError("Matrix A contains non-finite values") + # Convert input to torch tensor if it isn't already if not isinstance(A, torch.Tensor): A = torch.tensor(A, dtype=torch.float64) - if check_finite and not torch.isfinite(A).all(): - raise ValueError("Matrix A contains non-finite values") - if A.dim() != 2 or A.shape[0] != A.shape[1]: raise ValueError("expected a square matrix") @@ -453,31 +417,31 @@ def expm_frechet_kronform( for i in range(n): for j in range(n): E = torch.outer(ident[i], ident[j]) - F = expm_frechet(A, E, method=method, check_finite=False) + _, F = expm_frechet(A, E, method=method, check_finite=False) cols.append(vec(F)) return torch.stack(cols, dim=1) -def expm_cond(A: torch.Tensor | np.ndarray, check_finite: bool = True) -> torch.Tensor: +def expm_cond(A: torch.Tensor, check_finite: bool = True) -> torch.Tensor: """Relative condition number of the matrix exponential in the Frobenius norm. Args: - A: torch.Tensor or np.ndarray. Square input matrix with shape (N, N) + A: Square input matrix with shape (N, N) check_finite: Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. - Defaults to True. Returns: - kappa: torch.Tensor. The relative condition number of the matrix exponential + kappa: The relative condition number of the matrix exponential in the Frobenius norm """ + if check_finite and not torch.isfinite(A).all(): + raise ValueError("Matrix A contains non-finite values") + # Convert input to torch tensor if it isn't already if not isinstance(A, torch.Tensor): A = torch.tensor(A, dtype=torch.float64) - if check_finite and not torch.isfinite(A).all(): - raise ValueError("Matrix A contains non-finite values") if A.dim() != 2 or A.shape[0] != A.shape[1]: raise ValueError("expected a square matrix") @@ -529,7 +493,8 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: (A,) = ctx.saved_tensors # Compute the Frechet derivative in the direction of grad_output - return expm_frechet(A, grad_output, method="SPS", check_finite=False) + _, frechet_deriv = expm_frechet(A, grad_output, method="SPS", check_finite=False) + return frechet_deriv def _is_valid_matrix(T: torch.Tensor, n: int = 3) -> bool: @@ -562,32 +527,30 @@ def _determine_eigenvalue_case( # noqa: C901 ValueError: If the eigenvalue structure cannot be determined """ # Get unique values and their counts directly with one call - unique_vals, counts = torch.unique(eigenvalues, return_counts=True) + uniq_vals, counts = torch.unique(eigenvalues, return_counts=True) # Use np.isclose to group eigenvalues that are numerically close # We can create a mask for each unique value to see if other values are close to it - if len(unique_vals) > 1: + if len(uniq_vals) > 1: # Check if some "unique" values should actually be considered the same i = 0 - while i < len(unique_vals): + while i < len(uniq_vals): # Find all values close to the current one - close_mask = torch.isclose(unique_vals, unique_vals[i], rtol=0, atol=num_tol) + close_mask = torch.isclose(uniq_vals, uniq_vals[i], rtol=0, atol=num_tol) close_count = torch.sum(close_mask) if close_count > 1: # If there are other close values # Merge them (keep the first one, remove the others) counts[i] = torch.sum(counts[close_mask]) - unique_vals = unique_vals[ - ~(close_mask & torch.arange(len(close_mask)) != i) - ] + uniq_vals = uniq_vals[~(close_mask & torch.arange(len(close_mask)) != i)] counts = counts[~(close_mask & torch.arange(len(counts)) != i)] else: i += 1 # Now determine the case based on the number of unique eigenvalues - if len(unique_vals) == 1: + if len(uniq_vals) == 1: # Case 1: All eigenvalues are equal (Ξ», Ξ», Ξ») - lambda_val = unique_vals[0] + lambda_val = uniq_vals[0] Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device) T_minus_lambdaI = T - lambda_val * Identity @@ -601,14 +564,14 @@ def _determine_eigenvalue_case( # noqa: C901 return "case1c" # q(T) = (T - Ξ»I)Β³ - if len(unique_vals) == 2: + if len(uniq_vals) == 2: # Case 2: Two distinct eigenvalues # The counts array already tells us which eigenvalue is repeated if counts.max() != 2 or counts.min() != 1: raise ValueError("Unexpected eigenvalue pattern for Case 2") - mu = unique_vals[torch.argmin(counts)] # The non-repeated eigenvalue - lambda_val = unique_vals[torch.argmax(counts)] # The repeated eigenvalue + mu = uniq_vals[torch.argmin(counts)] # The non-repeated eigenvalue + lambda_val = uniq_vals[torch.argmax(counts)] # The repeated eigenvalue Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device) T_minus_muI = T - mu * Identity @@ -622,7 +585,7 @@ def _determine_eigenvalue_case( # noqa: C901 return "case2a" # q(T) = (T - Ξ»I)(T - ΞΌI) return "case2b" # q(T) = (T - ΞΌI)(T - Ξ»I)Β² - if len(unique_vals) == 3: + if len(uniq_vals) == 3: # Case 3: Three distinct eigenvalues (Ξ», ΞΌ, Ξ½) return "case3" # q(T) = (T - Ξ»I)(T - ΞΌI)(T - Ξ½I) @@ -636,7 +599,7 @@ def _matrix_log_case1a(T: torch.Tensor, lambda_val: torch.Tensor) -> torch.Tenso Args: T: The matrix whose logarithm is to be computed - lambda_val: The eigenvalue of T (a complex number) + lambda_val: The eigenvalue of T as a tensor Returns: The logarithm of T, which is log(Ξ»)Β·I @@ -655,7 +618,7 @@ def _matrix_log_case1b( Args: T: The matrix whose logarithm is to be computed - lambda_val: The eigenvalue of T (a complex number) + lambda_val: The eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: @@ -670,7 +633,7 @@ def _matrix_log_case1b( scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) def _matrix_log_case1c( @@ -682,7 +645,7 @@ def _matrix_log_case1c( Args: T: The matrix whose logarithm is to be computed - lambda_val: The eigenvalue of T (a complex number) + lambda_val: The eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: @@ -699,8 +662,8 @@ def _matrix_log_case1c( lambda_squared = lambda_val * lambda_val term1 = torch.log(lambda_val) * Identity - term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] + term2 = T_minus_lambdaI / max(lambda_val, num_tol) + term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) return term1 + term2 - term3 @@ -715,8 +678,8 @@ def _matrix_log_case2a( Args: T: The matrix whose logarithm is to be computed - lambda_val: The repeated eigenvalue of T (a complex number) - mu: The non-repeated eigenvalue of T (a complex number) + lambda_val: The repeated eigenvalue of T + mu: The non-repeated eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: @@ -755,8 +718,8 @@ def _matrix_log_case2b( Args: T: The matrix whose logarithm is to be computed - lambda_val: The repeated eigenvalue of T (a complex number) - mu: The non-repeated eigenvalue of T (a complex number) + lambda_val: The repeated eigenvalue of T + mu: The non-repeated eigenvalue of T num_tol: Numerical tolerance for stability checks, default=1e-16 Returns: @@ -827,7 +790,7 @@ def _matrix_log_case3( # Check if eigenvalues are distinct enough for numerical stability if ( - min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) # type: ignore[call-overload] + min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) < num_tol ): raise ValueError("Eigenvalues are too close, computation may be unstable") @@ -903,7 +866,7 @@ def _matrix_log_33( # noqa: C901 case = _determine_eigenvalue_case(T, eigenvalues, num_tol) # Case 1: All eigenvalues are equal (Ξ», Ξ», Ξ») - if case in ["case1a", "case1b", "case1c"]: + if case in ("case1a", "case1b", "case1c"): lambda_val = eigenvalues[0] # Check for numerical stability @@ -918,18 +881,18 @@ def _matrix_log_33( # noqa: C901 return _matrix_log_case1c(T, lambda_val, num_tol) # Case 2: Two distinct eigenvalues (ΞΌ, Ξ», Ξ») - elif case in ["case2a", "case2b"]: + elif case in ("case2a", "case2b"): # Find the unique eigenvalue (ΞΌ) and the repeated eigenvalue (Ξ») - unique_vals, counts = torch.unique( + uniq_vals, counts = torch.unique( torch.round(eigenvalues, decimals=10), return_counts=True ) - if len(unique_vals) != 2 or counts.max() != 2: + if len(uniq_vals) != 2 or counts.max() != 2: raise ValueError( "Case 2 requires exactly two distinct eigenvalues with one repeated" ) - mu = unique_vals[torch.argmin(counts)] # The non-repeated eigenvalue - lambda_val = unique_vals[torch.argmax(counts)] # The repeated eigenvalue + mu = uniq_vals[torch.argmin(counts)] # The non-repeated eigenvalue + lambda_val = uniq_vals[torch.argmax(counts)] # The repeated eigenvalue if case == "case2a": return _matrix_log_case2a(T, lambda_val, mu, num_tol) @@ -944,7 +907,11 @@ def _matrix_log_33( # noqa: C901 lambda_val, mu, nu = torch.sort(eigenvalues).values # Sort for consistency return _matrix_log_case3(T, lambda_val, mu, nu, num_tol) - raise ValueError(f"Unknown eigenvalue {case=}") + else: + raise ValueError(f"Unknown eigenvalue {case=}") + + # should never be reached, just for type checker + raise RuntimeError("Unexpected code path in _matrix_log_33") def matrix_log_scipy(matrix: torch.Tensor) -> torch.Tensor: @@ -961,9 +928,7 @@ def matrix_log_scipy(matrix: torch.Tensor) -> torch.Tensor: import scipy.linalg # Save original device and dtype - device = matrix.device - dtype = matrix.dtype - requires_grad = matrix.requires_grad + device, dtype, requires_grad = matrix.device, matrix.dtype, matrix.requires_grad # Detach and move to CPU for scipy matrix_cpu = matrix.detach().cpu().numpy() @@ -1042,7 +1007,7 @@ def batched_vdot( if batch_indices.min() < 0: raise ValueError("batch_indices must be non-negative") - output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore[call-overload] + output = torch.zeros(int(batch_indices.max()) + 1, dtype=x.dtype, device=x.device) output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) return output diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 77b1b0bae..73320d144 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -1,28 +1,16 @@ -"""Wrapper for FairChem ecosystem models in TorchSim. +"""FairChem model wrapper for torch-sim. -This module provides a TorchSim wrapper of the FairChem models for computing -energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the FairChem library, integrating it with the torch_sim framework to enable seamless -simulation of atomistic systems with machine learning potentials. +Provides a TorchSim-compatible interface to FairChem models for computing +energies, forces, and stresses of atomistic systems. -The FairChemModel class adapts FairChem models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. - -Notes: - This implementation requires FairChem to be installed and accessible. - It supports various model configurations through configuration files or - pretrained model checkpoints. +Requires fairchem-core to be installed. """ -# ruff: noqa: T201 - from __future__ import annotations -import copy import traceback import typing import warnings -from types import MappingProxyType from typing import Any import torch @@ -32,21 +20,16 @@ try: - from fairchem.core.common.registry import registry - from fairchem.core.common.utils import ( - load_config, - setup_imports, - setup_logging, - update_config, - ) - from fairchem.core.models.model_registry import model_name_to_local_file - from torch_geometric.data import Batch, Data + from fairchem.core import pretrained_mlip + from fairchem.core.calculate.ase_calculator import UMATask + from fairchem.core.common.utils import setup_imports, setup_logging + from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) class FairChemModel(ModelInterface): - """FairChem model wrapper for torch_sim. + """FairChem model wrapper for torch-sim. This class is a placeholder for the FairChemModel class. It raises an ImportError if FairChem is not installed. @@ -63,92 +46,56 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict -_DTYPE_DICT = { - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", -} - class FairChemModel(ModelInterface): - """Computes atomistic energies, forces and stresses using a FairChem model. + """FairChem model wrapper for computing atomistic properties. - This class wraps a FairChem model to compute energies, forces, and stresses for - atomistic systems. It handles model initialization, checkpoint loading, and - provides a forward pass that accepts a SimState object and returns model - predictions. + Wraps FairChem models to compute energies, forces, and stresses. Can be + initialized with a model checkpoint path or pretrained model name. - The model can be initialized either with a configuration file or a pretrained - checkpoint. It supports various model architectures and configurations supported by - FairChem. + Uses the fairchem-core-2.2.0+ predictor API for batch inference. Attributes: - neighbor_list_fn (Callable | None): Function to compute neighbor lists - config (dict): Complete model configuration dictionary - trainer: FairChem trainer object that contains the model - data_object (Batch): Data object containing system information - implemented_properties (list): Model outputs the model can compute - pbc (bool): Whether periodic boundary conditions are used + predictor: The FairChem predictor for batch inference + task_name (UMATask): Task type for the model + _device (torch.device): Device where computation is performed _dtype (torch.dtype): Data type used for computation _compute_stress (bool): Whether to compute stress tensor - _compute_forces (bool): Whether to compute forces - _device (torch.device): Device where computation is performed - _reshaped_props (dict): Properties that need reshaping after computation + implemented_properties (list): Model outputs the model can compute Examples: >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) >>> results = model(state) """ - _reshaped_props = MappingProxyType( - {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} - ) - - def __init__( # noqa: C901, PLR0915 + def __init__( self, model: str | Path | None, neighbor_list_fn: Callable | None = None, *, # force remaining arguments to be keyword-only - config_yml: str | None = None, model_name: str | None = None, - local_cache: str | None = None, - trainer: str | None = None, cpu: bool = False, - seed: int | None = None, dtype: torch.dtype | None = None, compute_stress: bool = False, - pbc: bool = True, - disable_amp: bool = True, + task_name: UMATask | str | None = None, ) -> None: - """Initialize the FairChemModel with specified configuration. - - Loads a FairChem model from either a checkpoint path or a configuration file. - Sets up the model parameters, trainer, and configuration for subsequent use - in energy and force calculations. + """Initialize the FairChem model. Args: model (str | Path | None): Path to model checkpoint file neighbor_list_fn (Callable | None): Function to compute neighbor lists (not currently supported) - config_yml (str | None): Path to configuration YAML file model_name (str | None): Name of pretrained model to load - local_cache (str | None): Path to local model cache directory - trainer (str | None): Name of trainer class to use cpu (bool): Whether to use CPU instead of GPU for computation - seed (int | None): Random seed for reproducibility dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor - pbc (bool): Whether to use periodic boundary conditions - disable_amp (bool): Whether to disable AMP + task_name (UMATask | str | None): Task type for UMA models (optional, + only needed for UMA models) + Raises: RuntimeError: If both model_name and model are specified - NotImplementedError: If local_cache is not set when model_name is used NotImplementedError: If custom neighbor list function is provided - ValueError: If stress computation is requested but not supported by model - - Notes: - Either config_yml or model must be provided. The model loads configuration - from the checkpoint if config_yml is not specified. + ValueError: If neither model nor model_name is provided """ setup_imports() setup_logging() @@ -158,7 +105,11 @@ def __init__( # noqa: C901, PLR0915 self._compute_stress = compute_stress self._compute_forces = True self._memory_scales_with = "n_atoms" - self.pbc = pbc + + if neighbor_list_fn is not None: + raise NotImplementedError( + "Custom neighbor list is not supported for FairChemModel." + ) if model_name is not None: if model is not None: @@ -166,166 +117,42 @@ def __init__( # noqa: C901, PLR0915 "model_name and checkpoint_path were both specified, " "please use only one at a time" ) - if local_cache is None: - raise NotImplementedError( - "Local cache must be set when specifying a model name" - ) - model = model_name_to_local_file( - model_name=model_name, local_cache=local_cache - ) + model = model_name - # Either the config path or the checkpoint path needs to be provided - if not config_yml and model is None: - raise ValueError("Either config_yml or model must be provided") - - checkpoint = None - if config_yml is not None: - if isinstance(config_yml, str): - config, duplicates_warning, duplicates_error = load_config(config_yml) - if len(duplicates_warning) > 0: - print( - "Overwritten config parameters from included configs " - f"(non-included parameters take precedence): {duplicates_warning}" - ) - if len(duplicates_error) > 0: - raise ValueError( - "Conflicting (duplicate) parameters in simultaneously " - f"included configs: {duplicates_error}" - ) - else: - config = config_yml - - # Only keeps the train data that might have normalizer values - if isinstance(config["dataset"], list): - config["dataset"] = config["dataset"][0] - elif isinstance(config["dataset"], dict): - config["dataset"] = config["dataset"].get("train", None) - else: - # Loads the config from the checkpoint directly (always on CPU). - checkpoint = torch.load(model, map_location=torch.device("cpu")) - config = checkpoint["config"] - - if trainer is not None: - config["trainer"] = trainer - else: - config["trainer"] = config.get("trainer", "ocp") - - if "model_attributes" in config: - config["model_attributes"]["name"] = config.pop("model") - config["model"] = config["model_attributes"] - - self.neighbor_list_fn = neighbor_list_fn - - if neighbor_list_fn is None: - # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - else: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemModel." - ) + if model is None: + raise ValueError("Either model or model_name must be provided") - if "backbone" in config["model"]: - config["model"]["backbone"]["use_pbc"] = pbc - config["model"]["backbone"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"]["backbone"].update({"dtype": _DTYPE_DICT[dtype]}) - for key in config["model"]["heads"]: - config["model"]["heads"][key].update( - {"dtype": _DTYPE_DICT[dtype]} - ) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - else: - config["model"]["use_pbc"] = pbc - config["model"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"].update({"dtype": _DTYPE_DICT[dtype]}) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - - ### backwards compatibility with OCP v<2.0 - config = update_config(config) - - self.config = copy.deepcopy(config) - self.config["checkpoint"] = str(model) - del config["dataset"]["src"] - - self.trainer = registry.get_trainer_class(config["trainer"])( - task=config.get("task", {}), - model=config["model"], - dataset=[config["dataset"]], - outputs=config["outputs"], - loss_functions=config["loss_functions"], - evaluation_metrics=config["evaluation_metrics"], - optimizer=config["optim"], - identifier="", - slurm=config.get("slurm", {}), - local_rank=config.get("local_rank", 0), - is_debug=config.get("is_debug", True), - cpu=cpu, - amp=False if dtype is not None else config.get("amp", False), - inference_only=True, - ) + # Convert task_name to UMATask if it's a string (only for UMA models) + if isinstance(task_name, str): + task_name = UMATask(task_name) - if dtype is not None: - # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self.dtype) + # Use the efficient predictor API for optimal performance + device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu" + self._device = torch.device(device_str) + self.task_name = task_name - if model is not None: - self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) + # Create efficient batch predictor for fast inference + self.predictor = pretrained_mlip.get_predict_unit(str(model), device=device_str) - seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] - if seed is None: - print( - "No seed has been set in model checkpoint or OCPCalculator! Results may " - "not be reproducible on re-run" - ) - else: - self.trainer.set_seed(seed) - - if disable_amp: - self.trainer.scaler = None - - self.implemented_properties = list(self.config["outputs"]) + # Determine implemented properties + # This is a simplified approach - in practice you might want to + # inspect the model configuration more carefully + self.implemented_properties = ["energy", "forces"] + if compute_stress: + self.implemented_properties.append("stress") - self._device = self.trainer.device - - stress_output = "stress" in self.implemented_properties - if not stress_output and compute_stress: - raise NotImplementedError("Stress output not implemented for this model") - - def load_checkpoint( - self, checkpoint_path: str, checkpoint: dict | None = None - ) -> None: - """Load an existing trained model checkpoint. + @property + def dtype(self) -> torch.dtype: + """Return the data type used by the model.""" + return self._dtype - Loads model parameters from a checkpoint file or dictionary, - setting the model to inference mode. - - Args: - checkpoint_path (str): Path to the trained model checkpoint file - checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, - this dictionary is used instead of loading from checkpoint_path. - - Notes: - If loading fails, a message is printed but no exception is raised. - """ - try: - self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) - except NotImplementedError: - print("Unable to load checkpoint!") + @property + def device(self) -> torch.device: + """Return the device where the model is located.""" + return self._device def forward(self, state: ts.SimState | StateDict) -> dict: - """Perform forward pass to compute energies, forces, and other properties. - - Takes a simulation state and computes the properties implemented by the model, - such as energy, forces, and stresses. + """Compute energies, forces, and other properties. Args: state (SimState | StateDict): State object containing positions, cells, @@ -336,63 +163,72 @@ def forward(self, state: ts.SimState | StateDict) -> dict: dict: Dictionary of model predictions, which may include: - energy (torch.Tensor): Energy with shape [batch_size] - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], - if compute_stress is True - - Notes: - The state is automatically transferred to the model's device if needed. - All output tensors are detached from the computation graph. + - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.device != self._device: - state = state.to(self._device) + if sim_state.device != self._device: + sim_state = sim_state.to(self._device) - if state.system_idx is None: - state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) + # Ensure system_idx has integer dtype (SimState guarantees presence) + if sim_state.system_idx.dtype != torch.int64: + sim_state.system_idx = sim_state.system_idx.to(dtype=torch.int64) - if self.pbc != state.pbc: - raise ValueError( - "PBC mismatch between model and state. " - "For FairChemModel PBC needs to be defined in the model class." - ) + # Convert SimState to AtomicData objects for efficient batch processing + from ase import Atoms - natoms = torch.bincount(state.system_idx) - fixed = torch.zeros((state.system_idx.size(0), natoms.sum()), dtype=torch.int) - data_list = [] - for i, (n, c) in enumerate( - zip(natoms, torch.cumsum(natoms, dim=0), strict=False) + n_atoms = torch.bincount(sim_state.system_idx) + atomic_data_list = [] + + for idx, (n, c) in enumerate( + zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False) ): - data_list.append( - Data( - pos=state.positions[c - n : c].clone(), - cell=state.row_vector_cell[i, None].clone(), - atomic_numbers=state.atomic_numbers[c - n : c].clone(), - fixed=fixed[c - n : c].clone(), - natoms=n, - pbc=torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool), - ) + # Extract system data + positions = sim_state.positions[c - n : c].cpu().numpy() + atomic_nums = sim_state.atomic_numbers[c - n : c].cpu().numpy() + cell = ( + sim_state.row_vector_cell[idx].cpu().numpy() + if sim_state.row_vector_cell is not None + else None ) - self.data_object = Batch.from_data_list(data_list) - if self.dtype is not None: - self.data_object.pos = self.data_object.pos.to(self.dtype) - self.data_object.cell = self.data_object.cell.to(self.dtype) + # Create ASE Atoms object first + atoms = Atoms( + numbers=atomic_nums, + positions=positions, + cell=cell, + pbc=sim_state.pbc if cell is not None else False, + ) - predictions = self.trainer.predict( - self.data_object, per_image=False, disable_tqdm=True - ) + # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + if self.task_name is None: + atomic_data = AtomicData.from_ase(atoms) + else: + atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + atomic_data_list.append(atomic_data) + + # Create batch for efficient inference + batch = atomicdata_list_to_batch(atomic_data_list) + batch = batch.to(self._device) + + # Run efficient batch prediction + predictions = self.predictor.predict(batch) - results = {} + # Convert predictions to torch-sim format + results: dict[str, torch.Tensor] = {} + results["energy"] = predictions["energy"].to(dtype=self._dtype) + results["forces"] = predictions["forces"].to(dtype=self._dtype) - for key in predictions: - _pred = predictions[key] - if key in self._reshaped_props: - _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred.detach() + # Handle stress if requested and available + if self._compute_stress and "stress" in predictions: + stress = predictions["stress"].to(dtype=self._dtype) + # Ensure stress has correct shape [batch_size, 3, 3] + if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): + stress = stress.view(-1, 3, 3) + results["stress"] = stress - results["energy"] = results["energy"].squeeze(dim=1) - if results.get("stress") is not None and len(results["stress"].shape) == 2: - results["stress"] = results["stress"].unsqueeze(dim=0) return results diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index a7b287e07..188a002de 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -1,12 +1,12 @@ -"""An interface for using arbitrary GraphPESModels in torch_sim. +"""An interface for using arbitrary GraphPESModels in ts. This module provides a TorchSim wrapper of the GraphPES models for computing energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the graph_pes library, integrating it with the torch_sim framework to enable seamless +the graph_pes library, integrating it with the torch-sim framework to enable seamless simulation of atomistic systems with machine learning potentials. The GraphPESWrapper class adapts GraphPESModels to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. +allowing them to be used within the broader torch-sim simulation framework. Notes: This implementation requires graph_pes to be installed and accessible. @@ -14,7 +14,6 @@ """ import traceback -import typing import warnings from pathlib import Path from typing import Any @@ -37,7 +36,7 @@ PropertyKey = str class GraphPESWrapper(ModelInterface): # type: ignore[reportRedeclaration] - """GraphPESModel wrapper for torch_sim. + """GraphPESModel wrapper for torch-sim. This class is a placeholder for the GraphPESWrapper class. It raises an ImportError if graph_pes is not installed. @@ -67,34 +66,28 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra """ graphs = [] - for i in range(state.n_systems): - system_mask = state.system_idx == i + for sys_idx in range(state.n_systems): + system_mask = state.system_idx == sys_idx R = state.positions[system_mask] Z = state.atomic_numbers[system_mask] - cell = state.row_vector_cell[i] - nl, shifts = vesin_nl_ts( - R, - cell, - state.pbc, - # graph-pes models internally trim the neighbour list to the - # model's cutoff value. To ensure no strange edge effects whereby - # edges that are exactly `cutoff` long are included/excluded, - # we bump this up slightly here - cutoff + 1e-5, - ) - - graphs.append( - AtomicGraph( - Z=Z.long(), - R=R, - cell=cell, - neighbour_list=nl.long(), - neighbour_cell_offsets=shifts, - properties={}, - cutoff=cutoff.item(), - other={}, - ) + cell = state.row_vector_cell[sys_idx] + # graph-pes models internally trim the neighbor list to the + # model's cutoff value. To ensure no strange edge effects whereby + # edges that are exactly `cutoff` long are included/excluded, + # we bump cutoff + 1e-5 up slightly + nl, shifts = vesin_nl_ts(R, cell, state.pbc, cutoff + 1e-5) + + atomic_graph = AtomicGraph( + Z=Z.long(), + R=R, + cell=cell, + neighbour_list=nl.long(), + neighbour_cell_offsets=shifts, + properties={}, + cutoff=cutoff.item(), + other={}, ) + graphs.append(atomic_graph) return to_batch(graphs) @@ -103,7 +96,7 @@ class GraphPESWrapper(ModelInterface): """Wrapper for GraphPESModel in TorchSim. This class provides a TorchSim wrapper around GraphPESModel instances, - allowing them to be used within the broader torch_sim simulation framework. + allowing them to be used within the broader torch-sim simulation framework. The graph-pes package allows for the training of existing model architectures, including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. @@ -154,12 +147,7 @@ def __init__( ) self._dtype = dtype - _model = typing.cast( - "GraphPESModel", - ( - model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type] - ), - ) + _model = model if isinstance(model, GraphPESModel) else load_model(model) self._gp_model = _model.to(device=self.device, dtype=self.dtype) self._compute_forces = compute_forces diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 27c032779..58f233e84 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -1,6 +1,6 @@ -"""Core interfaces for all models in torchsim. +"""Core interfaces for all models in TorchSim. -This module defines the abstract base class that all torchsim models must implement. +This module defines the abstract base class that all TorchSim models must implement. It establishes a common API for interacting with different force and energy models, ensuring consistent behavior regardless of the underlying implementation. The module also provides validation utilities to verify model conformance to the interface. @@ -36,7 +36,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): class ModelInterface(torch.nn.Module, ABC): - """Abstract base class for all simulation models in torchsim. + """Abstract base class for all simulation models in TorchSim. This interface provides a common structure for all energy and force models, ensuring they implement the required methods and properties. It defines how @@ -208,14 +208,14 @@ def validate_model_outputs( # noqa: C901, PLR0915 try: if not model.compute_stress: - model.compute_stress = True + model.compute_stress = True # type: ignore[unresolved-attribute] stress_computed = True except NotImplementedError: stress_computed = False try: if not model.compute_forces: - model.compute_forces = True + model.compute_forces = True # type: ignore[unresolved-attribute] force_computed = True except NotImplementedError: force_computed = False @@ -228,7 +228,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() og_system_idx = sim_state.system_idx.clone() - og_atomic_numbers = sim_state.atomic_numbers.clone() + og_atomic_nums = sim_state.atomic_numbers.clone() model_output = model.forward(sim_state) @@ -239,8 +239,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_cell=} != {sim_state.cell=}") if not torch.allclose(og_system_idx, sim_state.system_idx): raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") - if not torch.allclose(og_atomic_numbers, sim_state.atomic_numbers): - raise ValueError(f"{og_atomic_numbers=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): + raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") # assert model output has the correct keys if "energy" not in model_output: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 2a05e2f8c..80c115dba 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -33,14 +33,14 @@ from torch_sim.typing import StateDict -DEFAULT_SIGMA = torch.tensor(1.0) -DEFAULT_EPSILON = torch.tensor(1.0) +DEFAULT_SIGMA = 1.0 +DEFAULT_EPSILON = 1.0 def lennard_jones_pair( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, + sigma: float | torch.Tensor = DEFAULT_SIGMA, + epsilon: float | torch.Tensor = DEFAULT_EPSILON, ) -> torch.Tensor: """Calculate pairwise Lennard-Jones interaction energies between particles. @@ -78,8 +78,8 @@ def lennard_jones_pair( def lennard_jones_pair_force( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, + sigma: float | torch.Tensor = DEFAULT_SIGMA, + epsilon: float | torch.Tensor = DEFAULT_EPSILON, ) -> torch.Tensor: """Calculate pairwise Lennard-Jones forces between particles. @@ -271,18 +271,12 @@ def unbatched_forward( ) # Get displacements using neighbor list dr_vec, distances = transforms.get_pair_displacements( - positions=positions, - cell=cell, - pbc=pbc, - pairs=mapping, - shifts=shifts, + positions=positions, cell=cell, pbc=pbc, pairs=mapping, shifts=shifts ) else: # Get all pairwise displacements dr_vec, distances = transforms.get_pair_displacements( - positions=positions, - cell=cell, - pbc=pbc, + positions=positions, cell=cell, pbc=pbc ) # Mask out self-interactions mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) @@ -391,19 +385,24 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: energies = results["energies"] # Shape: [n_atoms] stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.system_idx is None and state.cell.shape[0] > 1: + if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError("System can only be inferred for batch size 1.") - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] + outputs = [ + self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems) + ] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors - results = {} + results: dict[str, torch.Tensor] = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 5ca2a629c..fda955e45 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -39,7 +39,7 @@ warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2) class MaceModel(ModelInterface): - """MACE model wrapper for torch_sim. + """MACE model wrapper for torch-sim. This class is a placeholder for the MaceModel class. It raises an ImportError if MACE is not installed. @@ -68,7 +68,7 @@ def to_one_hot( torch.Tensor: A tensor of shape (N x num_classes) containing the one-hot encodings. """ - shape = indices.shape[:-1] + (num_classes,) + shape = (*indices.shape[:-1], num_classes) oh = torch.zeros(shape, device=indices.device, dtype=dtype).view(shape) # scatter_ is the in-place version of scatter @@ -217,8 +217,8 @@ def setup_from_system_idx( # Create ptr tensor for system boundaries self.n_atoms_per_system = [] ptr = [0] - for i in range(self.n_systems): - system_mask = system_idx == i + for sys_idx in range(self.n_systems): + system_mask = system_idx == sys_idx n_atoms = system_mask.sum().item() self.n_atoms_per_system.append(n_atoms) ptr.append(ptr[-1] + n_atoms) @@ -229,7 +229,7 @@ def setup_from_system_idx( # Create one-hot encodings for all atoms self.node_attrs = to_one_hot( torch.tensor( - atomic_numbers_to_indices(atomic_numbers.cpu(), z_table=self.z_table), + atomic_numbers_to_indices(atomic_numbers.numpy(), z_table=self.z_table), dtype=torch.long, device=self.device, ).unsqueeze(-1), @@ -237,10 +237,7 @@ def setup_from_system_idx( dtype=self.dtype, ) - def forward( # noqa: C901 - self, - state: ts.SimState | StateDict, - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901 """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and @@ -264,38 +261,40 @@ def forward( # noqa: C901 or in the forward pass, or if provided in both places. ValueError: If system indices are not provided when needed. """ - # Extract required data from input - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) # Handle input validation for atomic numbers - if state.atomic_numbers is None and not self.atomic_numbers_in_init: + if sim_state.atomic_numbers is None and not self.atomic_numbers_in_init: raise ValueError( "Atomic numbers must be provided in either the constructor or forward." ) - if state.atomic_numbers is not None and self.atomic_numbers_in_init: + if sim_state.atomic_numbers is not None and self.atomic_numbers_in_init: raise ValueError( "Atomic numbers cannot be provided in both the constructor and forward." ) # Use system_idx from init if not provided - if state.system_idx is None: + if sim_state.system_idx is None: if not hasattr(self, "system_idx"): raise ValueError( "System indices must be provided if not set during initialization" ) - state.system_idx = self.system_idx + sim_state.system_idx = self.system_idx # Update system_idx information if new atomic numbers are provided if ( - state.atomic_numbers is not None + sim_state.atomic_numbers is not None and not self.atomic_numbers_in_init and not torch.equal( - state.atomic_numbers, + sim_state.atomic_numbers, getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), ) ): - self.setup_from_system_idx(state.atomic_numbers, state.system_idx) + self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) # Process each system's neighbor list separately edge_indices = [] @@ -304,25 +303,25 @@ def forward( # noqa: C901 offset = 0 # TODO (AG): Currently doesn't work for batched neighbor lists - for b in range(self.n_systems): - system_mask = state.system_idx == b + for sys_idx in range(self.n_systems): + system_mask = sim_state.system_idx == sys_idx # Calculate neighbor list for this system edge_idx, shifts_idx = self.neighbor_list_fn( - positions=state.positions[system_mask], - cell=state.row_vector_cell[b], - pbc=state.pbc, + positions=sim_state.positions[system_mask], + cell=sim_state.row_vector_cell[sys_idx], + pbc=sim_state.pbc, cutoff=self.r_max, ) # Adjust indices for the system edge_idx = edge_idx + offset - shifts = torch.mm(shifts_idx, state.row_vector_cell[b]) + shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx]) edge_indices.append(edge_idx) unit_shifts_list.append(shifts_idx) shifts_list.append(shifts) - offset += len(state.positions[system_mask]) + offset += len(sim_state.positions[system_mask]) # Combine all neighbor lists edge_index = torch.cat(edge_indices, dim=1) @@ -334,10 +333,10 @@ def forward( # noqa: C901 dict( ptr=self.ptr, node_attrs=self.node_attrs, - batch=state.system_idx, - pbc=state.pbc, - cell=state.row_vector_cell, - positions=state.positions, + batch=sim_state.system_idx, + pbc=sim_state.pbc, + cell=sim_state.row_vector_cell, + positions=sim_state.positions, edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, @@ -346,7 +345,7 @@ def forward( # noqa: C901 compute_stress=self.compute_stress, ) - results = {} + results: dict[str, torch.Tensor] = {} # Process energy energy = out["energy"] diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index 9b2efb23a..e074a970a 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -22,7 +22,7 @@ warnings.warn(f"MatterSim import failed: {traceback.format_exc()}", stacklevel=2) class MatterSimModel(ModelInterface): - """MatterSim model wrapper for torch_sim. + """MatterSim model wrapper for torch-sim. This class is a placeholder for the MatterSimModel class. It raises an ImportError if sevenn is not installed. @@ -132,13 +132,16 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.device != self._device: - state = state.to(self._device) + if sim_state.device != self._device: + sim_state = sim_state.to(self._device) - atoms_list = ts.io.state_to_atoms(state) + atoms_list = ts.io.state_to_atoms(sim_state) data_list = [self.convertor.convert(atoms) for atoms in atoms_list] batched_data = Collater([], follow_batch=None, exclude_keys=None)(data_list) batched_data.to(self._device) @@ -148,7 +151,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: include_stresses=self.compute_stress, ) - results = {} + results: dict[str, torch.Tensor] = {} results["energy"] = output["total_energy"].detach() results["forces"] = output["forces"].detach() results["stress"] = self.stress_weight * output["stresses"].detach() diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 5655ed889..1e97cddac 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -5,7 +5,7 @@ for multiple systems simultaneously. The MetatomicModel class adapts metatomic models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. +allowing them to be used within the broader torch-sim simulation framework. Notes: This module depends on the metatomic-torch package. @@ -37,7 +37,7 @@ warnings.warn(f"Metatomic import failed: {traceback.format_exc()}", stacklevel=2) class MetatomicModel(ModelInterface): - """Metatomic model wrapper for torch_sim. + """Metatomic model wrapper for torch-sim. This class is a placeholder for the MetatomicModel class. It raises an ImportError if metatomic is not installed. @@ -105,10 +105,10 @@ def __init__( if model == "pet-mad": path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" self._model = load_model(path).export() - elif model.endswith(".ckpt"): + elif str(model).endswith(".ckpt"): path = model self._model = load_model(path).export() - elif model.endswith(".pt"): + elif str(model).endswith(".pt"): path = model self._model = load_atomistic_model(path, extensions_path) else: @@ -117,7 +117,7 @@ def __init__( if "energy" not in self._model.capabilities().outputs: raise ValueError( "This model does not support energy predictions. " - "The model must have an `energy` output to be used in torch-sim." + "The model must have an `energy` output to be used in TorchSim." ) self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") @@ -139,19 +139,10 @@ def __init__( self._requested_neighbor_lists = self._model.requested_neighbor_lists() self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", - outputs={ - "energy": ModelOutput( - quantity="energy", - unit="eV", - per_atom=False, - ) - }, + outputs={"energy": ModelOutput(quantity="energy", unit="eV", per_atom=False)}, ) - def forward( # noqa: C901, PLR0915 - self, - state: ts.SimState | StateDict, - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901, PLR0915 """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and @@ -170,17 +161,19 @@ def forward( # noqa: C901, PLR0915 - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True """ - # Extract required data from input - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) # Input validation is already done inside the forward method of the # AtomisticModel class, so we don't need to do it again here. - atomic_numbers = state.atomic_numbers - cell = state.row_vector_cell - positions = state.positions - pbc = state.pbc + atomic_nums = sim_state.atomic_numbers + cell = sim_state.row_vector_cell + positions = sim_state.positions + pbc = sim_state.pbc # Check dtype (metatomic models require a specific input dtype) if positions.dtype != self._dtype: @@ -199,14 +192,14 @@ def forward( # noqa: C901, PLR0915 # Process each system separately systems: list[System] = [] strains = [] - for b in range(len(cell)): - system_mask = state.system_idx == b + for sys_idx in range(len(cell)): + system_mask = sim_state.system_idx == sys_idx system_positions = positions[system_mask] - system_cell = cell[b] + system_cell = cell[sys_idx] system_pbc = torch.tensor( [pbc, pbc, pbc], device=self._device, dtype=torch.bool ) - system_atomic_numbers = atomic_numbers[system_mask] + system_atomic_numbers = atomic_nums[system_mask] # Create a System object for this system if self._compute_forces: @@ -245,7 +238,7 @@ def forward( # noqa: C901, PLR0915 check_consistency=self._check_consistency, ) - results = {} + results: dict[str, torch.Tensor] = {} results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) # Compute forces and/or stresses if requested diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 702dc41fa..a3bebbb2b 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -34,16 +34,16 @@ from torch_sim.typing import StateDict -DEFAULT_SIGMA = torch.tensor(1.0) -DEFAULT_EPSILON = torch.tensor(5.0) -DEFAULT_ALPHA = torch.tensor(5.0) +DEFAULT_SIGMA = 1.0 +DEFAULT_EPSILON = 5.0 +DEFAULT_ALPHA = 5.0 def morse_pair( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, - alpha: torch.Tensor = DEFAULT_ALPHA, + sigma: float | torch.Tensor = DEFAULT_SIGMA, + epsilon: float | torch.Tensor = DEFAULT_EPSILON, + alpha: float | torch.Tensor = DEFAULT_ALPHA, ) -> torch.Tensor: """Calculate pairwise Morse potential energies between particles. @@ -73,14 +73,13 @@ def morse_pair( # Handle potential numerical instabilities return torch.where(dr > 0, energy, torch.zeros_like(energy)) - # return torch.nan_to_num(energy, nan=0.0, posinf=0.0, neginf=0.0) def morse_pair_force( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, - alpha: torch.Tensor = DEFAULT_ALPHA, + sigma: float | torch.Tensor = DEFAULT_SIGMA, + epsilon: float | torch.Tensor = DEFAULT_EPSILON, + alpha: float | torch.Tensor = DEFAULT_ALPHA, ) -> torch.Tensor: """Calculate pairwise Morse forces between particles. @@ -256,13 +255,16 @@ def unbatched_forward( This method can work with both neighbor list and full pairwise calculations. In both cases, interactions are truncated at the cutoff distance. """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - positions = state.positions - cell = state.row_vector_cell + positions = sim_state.positions + cell = sim_state.row_vector_cell cell = cell.squeeze() - pbc = state.pbc + pbc = sim_state.pbc if self.use_neighbor_list: mapping, shifts = vesin_nl_ts( @@ -320,20 +322,20 @@ def unbatched_forward( force_vectors = (pair_forces / distances)[:, None] * dr_vec if self.compute_forces: - forces = torch.zeros_like(state.positions) + forces = torch.zeros_like(sim_state.positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self.compute_stress and state.cell is not None: + if self.compute_stress and sim_state.cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volume = torch.abs(torch.linalg.det(state.cell)) + volume = torch.abs(torch.linalg.det(sim_state.cell)) results["stress"] = -stress_per_pair.sum(dim=0) / volume if self._per_atom_stresses: atom_stresses = torch.zeros( - (state.positions.shape[0], 3, 3), + (sim_state.positions.shape[0], 3, 3), dtype=self.dtype, device=self.device, ) @@ -376,21 +378,26 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: forces = results["forces"] # Shape: [n_atoms, 3] ``` """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.system_idx is None and state.cell.shape[0] > 1: + if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError( "system_idx can only be inferred if there is only one system." ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] + outputs = [ + self.unbatched_forward(sim_state[i]) for i in range(sim_state.n_systems) + ] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors - results = {} + results: dict[str, torch.Tensor] = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py new file mode 100644 index 000000000..916281a52 --- /dev/null +++ b/torch_sim/models/nequip_framework.py @@ -0,0 +1,379 @@ +"""Wrapper for NequIP-Allegro models in TorchSim. + +This module provides a TorchSim wrapper of the NequIP-Allegro models for computing +energies, forces, and stresses for atomistic systems. It integrates the NequIP-Allegro +models with TorchSim's simulation framework, handling batched computations for multiple +systems simultaneously. + +The implementation supports various features including: + +* Computing energies, forces, and stresses +* Batched calculations for multiple systems + +References: + - NequIP Package: https://github.com/mir-group/nequip +""" + +import traceback +import warnings +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import ase.data +import torch + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import vesin_nl_ts +from torch_sim.typing import StateDict + + +try: + from nequip.model.inference_models import load_compiled_model + from nequip.nn import graph_model + from nequip.scripts._compile_utils import ASE_OUTPUTS, PAIR_NEQUIP_INPUTS +except ImportError as exc: + warnings.warn(f"NequIP import failed: {traceback.format_exc()}", stacklevel=2) + + class NequIPFrameworkModel(ModelInterface): + """NequIP model wrapper for torch-sim. + + This class is a placeholder for the NequIPModel class. + It raises an ImportError if NequIP is not installed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + +class ChemicalSpeciesToAtomTypeMapper: + """Maps atomic numbers to model-specific atom type indices. + + This class provides functionality to map atomic numbers to the corresponding atom type + indices used by the model. It handles cases where the model's internal representation + of atom types may differ from conventional chemical species, such as when modeling + different charge states of the same element. + + The mapping is created using a lookup table that converts atomic numbers to + zero-based indices based on the provided list of chemical symbols. The order of + chemical symbols must match the order of atom types expected by the model. + + NOTE: This is adapted from the NequIP package. + + Attributes: + lookup_table (torch.Tensor): Tensor mapping atomic numbers to model type indices. + Contains -1 for unmapped atomic numbers. + + Args: + chemical_symbols (list[str]): List of chemical symbols in the order matching + the model's internal type ordering. Each symbol must be a valid chemical + element symbol. + + Raises: + AssertionError: If an invalid chemical symbol is provided. + """ + + def __init__(self, chemical_symbols: list[str]) -> None: # noqa: D107 + # Create lookup table mapping atomic numbers to model type indices + self.lookup_table = torch.full( + (max(ase.data.atomic_numbers.values()),), -1, dtype=torch.long + ) + for idx, sym in enumerate(chemical_symbols): + assert sym in ase.data.atomic_numbers, f"Invalid chemical symbol {sym}" # noqa: S101 + self.lookup_table[ase.data.atomic_numbers[sym]] = idx + + def __call__(self, atomic_numbers: torch.Tensor) -> torch.Tensor: + """Convert atomic numbers to model-specific atom type indices. + + Args: + atomic_numbers (torch.Tensor): Tensor of atomic numbers to convert. + + Returns: + torch.Tensor: Atom type indices used by the model. + """ + return torch.index_select(self.lookup_table, 0, atomic_numbers) + + +def from_compiled_model( + compile_path: str | Path, device: str | torch.device = "cpu" +) -> tuple[torch.nn.Module, tuple[float, list[str]]]: + """Load a compiled NequIP model from a file. + + Loads a compiled NequIP model from a file and extracts the necessary metadata + for using it in TorchSim. The model must have been compiled using nequip-compile. + + Args: + compile_path (str): Path to the compiled model file. The file should have been + created using nequip-compile. + device (str | torch.device): Device to load the model on. Can be either a string + like 'cpu' or 'cuda', or a torch.device object. Defaults to 'cpu'. + + Returns: + tuple[torch.nn.Module, tuple[float, list[str]]]: A tuple containing: + - The loaded NequIP model as a torch.nn.Module + - A tuple with: + - r_max (float): Cutoff radius used by the model + - type_names (list[str]): List of chemical symbols supported by the model + + Example: + >>> model, (r_max, type_names) = from_compiled_model("model.pth", device="cuda") + >>> print(f"Model cutoff: {r_max:.2f}") + >>> print(f"Supported elements: {type_names}") + + References: + For model compilation please refer to the NequIP documentation: + https://nequip.readthedocs.io/en/latest/guide/getting-started/workflow.html#compilation + """ + model, metadata = load_compiled_model( + str(compile_path), device, PAIR_NEQUIP_INPUTS, ASE_OUTPUTS + ) + + # extract r_max and type_names for transforms + r_max = metadata[graph_model.R_MAX_KEY] + type_names = metadata[graph_model.TYPE_NAMES_KEY] + + return model, (r_max, type_names) + + +class NequIPFrameworkModel(ModelInterface): + """NequIP model for energy, force and stress calculations. + + This class wraps a NequIP model to compute energies, forces and stresses + for atomic systems. + + Args: + model (torch.nn.Module): The NequIP model to use. Must be a torch.nn.Module. + r_max (float): Cutoff radius for neighbor list construction. + type_names (list[str]): List of chemical symbols supported by the model. + device (torch.device | None): Device to run calculations on. + Defaults to CUDA if available, otherwise CPU. + neighbor_list_fn (Callable): Function to compute neighbor lists. + Defaults to vesin_nl_ts. + atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. + If provided at initialization, cannot be provided again during forward pass. + system_idx (torch.Tensor | None): Batch indices with shape [n_atoms] indicating + which system each atom belongs to. If not provided with atomic_numbers, + all atoms are assumed to be in the same system. + """ + + def __init__( + self, + model: str | Path | torch.nn.Module | None = None, + *, + r_max: float, + type_names: list[str], + device: torch.device | None = None, + neighbor_list_fn: Callable = vesin_nl_ts, + atomic_numbers: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, + ) -> None: + """Initialize the NequIP model. + + Args: + model: The NequIP model to use. Must be a torch.nn.Module. + r_max: Cutoff radius for neighbor list construction. + type_names: List of chemical symbols supported by the model. + device: Device to run calculations on. + Defaults to CUDA if available, otherwise CPU. + neighbor_list_fn: Function to compute neighbor lists. Defaults to vesin_nl_ts. + atomic_numbers: Atomic numbers with shape [n_atoms]. If provided at + initialization, cannot be provided again during forward pass. + system_idx: Batch indices with shape [n_atoms] indicating which system + each atom belongs to. If not provided with atomic_numbers, all atoms + are assumed to be in the same system. If provided, must be a tensor + of long integers. + + Raises: + TypeError: If model is not a torch.nn.Module. + """ + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.neighbor_list_fn = neighbor_list_fn + self._memory_scales_with = "n_atoms_x_density" + self._compute_forces = True + self._compute_stress = True + + if isinstance(model, torch.nn.Module): + self.model = model + else: + raise TypeError("Invalid model type. Must be a torch.nn.Module.") + + # Set model properties + # using float64 for the cutoff radius (neighbor list) + self.r_max = torch.tensor(r_max, dtype=torch.float64, device=self.device) + self.type_names = type_names + + # Store flag to track if atomic numbers were provided at init + self.atomic_numbers_in_init = atomic_numbers is not None + self.n_systems = system_idx.max().item() + 1 if system_idx is not None else 1 + + # Set up batch information if atomic numbers are provided + if atomic_numbers is not None: + if system_idx is None: + # If batch is not provided, assume all atoms belong to same system + system_idx = torch.zeros( + len(atomic_numbers), dtype=torch.long, device=self.device + ) + + self.setup_from_system_idx(atomic_numbers, system_idx) + + def setup_from_system_idx( + self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor + ) -> None: + """Set up internal state from atomic numbers and system indices. + + Processes the atomic numbers and system indices to prepare the model for + forward pass calculations. Creates the necessary data structures for + batched processing of multiple systems. + + Args: + atomic_numbers (torch.Tensor): Atomic numbers tensor with shape [n_atoms]. + system_idx (torch.Tensor): System indices tensor with shape [n_atoms] + indicating which system each atom belongs to. + """ + self.atomic_numbers = atomic_numbers + self.system_idx = system_idx + self.atomic_types = ChemicalSpeciesToAtomTypeMapper(self.type_names)( + atomic_numbers + ) + + # Determine number of systems and atoms per system + self.n_systems = system_idx.max().item() + 1 + self.total_atoms = atomic_numbers.shape[0] + + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901 + """Compute energies, forces, and stresses for the given atomic systems. + + Processes the provided state information and computes energies, forces, and + stresses using the underlying MACE model. Handles batched calculations for + multiple systems and constructs the necessary neighbor lists. + + Args: + state (SimState | StateDict): State object containing positions, cell, + and other system information. Can be either a SimState object or a + dictionary with the relevant fields. + + Returns: + dict[str, torch.Tensor]: Computed properties: + - 'energy': System energies with shape [n_systems] + - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True + - 'stress': System stresses with shape [n_systems, 3, 3] if + compute_stress=True + + Raises: + ValueError: If atomic numbers are not provided either in the constructor + or in the forward pass, or if provided in both places. + ValueError: If system indices are not provided when needed. + """ + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) + + # Handle input validation for atomic numbers + if sim_state.atomic_numbers is None and not self.atomic_numbers_in_init: + raise ValueError( + "Atomic numbers must be provided in either the constructor or forward." + ) + if sim_state.atomic_numbers is not None and self.atomic_numbers_in_init: + raise ValueError( + "Atomic numbers cannot be provided in both the constructor and forward." + ) + + # Use system_idx from init if not provided + if sim_state.system_idx is None: + if not hasattr(self, "system_idx"): + raise ValueError( + "System indices must be provided if not set during initialization" + ) + sim_state.system_idx = self.system_idx + + # Update batch information if new atomic numbers are provided + if ( + sim_state.atomic_numbers is not None + and not self.atomic_numbers_in_init + and not torch.equal( + sim_state.atomic_numbers, + getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), + ) + ): + self.setup_from_system_idx(sim_state.atomic_numbers, sim_state.system_idx) + + # Process each system's neighbor list separately + edge_indices = [] + shifts_list = [] + unit_shifts_list = [] + offset = 0 + + # TODO (AG): Currently doesn't work for batched neighbor lists + for sys_idx in range(self.n_systems): + system_idx_mask = sim_state.system_idx == sys_idx + # Calculate neighbor list for this system + edge_idx, shifts_idx = self.neighbor_list_fn( + positions=sim_state.positions[system_idx_mask], + cell=sim_state.row_vector_cell[sys_idx], + pbc=sim_state.pbc, + cutoff=self.r_max, + ) + + # Adjust indices for the batch + edge_idx = edge_idx + offset + shifts = torch.mm(shifts_idx, sim_state.row_vector_cell[sys_idx]) + + edge_indices.append(edge_idx) + unit_shifts_list.append(shifts_idx) + shifts_list.append(shifts) + + offset += len(sim_state.positions[system_idx_mask]) + + # Combine all neighbor lists + edge_index = torch.cat(edge_indices, dim=1) + unit_shifts = torch.cat(unit_shifts_list, dim=0) + shifts = torch.cat(shifts_list, dim=0) + atomic_types = ChemicalSpeciesToAtomTypeMapper(self.type_names)( + sim_state.atomic_numbers + ) + + # Get model output + data: dict[str, torch.Tensor] = { + "pos": sim_state.positions, + "cell": sim_state.row_vector_cell, + "batch": sim_state.system_idx, + "num_atoms": sim_state.system_idx.bincount(), + "pbc": torch.tensor( + [sim_state.pbc, sim_state.pbc, sim_state.pbc], + dtype=torch.bool, + device=self.device, + ), + "atomic_numbers": sim_state.atomic_numbers, + "atom_types": atomic_types, + "edge_index": edge_index, + "edge_cell_shift": unit_shifts, + } + out = self.model(data) + results: dict[str, torch.Tensor] = {} + # Process energy + energy = out["total_energy"] + if energy is None: + results["energy"] = torch.zeros(self.n_systems, device=self.device) + else: + results["energy"] = energy.detach() + + # Process forces + if self.compute_forces: + forces = out["forces"] + if forces is not None: + results["forces"] = forces.detach() + + # Process stress + if self.compute_stress: + stress = out["stress"] + if stress is not None: + results["stress"] = stress.detach() + + return results diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index acc6f4a19..93ac4c333 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -2,11 +2,11 @@ This module provides a TorchSim wrapper of the ORB models for computing energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the ORB models library, integrating it with the torch_sim framework to enable seamless +the ORB models library, integrating it with the torch-sim framework to enable seamless simulation of atomistic systems with machine learning potentials. The OrbModel class adapts ORB models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. +allowing them to be used within the broader torch-sim simulation framework. Notes: This implementation requires orb_models to be installed and accessible. @@ -40,7 +40,7 @@ warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2) class OrbModel(ModelInterface): - """ORB model wrapper for torch_sim. + """ORB model wrapper for torch-sim. This class is a placeholder for the OrbModel class. It raises an ImportError if orb_models is not installed. @@ -69,7 +69,7 @@ def state_to_atom_graphs( # noqa: PLR0915 system_config: SystemConfig | None = None, max_num_neighbors: int | None = None, system_id: int | None = None, # noqa: ARG001 - half_supercell: bool = False, + half_supercell: bool | torch.Tensor = False, device: torch.device | None = None, output_dtype: torch.dtype | None = None, graph_construction_dtype: torch.dtype | None = None, @@ -146,22 +146,22 @@ def state_to_atom_graphs( # noqa: PLR0915 n_systems = state.system_idx.max().item() + 1 # Prepare lists to collect data from each system - all_edges = [] - all_vectors = [] - all_unit_shifts = [] - num_edges = [] - node_feats_list = [] - edge_feats_list = [] - graph_feats_list = [] + all_edges: list[torch.Tensor] = [] + all_vectors: list[torch.Tensor] = [] + all_unit_shifts: list[torch.Tensor] = [] + num_edges: list[torch.Tensor] = [] + node_feats_list: list[dict[str, torch.Tensor]] = [] + edge_feats_list: list[dict[str, torch.Tensor]] = [] + graph_feats_list: list[dict[str, torch.Tensor]] = [] # Process each system in a single loop offset = 0 - for i in range(n_systems): - system_mask = state.system_idx == i + for sys_idx in range(n_systems): + system_mask = state.system_idx == sys_idx positions_per_system = positions[system_mask] atomic_numbers_per_system = atomic_numbers[system_mask] atomic_numbers_embedding_per_system = atomic_numbers_embedding[system_mask] - cell_per_system = row_vector_cell[i] + cell_per_system = row_vector_cell[sys_idx] pbc_per_system = pbc # Compute edges directly for this system @@ -172,7 +172,7 @@ def state_to_atom_graphs( # noqa: PLR0915 radius=system_config.radius, max_number_neighbors=max_num_neighbors, edge_method=edge_method, - half_supercell=half_supercell, + half_supercell=bool(half_supercell), device=device, ) @@ -323,18 +323,22 @@ def __init__( self._compute_stress = compute_stress self._compute_forces = compute_forces + # Load model if path is provided + if isinstance(model, str | Path): + loaded_model = torch.load(model, map_location=self._device) + elif isinstance(model, torch.nn.Module): + loaded_model = model + else: + raise TypeError("Model must be a path or torch.nn.Module") + # Set up system configuration - self.system_config = system_config or model.system_config + self.system_config = system_config or loaded_model.system_config self._max_num_neighbors = max_num_neighbors self._edge_method = edge_method self._half_supercell = half_supercell self.conservative = conservative - # Load model if path is provided - if isinstance(model, str | Path): - model = torch.load(model, map_location=self._device) - - self.model = model.to(self._device) + self.model = loaded_model.to(self._device) self.model = self.model.eval() if self.dtype is not None: @@ -380,21 +384,24 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.device != self._device: - state = state.to(self._device) + if sim_state.device != self._device: + sim_state = sim_state.to(self._device) half_supercell = ( - torch.min(torch.det(state.cell)) > 1000 + torch.min(sim_state.volume) > 1000 if self._half_supercell is None else self._half_supercell ) # Convert state to atom graphs batch = state_to_atom_graphs( - state, + sim_state, system_config=self.system_config, max_num_neighbors=self._max_num_neighbors, edge_method=self._edge_method, @@ -405,7 +412,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # Run forward pass predictions = self.model.predict(batch) - results = {} + results: dict[str, torch.Tensor] = {} model_has_direct_heads = ( "forces" in self.model.heads and "stress" in self.model.heads ) diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 3a13a333a..518f3544e 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -139,7 +139,6 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: Returns: A dictionary containing the energy, forces, and stresses """ - # Extract required data from input if isinstance(state, dict): state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) @@ -194,7 +193,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: # Calculate forces and apply cutoff pair_forces = asymmetric_particle_pair_force_jit( - distances, sigma=self.sigma, epsilon=self.epsilon + dr=distances, A=self.epsilon, sigma=self.sigma, beta=self.beta ) pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) @@ -236,21 +235,26 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: Raises: ValueError: If batch cannot be inferred for multi-cell systems. """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.system_idx is None and state.cell.shape[0] > 1: + if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError( "system_idx can only be inferred if there is only one system." ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] + outputs = [ + self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems) + ] properties = outputs[0] # we always return tensors # per atom properties are returned as (atoms, ...) tensors # global properties are returned as shape (..., n) tensors - results = {} + results: dict[str, torch.Tensor] = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index cda8e1833..6a5d81b28 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -33,7 +33,7 @@ warnings.warn(f"SevenNet import failed: {traceback.format_exc()}", stacklevel=2) class SevenNetModel(ModelInterface): - """SevenNet model wrapper for torch_sim. + """SevenNet model wrapper for torch-sim. This class is a placeholder for the SevenNetModel class. It raises an ImportError if sevenn is not installed. @@ -143,11 +143,7 @@ def __init__( if self.dtype is not None: self.model = self.model.to(dtype=self.dtype) - self.implemented_properties = [ - "energy", - "forces", - "stress", - ] + self.implemented_properties = ["energy", "forces", "stress"] def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. @@ -171,24 +167,27 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) - if state.device != self._device: - state = state.to(self._device) + if sim_state.device != self._device: + sim_state = sim_state.to(self._device) # TODO: is this clone necessary? - state = state.clone() + sim_state = sim_state.clone() data_list = [] - for b in range(state.system_idx.max().item() + 1): - system_mask = state.system_idx == b + for sys_idx in range(sim_state.system_idx.max().item() + 1): + system_mask = sim_state.system_idx == sys_idx - pos = state.positions[system_mask] + pos = sim_state.positions[system_mask] # SevenNet uses row vector cell convention for neighbor list - row_vector_cell = state.row_vector_cell[b] - pbc = state.pbc - atomic_numbers = state.atomic_numbers[system_mask] + row_vector_cell = sim_state.row_vector_cell[sys_idx] + pbc = sim_state.pbc + atomic_nums = sim_state.atomic_numbers[system_mask] edge_idx, shifts_idx = self.neighbor_list_fn( positions=pos, @@ -203,17 +202,15 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # vol = vol if vol > 0.0 else torch.tensor(np.finfo(float).eps) data = { - key.NODE_FEATURE: atomic_numbers, - key.ATOMIC_NUMBERS: atomic_numbers.to( - dtype=torch.int64, device=self.device - ), + key.NODE_FEATURE: atomic_nums, + key.ATOMIC_NUMBERS: atomic_nums.to(dtype=torch.int64, device=self.device), key.POS: pos, key.EDGE_IDX: edge_idx, key.EDGE_VEC: edge_vec, key.CELL: row_vector_cell, key.CELL_SHIFT: shifts_idx, key.CELL_VOLUME: vol, - key.NUM_ATOMS: torch.tensor(len(atomic_numbers), device=self.device), + key.NUM_ATOMS: torch.tensor(len(atomic_nums), device=self.device), key.DATA_MODALITY: self.modal, } data[key.INFO] = {} @@ -239,13 +236,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: output = self.model(batched_data) - results = {} + results: dict[str, torch.Tensor] = {} energy = output[key.PRED_TOTAL_ENERGY] if energy is not None: results["energy"] = energy.detach() else: results["energy"] = torch.zeros( - state.system_idx.max().item() + 1, device=self.device + sim_state.system_idx.max().item() + 1, device=self.device ) forces = output[key.PRED_FORCE] diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index e97c90466..a9a1b97ac 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -39,7 +39,6 @@ epsilon_matrix=strength_matrix, ) results = multi_model(sim_state) - """ import torch @@ -58,9 +57,9 @@ def soft_sphere_pair( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, - alpha: torch.Tensor = DEFAULT_ALPHA, + sigma: float | torch.Tensor = DEFAULT_SIGMA, + epsilon: float | torch.Tensor = DEFAULT_EPSILON, + alpha: float | torch.Tensor = DEFAULT_ALPHA, ) -> torch.Tensor: """Calculate pairwise repulsive energies between soft spheres with finite-range interactions. @@ -250,10 +249,7 @@ def __init__( self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) self.alpha = torch.tensor(alpha, dtype=dtype, device=self.device) - def unbatched_forward( - self, - state: ts.SimState, - ) -> dict[str, torch.Tensor]: + def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. @@ -411,20 +407,25 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: forces = results["forces"] # Shape: [n_atoms, 3] ``` """ - if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) # Handle System indices if not provided - if state.system_idx is None and state.cell.shape[0] > 1: + if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError( "system_idx can only be inferred if there is only one system" ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] + outputs = [ + self.unbatched_forward(sim_state[i]) for i in range(sim_state.n_systems) + ] properties = outputs[0] # Combine results - results = {} + results: dict[str, torch.Tensor] = {} for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) @@ -506,7 +507,7 @@ def __init__( epsilon_matrix: torch.Tensor | None = None, alpha_matrix: torch.Tensor | None = None, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype = torch.float64, *, # Force keyword-only arguments pbc: bool = True, compute_forces: bool = True, @@ -713,7 +714,7 @@ def unbatched_forward( # noqa: PLR0915 cell=cell, pbc=self.pbc, cutoff=self.cutoff, - sort_id=False, + sorti=False, ) # Get displacements between neighbor pairs dr_vec, distances = transforms.get_pair_displacements( @@ -862,11 +863,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: "system_idx can only be inferred if there is only one system" ) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] + outputs = [ + self.unbatched_forward(state[sys_idx]) for sys_idx in range(state.n_systems) + ] properties = outputs[0] # Combine results - results = {} + results: dict[str, torch.Tensor] = {} for key in ("stress", "energy", "forces", "energies", "stresses"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index be2d99a8d..7db5213d7 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -5,12 +5,17 @@ implementations of the Metropolis criterion, swap generation, and utility functions for handling permutations in batched systems. -The `swap_monte_carlo` function can be used with `integrate` but if -a trajectory is being reported, the `TorchSimTrajectory.write_state` method -must be called with `variable_masses=True`. +The `swap_mc_init` and `swap_mc_step` functions can be used +with `integrate` but if a trajectory is being reported, the +`TorchSimTrajectory.write_state` method must be called with `variable_masses=True`. + +Examples: + >>> import torch_sim as ts + >>> mc_state = ts.swap_mc_init(model, initial_state, seed=42) + >>> for _ in range(1000): + ... mc_state = ts.swap_mc_step(model, mc_state, kT=0.1 * units.energy) """ -from collections.abc import Callable from dataclasses import dataclass import torch @@ -86,20 +91,20 @@ def generate_swaps( # Process each system - we need this loop because of ragged systems system_starts = system_lengths.cumsum(dim=0) - system_lengths[0] - for b in range(n_systems): + for sys_idx in range(n_systems): # Get global index of selected atom - first_idx = first_index[b, 0].item() + system_starts[b].item() + first_idx = first_index[sys_idx, 0].item() + system_starts[sys_idx].item() first_type = atomic_numbers[first_idx] # Get indices of atoms in this system - system_start = system_starts[b].item() - system_end = system_start + system_lengths[b].item() + system_start = system_starts[sys_idx].item() + system_end = system_start + system_lengths[sys_idx].item() # Create mask for same-type atoms same_type = atomic_numbers[system_start:system_end] == first_type # Zero out weights for same-type atoms (accounting for padding) - weights[b, : len(same_type)][same_type] = 0.0 + weights[sys_idx, : len(same_type)][same_type] = 0.0 second_index = torch.multinomial(weights, 1, replacement=False, generator=generator) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) @@ -123,26 +128,14 @@ def swaps_to_permutation(swaps: torch.Tensor, n_atoms: int) -> torch.Tensor: contains the index of the atom that should be moved to position i """ permutation = torch.arange(n_atoms, device=swaps.device) - permutation[swaps[:, 0]] = swaps[:, 1] - permutation[swaps[:, 1]] = swaps[:, 0] - return permutation - - -def validate_permutation(permutation: torch.Tensor, system_idx: torch.Tensor) -> None: - """Validate that permutations only swap atoms within the same system. - Confirms that no swaps are attempted between atoms in different systems, - which would lead to physically invalid configurations. - - Args: - permutation (torch.Tensor): Permutation tensor of shape [n_atoms] - system_idx (torch.Tensor): system_idx for each atom of shape [n_atoms] + for swap in swaps: + idx1, idx2 = swap + temp = permutation[idx1].clone() + permutation[idx1] = permutation[idx2] + permutation[idx2] = temp - Raises: - ValueError: If any swaps are between atoms in different systems - """ - if not torch.all(system_idx == system_idx[permutation]): - raise ValueError("Swaps must be between atoms in the same system") + return permutation def metropolis_criterion( @@ -174,7 +167,7 @@ def metropolis_criterion( delta_e = energy_new - energy_old # Calculate acceptance probability: min(1, exp(-Ξ”E/kT)) - p_acceptance = torch.exp(-delta_e / kT) + p_acceptance = torch.clamp(torch.exp(-delta_e / kT), max=1.0) # Generate random numbers between 0 and 1 using the generator random_values = torch.rand( @@ -185,108 +178,97 @@ def metropolis_criterion( return random_values < p_acceptance -def swap_monte_carlo( - *, +def swap_mc_init( model: ModelInterface, - kT: float, - seed: int | None = None, -) -> tuple[ - Callable[[SimState], SwapMCState], - Callable[[SwapMCState, float, torch.Generator | None], SwapMCState], -]: - """Initialize a swap Monte Carlo simulation for atomic structure optimization. + state: SimState, +) -> SwapMCState: + """Initialize a swap Monte Carlo state from input data. - Creates and returns functions for initializing the Monte Carlo state and performing - Monte Carlo steps. The simulation uses the Metropolis criterion to accept or reject - proposed swaps based on energy differences. + Creates an initial state for swap Monte Carlo simulations by computing initial + energy and setting up the permutation tracking. The simulation uses the Metropolis + criterion to accept or reject proposed swaps based on energy differences. Make sure that if the trajectory is being reported, the `TorchSimTrajectory.write_state` method is called with `variable_masses=True`. Args: - model (torch.nn.Module): Energy model that takes a SimState and returns a dict - containing 'energy' as a key - kT (float): Temperature of the system in energy units - seed (int | None, optional): Seed for the random number generator. - Defaults to None. + model: Energy model that takes a SimState and returns a dict containing + 'energy' as a key + state: The simulation state to initialize from Returns: - tuple: A tuple containing: - - init_function (Callable): Function to initialize a SwapMCState from a - SimState - - step_function (Callable): Function to perform a single Monte Carlo step + SwapMCState: Initialized state for swap Monte Carlo simulation containing + positions, energy, and permutation tracking Examples: - >>> init_fn, step_fn = swap_monte_carlo(model=energy_model, kT=0.1, seed=42) - >>> mc_state = init_fn(initial_state) + >>> mc_state = swap_monte_carlo_init(model=energy_model, state=initial_state) >>> for _ in range(100): - >>> mc_state = step_fn(mc_state) + >>> mc_state = swap_monte_carlo_step(model, mc_state, kT=0.1) + """ + model_output = model(state) + + return SwapMCState( + positions=state.positions, + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + atomic_numbers=state.atomic_numbers, + system_idx=state.system_idx, + energy=model_output["energy"], + last_permutation=torch.arange(state.n_atoms, device=state.device), + ) + + +def swap_mc_step( + model: ModelInterface, + state: SwapMCState, + *, + kT: float, + seed: int | None = None, +) -> SwapMCState: + """Perform a single swap Monte Carlo step. + + Proposes atom swaps, evaluates the energy change, and uses the Metropolis + criterion to determine whether to accept the move. Rejected moves are reversed. + + Args: + model: Energy model that takes a SimState and returns a dict containing + 'energy' as a key + state: The current Monte Carlo state + kT: Temperature parameter in energy units + seed: Seed for the random number generator. Defaults to None. + + Returns: + SwapMCState: Updated Monte Carlo state after applying the step + + Notes: + The function handles batched systems and ensures that swaps only occur + within the same system. """ + generator = None if seed is not None: generator = torch.Generator(device=model.device) generator.manual_seed(seed) - else: - generator = None - - def init_swap_mc_state(state: SimState) -> SwapMCState: - model_output = model(state) - - return SwapMCState( - positions=state.positions, - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - system_idx=state.system_idx, - energy=model_output["energy"], - last_permutation=torch.arange(state.n_atoms, device=state.device), - ) - - def swap_monte_carlo_step( - state: SwapMCState, - kT: float = kT, - generator: torch.Generator | None = generator, - ) -> SwapMCState: - """Perform a single swap Monte Carlo step. - - Proposes atom swaps, evaluates the energy change, and uses the Metropolis - criterion to determine whether to accept the move. Rejected moves are reversed. - - Args: - state (SwapMCState): The current Monte Carlo state - kT (float, optional): Temperature parameter in energy units. Defaults to the - value specified in the outer function. - generator (torch.Generator | None, optional): Random number generator. - Defaults to None. - - Returns: - SwapMCState: Updated Monte Carlo state after applying the step - - Notes: - The function handles batched systems and ensures that swaps only occur - within the same system. - """ - swaps = generate_swaps(state, generator=generator) - - permutation = swaps_to_permutation(swaps, state.n_atoms) - validate_permutation(permutation, state.system_idx) - - energies_old = state.energy.clone() - state.positions = state.positions[permutation].clone() - - model_output = model(state) - energies_new = model_output["energy"] - - accepted = metropolis_criterion( - energies_new, energies_old, kT, generator=generator - ) - rejected_swaps = swaps[~accepted] - reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) - state.positions = state.positions[reverse_rejected_swaps] - - state.energy = torch.where(accepted, energies_new, energies_old) - state.last_permutation = permutation[reverse_rejected_swaps].clone() - - return state - - return init_swap_mc_state, swap_monte_carlo_step + + swaps = generate_swaps(state, generator=generator) + + permutation = swaps_to_permutation(swaps, state.n_atoms) + + if not torch.all(state.system_idx == state.system_idx[permutation]): + raise ValueError("Swaps must be between atoms in the same system") + + energies_old = state.energy.clone() + state.positions = state.positions[permutation].clone() + + model_output = model(state) + energies_new = model_output["energy"] + + accepted = metropolis_criterion(energies_new, energies_old, kT, generator=generator) + rejected_swaps = swaps[~accepted] + reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) + state.positions = state.positions[reverse_rejected_swaps] + + state.energy = torch.where(accepted, energies_new, energies_old) + state.last_permutation = permutation[reverse_rejected_swaps].clone() + + return state diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 05cba999e..491d72cbb 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -2,9 +2,9 @@ import torch from vesin import NeighborList as VesinNeighborList -from vesin.torch import NeighborList as VesinNeighborList_ts +from vesin.torch import NeighborList as VesinNeighborListTorch -import torch_sim.math as tsm +import torch_sim.math as fm from torch_sim import transforms @@ -170,11 +170,11 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 for c in range(3): if pbc[c]: # (Note: torch.divmod does not exist in older numpy versions) - cell_shift_ic[:, c], bin_index_ic[:, c] = tsm.torch_divmod( + cell_shift_ic[:, c], bin_index_ic[:, c] = fm.torch_divmod( bin_index_ic[:, c], n_bins_c[c] ) else: - bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) # type: ignore[call-overload] + bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) # Convert Cartesian bin index to unique scalar bin index. bin_index_i = bin_index_ic[:, 0] + n_bins_c[0] * ( @@ -193,8 +193,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # homogeneous, i.e. has the same size *max_n_atoms_per_bin* for all bins. # The list is padded with -1 values. atoms_in_bin_ba = -torch.ones( - n_bins, max_n_atoms_per_bin, dtype=torch.long, device=device - ) # type: ignore[call-overload] + n_bins.item(), max_n_atoms_per_bin.item(), dtype=torch.long, device=device + ) for bin_cnt in range(int(max_n_atoms_per_bin.item())): # Create a mask array that identifies the first atom of each bin. mask = torch.cat( @@ -227,8 +227,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # (max_n_atoms_per_bin, max_n_atoms_per_bin), dtype=int # ).reshape(2, -1) atom_pairs_pn = torch.cartesian_prod( - torch.arange(max_n_atoms_per_bin, device=device), # type: ignore[call-overload] - torch.arange(max_n_atoms_per_bin, device=device), # type: ignore[call-overload] + torch.arange(max_n_atoms_per_bin, device=device), + torch.arange(max_n_atoms_per_bin, device=device), ) atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1) @@ -244,9 +244,9 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # that each bin contains exactly max_n_atoms_per_bin atoms. We then throw # out pairs involving pad atoms with atom index -1 below. binz_xyz, biny_xyz, binx_xyz = torch.meshgrid( - torch.arange(n_bins_c[2], device=device), # type: ignore[call-overload] - torch.arange(n_bins_c[1], device=device), # type: ignore[call-overload] - torch.arange(n_bins_c[0], device=device), # type: ignore[call-overload] + torch.arange(n_bins_c[2], device=device), + torch.arange(n_bins_c[1], device=device), + torch.arange(n_bins_c[0], device=device), indexing="ij", ) # The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing @@ -262,9 +262,9 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 for dy in range(-int(neigh_search_y.item()), int(neigh_search_y.item()) + 1): for dx in range(-int(neigh_search_x.item()), int(neigh_search_x.item()) + 1): # Bin index of neighboring bin and shift vector. - shiftx_xyz, neighbinx_xyz = tsm.torch_divmod(binx_xyz + dx, n_bins_c[0]) - shifty_xyz, neighbiny_xyz = tsm.torch_divmod(biny_xyz + dy, n_bins_c[1]) - shiftz_xyz, neighbinz_xyz = tsm.torch_divmod(binz_xyz + dz, n_bins_c[2]) + shiftx_xyz, neighbinx_xyz = fm.torch_divmod(binx_xyz + dx, n_bins_c[0]) + shifty_xyz, neighbiny_xyz = fm.torch_divmod(biny_xyz + dy, n_bins_c[1]) + shiftz_xyz, neighbinz_xyz = fm.torch_divmod(binz_xyz + dz, n_bins_c[2]) neighbin_b = ( neighbinx_xyz + n_bins_c[0] * (neighbiny_xyz + n_bins_c[1] * neighbinz_xyz) @@ -363,10 +363,10 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 cell_shift_vector_n = cell_shift_vector_n[m] # Sort neighbor list. - bin_cnt_sort_idx = torch.argsort(first_at_neigh_tuple_n) - first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt_sort_idx] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt_sort_idx] - cell_shift_vector_n = cell_shift_vector_n[bin_cnt_sort_idx] + bin_cnt = torch.argsort(first_at_neigh_tuple_n) + first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt] + cell_shift_vector_n = cell_shift_vector_n[bin_cnt] # Compute distance vectors. # TODO: Use .T? @@ -508,7 +508,7 @@ def vesin_nl_ts( """Compute neighbor lists using TorchScript-compatible Vesin implementation. This function provides a TorchScript-compatible interface to the Vesin neighbor - list algorithm using VesinNeighborList_ts. It handles both periodic and non-periodic + list algorithm using VesinNeighborListTorch. It handles both periodic and non-periodic systems and returns neighbor pairs along with their periodic shifts. Args: @@ -530,7 +530,7 @@ def vesin_nl_ts( neighbor pair. Notes: - - Uses VesinNeighborList_ts for TorchScript compatibility + - Uses VesinNeighborListTorch for TorchScript compatibility - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - For non-periodic systems (pbc=False), shifts will be zero vectors @@ -542,7 +542,7 @@ def vesin_nl_ts( device = positions.device dtype = positions.dtype - neighbor_list_fn = VesinNeighborList_ts(cutoff.item(), full_list=True) + neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) # Convert tensors to CPU and float64 properly positions_cpu = positions.cpu().to(dtype=torch.float64) @@ -569,12 +569,11 @@ def vesin_nl_ts( def vesin_nl( - *, positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, - cutoff: torch.Tensor, - sort_id: bool = False, + pbc: bool, # noqa: FBT001 + cutoff: float | torch.Tensor, + sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: """Compute neighbor lists using the standard Vesin implementation. @@ -614,7 +613,7 @@ def vesin_nl( device = positions.device dtype = positions.dtype - neighbor_list_fn = VesinNeighborList(cutoff, full_list=True, sorted=sort_id) + neighbor_list_fn = VesinNeighborList((float(cutoff)), full_list=True, sorted=sort_id) # Convert tensors to CPU and float64 without gradients positions_cpu = positions.detach().cpu().to(dtype=torch.float64) @@ -640,7 +639,7 @@ def vesin_nl( def strict_nl( cutoff: float, positions: torch.Tensor, - cell: torch.Tensor | None, + cell: torch.Tensor, mapping: torch.Tensor, system_mapping: torch.Tensor, shifts_idx: torch.Tensor, @@ -658,8 +657,8 @@ def strict_nl( is used to filter the neighbor pairs based on their distances. positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing the positions of the atoms. - cell (torch.Tensor | None): Unit cell vectors according to the row vector - convention. i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + cell (torch.Tensor): Unit cell vectors according to the row vector convention, + i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` for which to compute distances. @@ -689,12 +688,10 @@ def strict_nl( References: - https://github.com/felixmusil/torch_nl """ - if cell is None: + cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) + if cell_shifts is None: d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) else: - cell_shifts = transforms.compute_cell_shifts_strict( - cell, shifts_idx, system_mapping - ) d2 = ( (positions[mapping[0]] - positions[mapping[1]] - cell_shifts) .square() @@ -710,10 +707,10 @@ def strict_nl( @torch.jit.script def torch_nl_n2( - cutoff: float, positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, + cutoff: torch.Tensor, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -754,22 +751,22 @@ def torch_nl_n2( """ n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( - positions, cell, pbc, cutoff, n_atoms, self_interaction + positions, cell, pbc, cutoff.item(), n_atoms, self_interaction ) mapping, mapping_system, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, system_mapping, shifts_idx + cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx ) return mapping, mapping_system, shifts_idx @torch.jit.script def torch_nl_linked_cell( - cutoff: float, positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, + cutoff: torch.Tensor, system_idx: torch.Tensor, - self_interaction: bool = False, # noqa: FBT001, FBT002 + self_interaction: bool = False, # noqa: FBT001, FBT002 (*, not compatible with torch.jit.script) ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using the linked cell algorithm before applying a strict `cutoff`. The atoms positions `pos` @@ -810,10 +807,10 @@ def torch_nl_linked_cell( """ n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( - positions, cell, pbc, cutoff, n_atoms, self_interaction + positions, cell, pbc, cutoff.item(), n_atoms, self_interaction ) mapping, mapping_system, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, system_mapping, shifts_idx + cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx ) return mapping, mapping_system, shifts_idx diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py deleted file mode 100644 index d715c5c56..000000000 --- a/torch_sim/optimizers.py +++ /dev/null @@ -1,1743 +0,0 @@ -"""Optimizers for geometry relaxations. - -This module provides optimization algorithms for atomic structures in a batched format, -enabling efficient relaxation of multiple atomic structures simultaneously. It includes -several gradient-based methods with support for both atomic position and unit cell -optimization. - -The module offers: - -* Standard gradient descent for atomic positions -* Gradient descent with unit cell optimization -* FIRE (Fast Inertial Relaxation Engine) optimization with unit cell parameters -* FIRE optimization with Frechet cell parameterization for improved cell relaxation - -ASE-style FIRE: https://gitlab.com/ase/ase/-/blob/master/ase/optimize/fire.py?ref_type=heads -Velocity Verlet-style FIRE: https://doi.org/10.1103/PhysRevLett.97.170201 - -""" - -import functools -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any, Literal, get_args - -import torch - -import torch_sim.math as tsm -from torch_sim.models.interface import ModelInterface -from torch_sim.state import DeformGradMixin, SimState -from torch_sim.typing import StateDict - - -MdFlavor = Literal["vv_fire", "ase_fire"] -vv_fire_key, ase_fire_key = get_args(MdFlavor) - -_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 -_fire_system_attributes = ( - SimState._system_attributes # noqa: SLF001 - | DeformGradMixin._system_attributes # noqa: SLF001 - | { - "energy", - "stress", - "cell_positions", - "cell_velocities", - "cell_forces", - "cell_masses", - "cell_factor", - "pressure", - "dt", - "alpha", - "n_pos", - } -) -_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001 - "hydrostatic_strain", - "constant_volume", -} - - -@dataclass -class GDState(SimState): - """State class for batched gradient descent optimization. - - This class extends SimState to store and track the evolution of system state - during gradient descent optimization. It maintains the energies and forces - needed to perform gradient-based structure relaxation in a batched manner. - - Attributes: - positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] - masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] - pbc (bool): Whether to use periodic boundary conditions - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - system_idx (torch.Tensor): System indices with shape [n_atoms] - forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_systems] - """ - - forces: torch.Tensor - energy: torch.Tensor - - _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 - _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 - - -def gradient_descent( - model: ModelInterface, *, lr: torch.Tensor | float = 0.01 -) -> tuple[Callable[[StateDict | SimState], GDState], Callable[[GDState], GDState]]: - """Initialize a batched gradient descent optimization. - - Creates an optimizer that performs standard gradient descent on atomic positions - for multiple systems in parallel. The optimizer updates atomic positions based on - forces computed by the provided model. The cell is not optimized with this optimizer. - - Args: - model (torch.nn.Module): Model that computes energies and forces - lr (torch.Tensor | float): Learning rate(s) for optimization. Can be a single - float applied to all systems or a tensor with shape [n_systems] for - system-specific rates - - Returns: - tuple: A pair of functions: - - Initialization function that creates the initial BatchedGDState - - Update function that performs one gradient descent step - - Notes: - The learning rate controls the step size during optimization. Larger values can - speed up convergence but may cause instability in the optimization process. - """ - device, dtype = model.device, model.dtype - - def gd_init( - state: SimState | StateDict, - **kwargs: Any, - ) -> GDState: - """Initialize the batched gradient descent optimization state. - - Args: - state: SimState containing positions, masses, cell, etc. - kwargs: Additional keyword arguments to override state attributes - - Returns: - Initialized BatchedGDState with forces and energy - """ - if not isinstance(state, SimState): - state = SimState(**state) - - atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) - - # Get initial forces and energy from model - model_output = model(state) - energy = model_output["energy"] - forces = model_output["forces"] - - return GDState( - positions=state.positions, - forces=forces, - energy=energy, - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - atomic_numbers=atomic_numbers, - system_idx=state.system_idx, - ) - - def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: - """Perform one gradient descent optimization step to update the - atomic positions. The cell is not optimized. - - Args: - state: Current optimization state - lr: Learning rate(s) to use for this step, overriding the default - - Returns: - Updated GDState after one optimization step - """ - # Get per-atom learning rates by mapping batch learning rates to atoms - if isinstance(lr, float): - lr = torch.full((state.n_systems,), lr, device=device, dtype=dtype) - - atom_lr = lr[state.system_idx].unsqueeze(-1) # shape: (total_atoms, 1) - - # Update positions using forces and per-atom learning rates - state.positions = state.positions + atom_lr * state.forces - - # Get updated forces and energy from model - model_output = model(state) - - # Update state with new forces and energy - state.forces = model_output["forces"] - state.energy = model_output["energy"] - - return state - - return gd_init, gd_step - - -@dataclass(kw_only=True) -class UnitCellGDState(GDState, DeformGradMixin): - """State class for batched gradient descent optimization with unit cell. - - Extends GDState to include unit cell optimization parameters and stress - information. This class maintains the state variables needed for simultaneously - optimizing atomic positions and unit cell parameters. - - Attributes: - # Inherited from GDState - positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] - masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] - pbc (bool): Whether to use periodic boundary conditions - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - system_idx (torch.Tensor): System indices with shape [n_atoms] - forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Potential energy with shape [n_systems] - - # Additional attributes for cell optimization - stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] - reference_cell (torch.Tensor): Reference unit cells with shape - [n_systems, 3, 3] - cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_systems, 1, 1] - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - constant_volume (bool): Whether to maintain constant volume - pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] - cell_positions (torch.Tensor): Cell positions with shape [n_systems, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] - """ - - # Required attributes not in BatchedGDState - reference_cell: torch.Tensor - cell_factor: torch.Tensor - hydrostatic_strain: bool - constant_volume: bool - pressure: torch.Tensor - stress: torch.Tensor - - # Cell attributes - cell_positions: torch.Tensor - cell_forces: torch.Tensor - cell_masses: torch.Tensor - - _system_attributes = ( - GDState._system_attributes # noqa: SLF001 - | DeformGradMixin._system_attributes # noqa: SLF001 - | { - "cell_forces", - "pressure", - "stress", - "cell_positions", - "cell_factor", - "cell_masses", - } - ) - _global_attributes = ( - GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001 - ) - - -def unit_cell_gradient_descent( # noqa: PLR0915, C901 - model: ModelInterface, - *, - positions_lr: float = 0.01, - cell_lr: float = 0.1, - cell_factor: float | torch.Tensor | None = None, - hydrostatic_strain: bool = False, - constant_volume: bool = False, - scalar_pressure: float = 0.0, -) -> tuple[ - Callable[[SimState | StateDict], UnitCellGDState], - Callable[[UnitCellGDState], UnitCellGDState], -]: - """Initialize a batched gradient descent optimization with unit cell parameters. - - Creates an optimizer that performs gradient descent on both atomic positions and - unit cell parameters for multiple systems in parallel. Supports constraints on cell - deformation and applied external pressure. - - This optimizer extends standard gradient descent to simultaneously optimize - both atomic coordinates and unit cell parameters based on forces and stress - computed by the provided model. - - Args: - model (torch.nn.Module): Model that computes energies, forces, and stress - positions_lr (float): Learning rate for atomic positions optimization. Default - is 0.01. - cell_lr (float): Learning rate for unit cell optimization. Default is 0.1. - cell_factor (float | torch.Tensor | None): Scaling factor for cell - optimization. If None, defaults to number of atoms per system - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - (isotropic scaling). Default is False. - constant_volume (bool): Whether to maintain constant volume during optimization - Default is False. - scalar_pressure (float): Applied external pressure in GPa. Default is 0.0. - - Returns: - tuple: A pair of functions: - - Initialization function that creates a BatchedUnitCellGDState - - Update function that performs one gradient descent step with cell - optimization - - Notes: - - To fix the cell and only optimize atomic positions, set both - constant_volume=True and hydrostatic_strain=True - - The cell_factor parameter controls the relative scale of atomic vs cell - optimization - - Larger values for positions_lr and cell_lr can speed up convergence but - may cause instability in the optimization process - """ - device, dtype = model.device, model.dtype - - def gd_init( - state: SimState, - cell_factor: float | torch.Tensor | None = cell_factor, - hydrostatic_strain: bool = hydrostatic_strain, # noqa: FBT001 - constant_volume: bool = constant_volume, # noqa: FBT001 - scalar_pressure: float = scalar_pressure, - ) -> UnitCellGDState: - """Initialize the batched gradient descent optimization state with unit cell. - - Args: - state: Initial system state containing positions, masses, cell, etc. - cell_factor: Scaling factor for cell optimization (default: number of atoms) - hydrostatic_strain: Whether to only allow hydrostatic deformation - constant_volume: Whether to maintain constant volume - scalar_pressure: Applied pressure in GPa - **kwargs: Additional keyword arguments for state initialization - - Returns: - Initial UnitCellGDState with system configuration and forces - """ - if not isinstance(state, SimState): - state = SimState(**state) - - n_systems = state.n_systems - - # Setup cell_factor - if cell_factor is None: - # Count atoms per system - _, counts = torch.unique(state.system_idx, return_counts=True) - cell_factor = counts.to(dtype=dtype) - - if isinstance(cell_factor, int | float): - # Use same factor for all systems - cell_factor = torch.full( - (state.n_systems,), cell_factor, device=device, dtype=dtype - ) - - # Reshape to (n_systems, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_systems, 1, 1) - - scalar_pressure = torch.full( - (state.n_systems, 1, 1), scalar_pressure, device=device, dtype=dtype - ) - # Setup pressure tensor - pressure = scalar_pressure * torch.eye(3, device=device) - - # Get initial forces and energy from model - model_output = model(state) - energy = model_output["energy"] - forces = model_output["forces"] - stress = model_output["stress"] # Already shape: (n_systems, 3, 3) - - # Create cell masses - cell_masses = torch.ones( - (state.n_systems, 3), device=device, dtype=dtype - ) # One mass per cell DOF - - # Get current deformation gradient - cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 - state.row_vector_cell, state.row_vector_cell - ) - - # Calculate cell positions - cell_factor_expanded = cell_factor.expand( - state.n_systems, 3, 1 - ) # shape: (n_systems, 3, 1) - cell_positions = ( - cur_deform_grad.reshape(state.n_systems, 3, 3) * cell_factor_expanded - ) # shape: (n_systems, 3, 3) - - # Calculate virial - volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) - virial = -volumes * (stress + pressure) - - if hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(state.n_systems, -1, -1) - - if constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device - ).unsqueeze(0).expand(state.n_systems, -1, -1) - - return UnitCellGDState( - positions=state.positions, - forces=forces, - energy=energy, - stress=stress, - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - reference_cell=state.cell.clone(), - cell_factor=cell_factor, - hydrostatic_strain=hydrostatic_strain, - constant_volume=constant_volume, - pressure=pressure, - atomic_numbers=state.atomic_numbers, - system_idx=state.system_idx, - cell_positions=cell_positions, - cell_forces=virial / cell_factor, - cell_masses=cell_masses, - ) - - def gd_step( - state: UnitCellGDState, - positions_lr: torch.Tensor = positions_lr, - cell_lr: torch.Tensor = cell_lr, - ) -> UnitCellGDState: - """Perform one gradient descent optimization step with unit cell. - - Updates both atomic positions and cell parameters based on forces and stress. - - Args: - state: Current optimization state - positions_lr: Learning rate for atomic positions optimization - cell_lr: Learning rate for unit cell optimization - - Returns: - Updated UnitCellGDState after one optimization step - """ - # Get dimensions - n_systems = state.n_systems - - # Get per-atom learning rates by mapping system learning rates to atoms - if isinstance(positions_lr, float): - positions_lr = torch.full( - (state.n_systems,), positions_lr, device=device, dtype=dtype - ) - - if isinstance(cell_lr, float): - cell_lr = torch.full((state.n_systems,), cell_lr, device=device, dtype=dtype) - - # Get current deformation gradient - cur_deform_grad = state.deform_grad() - - # Calculate cell positions from deformation gradient - cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) - cell_positions = ( - cur_deform_grad.reshape(n_systems, 3, 3) * cell_factor_expanded - ) # shape: (n_systems, 3, 3) - - # Get per-atom and per-cell learning rates - atom_wise_lr = positions_lr[state.system_idx].unsqueeze(-1) - cell_wise_lr = cell_lr.view(n_systems, 1, 1) # shape: (n_systems, 1, 1) - - # Update atomic and cell positions - atomic_positions_new = state.positions + atom_wise_lr * state.forces - cell_positions_new = cell_positions + cell_wise_lr * state.cell_forces - - # Update cell with deformation gradient - cell_update = cell_positions_new / cell_factor_expanded - new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, cell_update.mT) - - # Update state - state.positions = atomic_positions_new - state.row_vector_cell = new_row_vector_cell - - # Get new forces and energy - model_output = model(state) - - state.energy = model_output["energy"] - state.forces = model_output["forces"] - state.stress = model_output["stress"] - - # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(n_systems, 1, 1) - virial = -volumes * (state.stress + state.pressure) - if state.hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(n_systems, -1, -1) - if state.constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device - ).unsqueeze(0).expand(n_systems, -1, -1) - - # Update cell forces - state.cell_positions = cell_positions_new - state.cell_forces = virial / state.cell_factor - - return state - - return gd_init, gd_step - - -@dataclass(kw_only=True) -class FireState(SimState): - """State information for batched FIRE optimization. - - This class extends SimState to store and track the system state during FIRE - (Fast Inertial Relaxation Engine) optimization. It maintains the atomic - parameters along with their velocities and forces for structure relaxation using - the FIRE algorithm. - - Attributes: - # Inherited from SimState - positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] - masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] - pbc (bool): Whether to use periodic boundary conditions - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - system_idx (torch.Tensor): System indices with shape [n_atoms] - - # Atomic quantities - forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] - velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per system with shape [n_systems] - - # FIRE optimization parameters - dt (torch.Tensor): Current timestep per system with shape [n_systems] - alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] - n_pos (torch.Tensor): Number of positive power steps per system with shape - [n_systems] - - Properties: - momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], - calculated as velocities * masses - """ - - # Required attributes not in SimState - forces: torch.Tensor - energy: torch.Tensor - velocities: torch.Tensor - - # FIRE algorithm parameters - dt: torch.Tensor - alpha: torch.Tensor - n_pos: torch.Tensor - - _atom_attributes = _md_atom_attributes - _system_attributes = ( - SimState._system_attributes # noqa: SLF001 - | { - "energy", - "dt", - "alpha", - "n_pos", - } - ) - - -def fire( - model: ModelInterface, - *, - dt_max: float = 1.0, - dt_start: float = 0.1, - n_min: int = 5, - f_inc: float = 1.1, - f_dec: float = 0.5, - alpha_start: float = 0.1, - f_alpha: float = 0.99, - max_step: float = 0.2, - md_flavor: MdFlavor = ase_fire_key, -) -> tuple[ - Callable[[SimState | StateDict], FireState], - Callable[[FireState], FireState], -]: - """Initialize a batched FIRE optimization. - - Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) - optimization on atomic positions. - - Args: - model (torch.nn.Module): Model that computes energies, forces, and stress - dt_max (float): Maximum allowed timestep - dt_start (float): Initial timestep - n_min (int): Minimum steps before timestep increase - f_inc (float): Factor for timestep increase when power is positive - f_dec (float): Factor for timestep decrease when power is negative - alpha_start (float): Initial velocity mixing parameter - f_alpha (float): Factor for mixing parameter decrease - max_step (float): Maximum distance an atom can move per iteration (default - value is 0.2). Only used when md_flavor='ase_fire'. - md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". - - Returns: - tuple[Callable, Callable]: - - Initialization function that creates a FireState - - Update function (either vv_fire_step or ase_fire_step) that performs - one FIRE optimization step. - - Notes: - - md_flavor="vv_fire" follows the original paper closely, including - integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 - and https://github.com/TorchSim/torch-sim/issues/90#issuecomment-2826179997 - for details. - - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly - in the update steps and does not explicitly use atomic masses in the - velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 - for details. - - FIRE is generally more efficient than standard gradient descent for atomic - structure optimization. - - The algorithm adaptively adjusts step sizes and mixing parameters based - on the dot product of forces and velocities (power). - """ - if md_flavor not in get_args(MdFlavor): - raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") - - device, dtype = model.device, model.dtype - - eps = 1e-8 if dtype == torch.float32 else 1e-16 - - # Setup parameters, added max_step for ASE style - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( - torch.as_tensor(p, device=device, dtype=dtype) - for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) - ) - - def fire_init( - state: SimState | StateDict, - dt_start: float = dt_start, - alpha_start: float = alpha_start, - ) -> FireState: - """Initialize a batched FIRE optimization state. - - Args: - state: Input state as SimState object or state parameter dict - dt_start: Initial timestep per system - alpha_start: Initial mixing parameter per system - - Returns: - FireState with initialized optimization tensors - """ - if not isinstance(state, SimState): - state = SimState(**state) - - # Get dimensions - n_systems = state.n_systems - - # Get initial forces and energy from model - model_output = model(state) - - energy = model_output["energy"] # [n_systems] - forces = model_output["forces"] # [n_total_atoms, 3] - - # Setup parameters - dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) - - return FireState( # Create initial state - # Copy SimState attributes - positions=state.positions.clone(), - masses=state.masses.clone(), - cell=state.cell.clone(), - atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), - pbc=state.pbc, - velocities=torch.full( - state.positions.shape, torch.nan, device=device, dtype=dtype - ), - forces=forces, - energy=energy, - # Optimization attributes - dt=dt_start, - alpha=alpha_start, - n_pos=n_pos, - ) - - step_func_kwargs = dict( - model=model, - dt_max=dt_max, - n_min=n_min, - f_inc=f_inc, - f_dec=f_dec, - alpha_start=alpha_start, - f_alpha=f_alpha, - eps=eps, - is_cell_optimization=False, - is_frechet=False, - ) - if md_flavor == ase_fire_key: - step_func_kwargs["max_step"] = max_step - step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] - return fire_init, functools.partial(step_func, **step_func_kwargs) - - -@dataclass(kw_only=True) -class UnitCellFireState(SimState, DeformGradMixin): - """State information for batched FIRE optimization with unit cell degrees of - freedom. - - This class extends SimState to store and track the system state during FIRE - (Fast Inertial Relaxation Engine) optimization. It maintains both atomic and cell - parameters along with their velocities and forces for structure relaxation using - the FIRE algorithm. - - Attributes: - # Inherited from SimState - positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] - masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] - pbc (bool): Whether to use periodic boundary conditions - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - system_idx (torch.Tensor): System indices with shape [n_atoms] - - # Atomic quantities - forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] - velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - energy (torch.Tensor): Energy per system with shape [n_systems] - stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] - - # Cell quantities - cell_positions (torch.Tensor): Cell positions with shape [n_systems, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_systems, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] - - # Cell optimization parameters - reference_cell (torch.Tensor): Original unit cells with shape [n_systems, 3, 3] - cell_factor (torch.Tensor): Cell optimization scaling factor with shape - [n_systems, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - constant_volume (bool): Whether to maintain constant volume - - # FIRE optimization parameters - dt (torch.Tensor): Current timestep per system with shape [n_systems] - alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] - n_pos (torch.Tensor): Number of positive power steps per system with shape - [n_systems] - - Properties: - momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], - calculated as velocities * masses - """ - - # Required attributes not in SimState - forces: torch.Tensor - energy: torch.Tensor - stress: torch.Tensor - velocities: torch.Tensor - - # Cell attributes - cell_positions: torch.Tensor - cell_velocities: torch.Tensor - cell_forces: torch.Tensor - cell_masses: torch.Tensor - - # Optimization-specific attributes - cell_factor: torch.Tensor - pressure: torch.Tensor - hydrostatic_strain: bool - constant_volume: bool - - # FIRE algorithm parameters - dt: torch.Tensor - alpha: torch.Tensor - n_pos: torch.Tensor - - _atom_attributes = _md_atom_attributes - _system_attributes = _fire_system_attributes - _global_attributes = _fire_global_attributes - - -def unit_cell_fire( - model: ModelInterface, - *, - dt_max: float = 1.0, - dt_start: float = 0.1, - n_min: int = 5, - f_inc: float = 1.1, - f_dec: float = 0.5, - alpha_start: float = 0.1, - f_alpha: float = 0.99, - cell_factor: float | None = None, - hydrostatic_strain: bool = False, - constant_volume: bool = False, - scalar_pressure: float = 0.0, - max_step: float = 0.2, - md_flavor: MdFlavor = ase_fire_key, -) -> tuple[ - Callable[[SimState | StateDict], UnitCellFireState], - Callable[[UnitCellFireState], UnitCellFireState], -]: - """Initialize a batched FIRE optimization with unit cell degrees of freedom. - - Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) - optimization on both atomic positions and unit cell parameters for multiple systems - in parallel. FIRE combines molecular dynamics with velocity damping and adjustment - of time steps to efficiently find local minima. - - Args: - model (torch.nn.Module): Model that computes energies, forces, and stress - dt_max (float): Maximum allowed timestep - dt_start (float): Initial timestep - n_min (int): Minimum steps before timestep increase - f_inc (float): Factor for timestep increase when power is positive - f_dec (float): Factor for timestep decrease when power is negative - alpha_start (float): Initial velocity mixing parameter - f_alpha (float): Factor for mixing parameter decrease - cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per system - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - (isotropic scaling) - constant_volume (bool): Whether to maintain constant volume during optimization - scalar_pressure (float): Applied external pressure in GPa - max_step (float): Maximum allowed step size for ase_fire - md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". - - Returns: - tuple: A pair of functions: - - Initialization function that creates a BatchedUnitCellFireState - - Update function that performs one FIRE optimization step - - Notes: - - md_flavor="vv_fire" follows the original paper closely, including - integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 - and https://github.com/TorchSim/torch-sim/issues/90#issuecomment-2826179997 - for details. - - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly - in the update steps and does not explicitly use atomic masses in the - velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 - for details. - - FIRE is generally more efficient than standard gradient descent for atomic - structure optimization - - The algorithm adaptively adjusts step sizes and mixing parameters based - on the dot product of forces and velocities - - To fix the cell and only optimize atomic positions, set both - constant_volume=True and hydrostatic_strain=True - - The cell_factor parameter controls the relative scale of atomic vs cell - optimization - """ - if md_flavor not in get_args(MdFlavor): - raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") - device, dtype = model.device, model.dtype - - eps = 1e-8 if dtype == torch.float32 else 1e-16 - - # Setup parameters - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( - torch.as_tensor(p, device=device, dtype=dtype) - for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) - ) - - def fire_init( - state: SimState | StateDict, - cell_factor: torch.Tensor | None = cell_factor, - scalar_pressure: float = scalar_pressure, - dt_start: float = dt_start, - alpha_start: float = alpha_start, - ) -> UnitCellFireState: - """Initialize a batched FIRE optimization state with unit cell. - - Args: - state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per system. - Single value or tensor of shape [n_systems]. - scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per system - alpha_start: Initial mixing parameter per system - - Returns: - UnitCellFireState with initialized optimization tensors - """ - if not isinstance(state, SimState): - state = SimState(**state) - - # Get dimensions - n_systems = state.n_systems - - # Setup cell_factor - if cell_factor is None: - # Count atoms per system - _, counts = torch.unique(state.system_idx, return_counts=True) - cell_factor = counts.to(dtype=dtype) - - if isinstance(cell_factor, int | float): - # Use same factor for all systems - cell_factor = torch.full( - (state.n_systems,), cell_factor, device=device, dtype=dtype - ) - - # Reshape to (n_systems, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_systems, 1, 1) - - # Setup pressure tensor - pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_systems, -1, -1) - - # Get initial forces and energy from model - model_output = model(state) - - energy = model_output["energy"] # [n_systems] - forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_systems, 3, 3] - - volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) - virial = -volumes * (stress + pressure) # P is P_ext * I - - if hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(n_systems, -1, -1) - - if constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device - ).unsqueeze(0).expand(n_systems, -1, -1) - - cell_forces = virial / cell_factor - - # Sum masses per system using segment_reduce - # TODO (AG): check this - system_counts = torch.bincount(state.system_idx) - - cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=system_counts - ) # shape: (n_systems,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_systems, 3) - - # Setup parameters - dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) - - return UnitCellFireState( # Create initial state - # Copy SimState attributes - positions=state.positions.clone(), - masses=state.masses.clone(), - cell=state.cell.clone(), - atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), - pbc=state.pbc, - velocities=torch.full( - state.positions.shape, torch.nan, device=device, dtype=dtype - ), - forces=forces, - energy=energy, - stress=stress, - # Cell attributes - cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), - cell_velocities=torch.full( - cell_forces.shape, torch.nan, device=device, dtype=dtype - ), - cell_forces=cell_forces, - cell_masses=cell_masses, - # Optimization attributes - reference_cell=state.cell.clone(), - cell_factor=cell_factor, - pressure=pressure, - dt=dt_start, - alpha=alpha_start, - n_pos=n_pos, - hydrostatic_strain=hydrostatic_strain, - constant_volume=constant_volume, - ) - - step_func_kwargs = dict( - model=model, - dt_max=dt_max, - n_min=n_min, - f_inc=f_inc, - f_dec=f_dec, - alpha_start=alpha_start, - f_alpha=f_alpha, - eps=eps, - is_cell_optimization=True, - is_frechet=False, - ) - if md_flavor == ase_fire_key: - step_func_kwargs["max_step"] = max_step - step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] - return fire_init, functools.partial(step_func, **step_func_kwargs) - - -@dataclass(kw_only=True) -class FrechetCellFIREState(SimState, DeformGradMixin): - """State class for batched FIRE optimization with Frechet cell derivatives. - - This class extends SimState to store and track the system state during FIRE - optimization with matrix logarithm parameterization for cell degrees of freedom. - This parameterization provides improved handling of cell deformations during - optimization. - - Attributes: - # Inherited from SimState - positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] - masses (torch.Tensor): Atomic masses with shape [n_atoms] - cell (torch.Tensor): Unit cell vectors with shape [n_systems, 3, 3] - pbc (bool): Whether to use periodic boundary conditions - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] - system_idx (torch.Tensor): System indices with shape [n_atoms] - - # Additional atomic quantities - forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] - energy (torch.Tensor): Energy per system with shape [n_systems] - velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [n_systems, 3, 3] - - # Optimization-specific attributes - reference_cell (torch.Tensor): Original unit cell with shape [n_systems, 3, 3] - cell_factor (torch.Tensor): Scaling factor for cell optimization with shape - [n_systems, 1, 1] - pressure (torch.Tensor): Applied pressure tensor with shape [n_systems, 3, 3] - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - constant_volume (bool): Whether to maintain constant volume - - # Cell attributes using log parameterization - cell_positions (torch.Tensor): Cell positions using log parameterization with - shape [n_systems, 3, 3] - cell_velocities (torch.Tensor): Cell velocities with shape [n_systems, 3, 3] - cell_forces (torch.Tensor): Cell forces with shape [n_systems, 3, 3] - cell_masses (torch.Tensor): Cell masses with shape [n_systems, 3] - - # FIRE algorithm parameters - dt (torch.Tensor): Current timestep per system with shape [n_systems] - alpha (torch.Tensor): Current mixing parameter per system with shape [n_systems] - n_pos (torch.Tensor): Number of positive power steps per system with shape - [n_systems] - - Properties: - momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], - calculated as velocities * masses - """ - - # Required attributes not in SimState - forces: torch.Tensor - energy: torch.Tensor - velocities: torch.Tensor - stress: torch.Tensor - - # Optimization-specific attributes - cell_factor: torch.Tensor - pressure: torch.Tensor - hydrostatic_strain: bool - constant_volume: bool - - # Cell attributes - cell_positions: torch.Tensor - cell_velocities: torch.Tensor - cell_forces: torch.Tensor - cell_masses: torch.Tensor - - # FIRE algorithm parameters - dt: torch.Tensor - alpha: torch.Tensor - n_pos: torch.Tensor - - _atom_attributes = _md_atom_attributes - _system_attributes = _fire_system_attributes - _global_attributes = _fire_global_attributes - - -def frechet_cell_fire( - model: ModelInterface, - *, - dt_max: float = 1.0, - dt_start: float = 0.1, - n_min: int = 5, - f_inc: float = 1.1, - f_dec: float = 0.5, - alpha_start: float = 0.1, - f_alpha: float = 0.99, - cell_factor: float | None = None, - hydrostatic_strain: bool = False, - constant_volume: bool = False, - scalar_pressure: float = 0.0, - max_step: float = 0.2, - md_flavor: MdFlavor = ase_fire_key, -) -> tuple[ - Callable[[SimState | StateDict], FrechetCellFIREState], - Callable[[FrechetCellFIREState], FrechetCellFIREState], -]: - """Initialize a batched FIRE optimization with Frechet cell parameterization. - - Creates an optimizer that performs FIRE optimization on both atomic positions and - unit cell parameters using matrix logarithm parameterization for cell degrees of - freedom. This parameterization provides forces consistent with numerical - derivatives of the potential energy with respect to cell variables, resulting in - more robust cell optimization. - - Args: - model (torch.nn.Module): Model that computes energies, forces, and stress. - dt_max (float): Maximum allowed timestep - dt_start (float): Initial timestep - n_min (int): Minimum steps before timestep increase - f_inc (float): Factor for timestep increase when power is positive - f_dec (float): Factor for timestep decrease when power is negative - alpha_start (float): Initial velocity mixing parameter - f_alpha (float): Factor for mixing parameter decrease - cell_factor (float | None): Scaling factor for cell optimization. - If None, defaults to number of atoms per system - hydrostatic_strain (bool): Whether to only allow hydrostatic deformation - (isotropic scaling) - constant_volume (bool): Whether to maintain constant volume during optimization - scalar_pressure (float): Applied external pressure in GPa - max_step (float): Maximum allowed step size for ase_fire - md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". - - Returns: - tuple: A pair of functions: - - Initialization function that creates a FrechetCellFIREState - - Update function that performs one FIRE step with Frechet derivatives - - Notes: - - md_flavor="vv_fire" follows the original paper closely, including - integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 - and https://github.com/TorchSim/torch-sim/issues/90#issuecomment-2826179997 - for details. - - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly - in the update steps and does not explicitly use atomic masses in the - velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 - for details. - - Frechet cell parameterization uses matrix logarithm to represent cell - deformations, which provides improved numerical properties for cell - optimization - - This method generally performs better than standard unit cell optimization - for cases with large cell deformations - - To fix the cell and only optimize atomic positions, set both - constant_volume=True and hydrostatic_strain=True - """ - if md_flavor not in get_args(MdFlavor): - raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") - device, dtype = model.device, model.dtype - - eps = 1e-8 if dtype == torch.float32 else 1e-16 - - # Setup parameters - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( - torch.as_tensor(p, device=device, dtype=dtype) - for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) - ) - - def fire_init( - state: SimState | StateDict, - cell_factor: torch.Tensor | None = cell_factor, - scalar_pressure: float = scalar_pressure, - dt_start: float = dt_start, - alpha_start: float = alpha_start, - ) -> FrechetCellFIREState: - """Initialize a batched FIRE optimization state with Frechet cell - parameterization. - - Args: - state: Input state as SimState object or state parameter dict - cell_factor: Cell optimization scaling factor. If None, uses atoms per system. - Single value or tensor of shape [n_systems]. - scalar_pressure: Applied pressure in energy units - dt_start: Initial timestep per system - alpha_start: Initial mixing parameter per system - - Returns: - FrechetCellFIREState with initialized optimization tensors - """ - if not isinstance(state, SimState): - state = SimState(**state) - - # Get dimensions - n_systems = state.n_systems - - # Setup cell_factor - if cell_factor is None: - # Count atoms per system - _, counts = torch.unique(state.system_idx, return_counts=True) - cell_factor = counts.to(dtype=dtype) - - if isinstance(cell_factor, int | float): - # Use same factor for all systems - cell_factor = torch.full( - (state.n_systems,), cell_factor, device=device, dtype=dtype - ) - - # Reshape to (n_systems, 1, 1) for broadcasting - cell_factor = cell_factor.view(n_systems, 1, 1) - - # Setup pressure tensor - pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - pressure = pressure.unsqueeze(0).expand(n_systems, -1, -1) - - # Get initial forces and energy from model - model_output = model(state) - - energy = model_output["energy"] # [n_systems] - forces = model_output["forces"] # [n_total_atoms, 3] - stress = model_output["stress"] # [n_systems, 3, 3] - - # Calculate initial cell positions using matrix logarithm - # Calculate current deformation gradient (identity matrix at start) - cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 - state.row_vector_cell, state.row_vector_cell - ) # shape: (n_systems, 3, 3) - - # For identity matrix, logm gives zero matrix - # Initialize cell positions to zeros - cell_positions = torch.zeros((n_systems, 3, 3), device=device, dtype=dtype) - - # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) - virial = -volumes * (stress + pressure) # P is P_ext * I - - if hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(n_systems, -1, -1) - - if constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device - ).unsqueeze(0).expand(n_systems, -1, -1) - - # Calculate UCF-style cell gradient - ucf_cell_grad = torch.zeros_like(virial) - for b in range(n_systems): - ucf_cell_grad[b] = virial[b] @ torch.linalg.inv(cur_deform_grad[b].T) - # Calculate cell forces using Frechet derivative approach (all zeros for identity) - cell_forces = ucf_cell_grad / cell_factor - - # Sum masses per system - system_counts = torch.bincount(state.system_idx) - cell_masses = torch.segment_reduce( - state.masses, reduce="sum", lengths=system_counts - ) # shape: (n_systems,) - cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_systems, 3) - - # Setup parameters - dt_start = torch.full((n_systems,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_systems,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_systems,), device=device, dtype=torch.int32) - - return FrechetCellFIREState( # Create initial state - # Copy SimState attributes - positions=state.positions, - masses=state.masses, - cell=state.cell, - atomic_numbers=state.atomic_numbers, - system_idx=state.system_idx, - pbc=state.pbc, - velocities=torch.full( - state.positions.shape, torch.nan, device=device, dtype=dtype - ), - forces=forces, - energy=energy, - stress=stress, - # Cell attributes - cell_positions=cell_positions, - cell_velocities=torch.full( - cell_forces.shape, torch.nan, device=device, dtype=dtype - ), - cell_forces=cell_forces, - cell_masses=cell_masses, - # Optimization attributes - reference_cell=state.cell.clone(), - cell_factor=cell_factor, - pressure=pressure, - dt=dt_start, - alpha=alpha_start, - n_pos=n_pos, - hydrostatic_strain=hydrostatic_strain, - constant_volume=constant_volume, - ) - - step_func_kwargs = dict( - model=model, - dt_max=dt_max, - n_min=n_min, - f_inc=f_inc, - f_dec=f_dec, - alpha_start=alpha_start, - f_alpha=f_alpha, - eps=eps, - is_cell_optimization=True, - is_frechet=True, - ) - if md_flavor == ase_fire_key: - step_func_kwargs["max_step"] = max_step - step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] - return fire_init, functools.partial(step_func, **step_func_kwargs) - - -AnyFireCellState = UnitCellFireState | FrechetCellFIREState - - -def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | AnyFireCellState, - model: ModelInterface, - *, - dt_max: torch.Tensor, - n_min: torch.Tensor, - f_inc: torch.Tensor, - f_dec: torch.Tensor, - alpha_start: torch.Tensor, - f_alpha: torch.Tensor, - eps: float, - is_cell_optimization: bool = False, - is_frechet: bool = False, -) -> FireState | AnyFireCellState: - """Perform one Velocity-Verlet based FIRE optimization step. - - Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for - optimizing atomic positions and optionally unit cell parameters in a batched setting. - Uses velocity Verlet integration with adaptive velocity mixing. - - Args: - state: Current optimization state (FireState, UnitCellFireState, or - FrechetCellFIREState). - model: Model that computes energies, forces, and potentially stress. - dt_max: Maximum allowed timestep. - n_min: Minimum steps before timestep increase. - f_inc: Factor for timestep increase when power is positive. - f_dec: Factor for timestep decrease when power is negative. - alpha_start: Initial mixing parameter for velocity update. - f_alpha: Factor for mixing parameter decrease. - eps: Small epsilon value for numerical stability. - is_cell_optimization: Flag indicating if cell optimization is active. - is_frechet: Flag indicating if Frechet cell parameterization is used. - - Returns: - Updated state after performing one VV-FIRE step. - """ - n_systems = state.n_systems - device = state.positions.device - dtype = state.positions.dtype - deform_grad_new: torch.Tensor | None = None - - nan_velocities = state.velocities.isnan().any(dim=1) - if nan_velocities.any(): - state.velocities[nan_velocities] = torch.zeros_like( - state.positions[nan_velocities] - ) - if is_cell_optimization: - if not isinstance(state, get_args(AnyFireCellState)): - raise ValueError( - f"Cell optimization requires one of {get_args(AnyFireCellState)}." - ) - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_positions[nan_cell_velocities] - ) - - alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype - ) - - atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - - if is_cell_optimization: - cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) - state.cell_velocities += ( - 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) - ) - - state.positions = state.positions + atom_wise_dt * state.velocities - - if is_cell_optimization: - cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) - if is_frechet: - if not isinstance(state, expected_cls := FrechetCellFIREState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - cur_deform_grad = state.deform_grad() - deform_grad_log = torch.zeros_like(cur_deform_grad) - for b in range(n_systems): - deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) - - cell_positions_log_scaled = deform_grad_log * cell_factor_reshaped - cell_positions_log_scaled_new = ( - cell_positions_log_scaled + cell_wise_dt * state.cell_velocities - ) - deform_grad_log_new = cell_positions_log_scaled_new / cell_factor_reshaped - deform_grad_new = torch.matrix_exp(deform_grad_log_new) - new_row_vector_cell = torch.bmm( - state.reference_row_vector_cell, deform_grad_new.transpose(1, 2) - ) - state.row_vector_cell = new_row_vector_cell - state.cell_positions = cell_positions_log_scaled_new - else: - if not isinstance(state, expected_cls := UnitCellFireState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - cur_deform_grad = state.deform_grad() - cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) - current_cell_positions_scaled = ( - cur_deform_grad.view(n_systems, 3, 3) * cell_factor_expanded - ) - - cell_positions_scaled_new = ( - current_cell_positions_scaled + cell_wise_dt * state.cell_velocities - ) - cell_update = cell_positions_scaled_new / cell_factor_expanded - new_cell = torch.bmm( - state.reference_row_vector_cell, cell_update.transpose(1, 2) - ) - state.row_vector_cell = new_cell - state.cell_positions = cell_positions_scaled_new - - results = model(state) - state.forces = results["forces"] - state.energy = results["energy"] - - if is_cell_optimization: - state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) - virial = -volumes * (state.stress + state.pressure) - - if state.hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_systems, -1, -1) - if state.constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_systems, -1, -1) - - if is_frechet: - if not isinstance(state, expected_cls := FrechetCellFIREState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - ucf_cell_grad = torch.bmm( - virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2)) - ) - directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) - for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): - directions[idx, mu, nu] = 1.0 - - new_cell_forces = torch.zeros_like(ucf_cell_grad) - for b in range(n_systems): - expm_derivs = torch.stack( - [ - tsm.expm_frechet(deform_grad_log_new[b], direction) - for direction in directions - ] - ) - forces_flat = torch.sum( - expm_derivs * ucf_cell_grad[b].unsqueeze(0), dim=(1, 2) - ) - new_cell_forces[b] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces / cell_factor_reshaped - else: - if not isinstance(state, expected_cls := UnitCellFireState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - state.cell_forces = virial / cell_factor_reshaped - - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - if is_cell_optimization: - state.cell_velocities += ( - 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) - ) - - system_power = tsm.batched_vdot(state.forces, state.velocities, state.system_idx) - - if is_cell_optimization: - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - - # 2. Update dt, alpha, n_pos - pos_mask_system = system_power > 0.0 - neg_mask_system = ~pos_mask_system - - state.n_pos[pos_mask_system] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_system - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha - - state.dt[neg_mask_system] *= f_dec - state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] - state.n_pos[neg_mask_system] = 0 - - v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx - ) - f_scaling_system = tsm.batched_vdot(state.forces, state.forces, state.system_idx) - - if is_cell_optimization: - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) - - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) - v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) - state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), - (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, - torch.zeros_like(state.cell_velocities), - ) - - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) - v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) - - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) # per-atom alpha - state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), - (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, - torch.zeros_like(state.velocities), - ) - - return state - - -def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | AnyFireCellState, - model: ModelInterface, - *, - dt_max: torch.Tensor, - n_min: torch.Tensor, - f_inc: torch.Tensor, - f_dec: torch.Tensor, - alpha_start: torch.Tensor, - f_alpha: torch.Tensor, - max_step: torch.Tensor, - eps: float, - is_cell_optimization: bool = False, - is_frechet: bool = False, -) -> FireState | AnyFireCellState: - """Perform one ASE-style FIRE optimization step. - - Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm - mimicking the ASE implementation. It can handle atomic position optimization - only, or combined position and cell optimization (standard or Frechet). - - Args: - state: Current optimization state. - model: Model that computes energies, forces, and potentially stress. - dt_max: Maximum allowed timestep. - n_min: Minimum steps before timestep increase. - f_inc: Factor for timestep increase when power is positive. - f_dec: Factor for timestep decrease when power is negative. - alpha_start: Initial mixing parameter for velocity update. - f_alpha: Factor for mixing parameter decrease. - max_step: Maximum allowed step size. - eps: Small epsilon value for numerical stability. - is_cell_optimization: Flag indicating if cell optimization is active. - is_frechet: Flag indicating if Frechet cell parameterization is used. - - Returns: - Updated state after performing one ASE-FIRE step. - """ - device, dtype = state.positions.device, state.positions.dtype - n_systems = state.n_systems - - cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError - - nan_velocities = state.velocities.isnan().any(dim=1) - if nan_velocities.any(): - state.velocities[nan_velocities] = torch.zeros_like( - state.positions[nan_velocities] - ) - forces = state.forces - if is_cell_optimization: - if not isinstance(state, get_args(AnyFireCellState)): - raise ValueError( - f"Cell optimization requires one of {get_args(AnyFireCellState)}." - ) - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_positions[nan_cell_velocities] - ) - cur_deform_grad = state.deform_grad() - else: - alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype - ) - - if is_cell_optimization: - cur_deform_grad = state.deform_grad() - forces = torch.bmm( - state.forces.unsqueeze(1), cur_deform_grad[state.system_idx] - ).squeeze(1) - else: - forces = state.forces - - # 1. Current power (FΒ·v) per system (atoms + cell) - system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) - - if is_cell_optimization: - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - - # 2. Update dt, alpha, n_pos - pos_mask_system = system_power > 0.0 - neg_mask_system = ~pos_mask_system - - inc_mask = (state.n_pos > n_min) & pos_mask_system - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha - state.n_pos[pos_mask_system] += 1 - - state.dt[neg_mask_system] *= f_dec - state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] - state.n_pos[neg_mask_system] = 0 - - # 3. Velocity mixing BEFORE acceleration (ASE ordering) - v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx - ) - f_scaling_system = tsm.batched_vdot(forces, forces, state.system_idx) - - if is_cell_optimization: - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) - - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) - v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) - state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), - (1.0 - alpha_cell_bc) * state.cell_velocities - + alpha_cell_bc * v_mixing_cell, - torch.zeros_like(state.cell_velocities), - ) - - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) - v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) # per-atom alpha - state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), - (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, - torch.zeros_like(state.velocities), - ) - - # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) - dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) - dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, state.system_idx) - - if is_cell_optimization: - state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) - - dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) - dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) - dr_cell = torch.where( - dr_scaling_cell > max_step, - max_step * dr_cell / (dr_scaling_cell + eps), - dr_cell, - ) - - dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) - - dr_atom = torch.where( - dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom - ) - - if is_cell_optimization: - state.positions = ( - torch.linalg.solve( - cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) - ).squeeze(-1) - + dr_atom - ) - - if is_frechet: - if not isinstance(state, expected_cls := FrechetCellFIREState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - new_logm_F_scaled = state.cell_positions + dr_cell - state.cell_positions = new_logm_F_scaled - logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) - F_new = torch.matrix_exp(logm_F_new) - new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT) - state.row_vector_cell = new_row_vector_cell - else: - if not isinstance(state, expected_cls := UnitCellFireState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - F_current = state.deform_grad() - cell_factor_exp_mult = state.cell_factor.expand(n_systems, 3, 1) - current_F_scaled = F_current * cell_factor_exp_mult - - F_new_scaled = current_F_scaled + dr_cell - state.cell_positions = F_new_scaled - F_new = F_new_scaled / (cell_factor_exp_mult + eps) - new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT) - state.row_vector_cell = new_row_vector_cell - - state.positions = torch.bmm( - state.positions.unsqueeze(1), F_new[state.system_idx].mT - ).squeeze(1) - else: - state.positions = state.positions + dr_atom - - # 7. Force / stress refresh & new cell forces - results = model(state) - state.forces = results["forces"] - state.energy = results["energy"] - - if is_cell_optimization: - state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) - if torch.any(volumes <= 0): - bad_indices = torch.where(volumes <= 0)[0].tolist() - print( # noqa: T201 - f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " - f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" - ) - - virial = -volumes * (state.stress + state.pressure) - - if state.hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_systems, -1, -1) - - if state.constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device, dtype=dtype - ).unsqueeze(0).expand(n_systems, -1, -1) - - if is_frechet: - if not isinstance(state, expected_cls := FrechetCellFIREState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - if F_new is None: - raise ValueError( - "F_new should be defined for Frechet cell force calculation" - ) - if logm_F_new is None: - raise ValueError( - "logm_F_new should be defined for Frechet cell force calculation" - ) - ucf_cell_grad = torch.bmm( - virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) - ) - directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) - for idx, (mu, nu) in enumerate( - [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)] - ): - directions[idx, mu, nu] = 1.0 - - new_cell_forces_log_space = torch.zeros_like(state.cell_forces) - for b_idx in range(n_systems): - expm_derivs = torch.stack( - [ - tsm.expm_frechet(logm_F_new[b_idx], direction) - for direction in directions - ] - ) - forces_flat = torch.sum( - expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) - ) - new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) - else: - if not isinstance(state, expected_cls := UnitCellFireState): - raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") - state.cell_forces = virial / state.cell_factor - - return state diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py new file mode 100644 index 000000000..9985a1d37 --- /dev/null +++ b/torch_sim/optimizers/__init__.py @@ -0,0 +1,38 @@ +"""Optimizers for geometry relaxations. + +This module provides optimization algorithms for atomic structures in a batched format, +enabling efficient relaxation of multiple atomic structures simultaneously. It uses a +filter-based design where cell optimization constraints and parameterizations are +handled by separate filter functions. +""" + +from collections.abc import Callable +from enum import StrEnum +from typing import Any, Final, Literal, get_args + +from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401 +from torch_sim.optimizers.fire import fire_init, fire_step +from torch_sim.optimizers.gradient_descent import ( + gradient_descent_init, + gradient_descent_step, +) +from torch_sim.optimizers.state import FireState, OptimState # noqa: F401 + + +MdFlavor = Literal["vv_fire", "ase_fire"] +vv_fire_key, ase_fire_key = get_args(MdFlavor) + + +class OptimFlavor(StrEnum): + """Enumeration of the optimization flavors.""" + + gradient_descent = "gradient_descent" + fire = "fire" + + +OPTIM_REGISTRY: Final[ + dict[OptimFlavor, tuple[Callable[..., Any], Callable[..., Any]]] +] = { + OptimFlavor.gradient_descent: (gradient_descent_init, gradient_descent_step), + OptimFlavor.fire: (fire_init, fire_step), +} diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py new file mode 100644 index 000000000..0aaa8fe28 --- /dev/null +++ b/torch_sim/optimizers/cell_filters.py @@ -0,0 +1,381 @@ +"""Cell filters for optimization algorithms. + +This module provides filter functions that can be applied to optimization algorithms +to handle different types of cell optimization constraints and parameterizations. +Filters encapsulate the logic for computing cell forces and updating cell parameters +during optimization. +""" + +from collections.abc import Callable +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Any + +import torch + +import torch_sim.math as fm +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.state import SimState + + +def _setup_cell_factor( + state: SimState, + cell_factor: float | torch.Tensor | None, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Setup cell factor tensor.""" + n_systems = state.n_systems + + if cell_factor is None: + # Count atoms per system + _, counts = torch.unique(state.system_idx, return_counts=True) + cell_factor_tensor = counts.to(dtype=dtype) + elif isinstance(cell_factor, (int, float)): + cell_factor_tensor = torch.full( + (n_systems,), cell_factor, device=device, dtype=dtype + ) + else: + cell_factor_tensor = torch.tensor(cell_factor, device=device, dtype=dtype) + if (n_cft := cell_factor_tensor.numel()) != n_systems: + raise ValueError( + f"cell_factor tensor must have {n_systems} elements, got {n_cft}" + ) + + return cell_factor_tensor.view(n_systems, 1, 1) + + +def _setup_pressure( + n_systems: int, scalar_pressure: float, device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + """Setup pressure tensor.""" + pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) + return pressure.unsqueeze(0).expand(n_systems, -1, -1) + + +def _compute_cell_masses(state: SimState) -> torch.Tensor: + """Compute cell masses by summing atomic masses per system.""" + system_counts = torch.bincount(state.system_idx) + cell_masses = torch.segment_reduce(state.masses, reduce="sum", lengths=system_counts) + return cell_masses.unsqueeze(-1).expand(-1, 3) + + +def _apply_constraints( + virial: torch.Tensor, *, hydrostatic_strain: bool, constant_volume: bool +) -> torch.Tensor: + """Apply hydrostatic strain and constant volume constraints to virial.""" + n_systems, device = virial.shape[0], virial.device + + if hydrostatic_strain: + diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) + virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( + 0 + ).expand(n_systems, -1, -1) + + if constant_volume: + diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) + virial = virial - diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( + 0 + ).expand(n_systems, -1, -1) + + return virial + + +def deform_grad(reference_cell: torch.Tensor, current_cell: torch.Tensor) -> torch.Tensor: + """Compute deformation gradient between current and reference cells.""" + return torch.linalg.solve(reference_cell, current_cell).transpose(-2, -1) + + +def unit_cell_filter_init[T: AnyCellState]( + model: ModelInterface, + state: T, + *, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + **_kwargs: Any, +) -> None: + """Initialize unit cell filter state.""" + device, dtype = model.device, model.dtype + n_systems = state.n_systems + + # Setup parameters + cell_factor_tensor = _setup_cell_factor(state, cell_factor, device, dtype) + pressure = _setup_pressure(n_systems, scalar_pressure, device, dtype) + cell_masses = _compute_cell_masses(state) + + # Get initial model output for stress + model_output = model(state) + + # Calculate initial cell forces + stress = model_output["stress"] + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) + virial = -volumes * (stress + pressure) + virial = _apply_constraints( + virial, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume + ) + cell_forces = virial / cell_factor_tensor + + # Calculate initial cell positions from deformation gradient + # Use current cell as reference (matches reference implementation) + reference_cell = state.cell.clone() + cur_deform_grad = deform_grad(reference_cell.mT, state.row_vector_cell) + cell_factor_expanded = cell_factor_tensor.expand(n_systems, 3, 1) + cell_positions = cur_deform_grad.reshape(n_systems, 3, 3) * cell_factor_expanded + + # update state cell attributes in place + state.cell_factor = cell_factor_tensor + state.pressure = pressure + state.hydrostatic_strain = hydrostatic_strain + state.constant_volume = constant_volume + state.reference_cell = reference_cell + state.cell_positions = cell_positions + state.cell_forces = cell_forces + state.cell_masses = cell_masses + + +def frechet_cell_filter_init[T: AnyCellState]( + model: ModelInterface, + state: T, + *, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + **_kwargs: Any, +) -> None: + """Initialize Frechet cell filter state.""" + device, dtype = model.device, model.dtype + n_systems = state.n_systems + + # Setup parameters + cell_factor_tensor = _setup_cell_factor(state, cell_factor, device, dtype) + pressure = _setup_pressure(n_systems, scalar_pressure, device, dtype) + cell_masses = _compute_cell_masses(state) + + # Initialize cell positions to zeros (log of identity matrix) + cell_positions = torch.zeros((n_systems, 3, 3), device=device, dtype=dtype) + + # Get initial model output for stress + model_output = model(state) + + # Calculate initial cell forces using Frechet approach + stress = model_output["stress"] + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) + virial = -volumes * (stress + pressure) + virial = _apply_constraints( + virial, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume + ) + + # Get current deformation gradient (identity at start for Frechet) + reference_cell = state.cell.clone() + cur_deform_grad = deform_grad(reference_cell.mT, state.row_vector_cell) + ucf_cell_grad = torch.bmm( + virial, torch.linalg.inv(torch.transpose(cur_deform_grad, 1, 2)) + ) + + # For identity matrix (initial state), Frechet derivative gives zero forces + # This matches the reference implementation behavior + cell_forces = ucf_cell_grad / cell_factor_tensor + + # update state cell attributes in place + state.cell_factor = cell_factor_tensor + state.pressure = pressure + state.hydrostatic_strain = hydrostatic_strain + state.constant_volume = constant_volume + state.reference_cell = reference_cell + state.cell_positions = cell_positions + state.cell_forces = cell_forces + state.cell_masses = cell_masses + + +class CellFilter(StrEnum): + """Enumeration of the cell filters.""" + + unit = "unit" + frechet = "frechet" + + +# Filter type definitions for convenience +def unit_cell_update[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: + """Update cell using unit cell approach.""" + if isinstance(cell_lr, (int, float)): + cell_lr = torch.full( + (state.n_systems,), cell_lr, device=state.device, dtype=state.dtype + ) + + # Get current deformation gradient + cur_deform_grad = deform_grad(state.reference_cell.mT, state.row_vector_cell) + + # Calculate cell positions from current deformation gradient + cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) + current_cell_positions = ( + cur_deform_grad.reshape(state.n_systems, 3, 3) * cell_factor_expanded + ) + + # Update cell positions + cell_wise_lr = cell_lr.view(state.n_systems, 1, 1) + cell_step = cell_wise_lr * state.cell_forces + cell_positions_new = current_cell_positions + cell_step + + # Update cell from new positions + cell_update = cell_positions_new / cell_factor_expanded + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, cell_update.transpose(-2, -1) + ) + state.cell_positions = cell_positions_new + + +def frechet_cell_update[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: + """Update cell using frechet approach.""" + if isinstance(cell_lr, (int, float)): + cell_lr = torch.full( + (state.n_systems,), cell_lr, device=state.device, dtype=state.dtype + ) + cell_wise_lr = cell_lr.view(state.n_systems, 1, 1) + + # Compute cell step and update cell positions in log space + cell_step = cell_wise_lr * state.cell_forces + cell_positions_new = state.cell_positions + cell_step + + # Convert from log space to deformation gradient + cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped + deform_grad_new = torch.matrix_exp(deform_grad_log_new) + + # Update cell from new deformation gradient + new_row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) + state.row_vector_cell = new_row_vector_cell + state.cell_positions = cell_positions_new + + +def compute_cell_forces[T: AnyCellState]( + model_output: dict[str, torch.Tensor], state: T +) -> None: + """Compute cell forces for both unit and frechet methods.""" + stress = model_output["stress"] + volumes = torch.linalg.det(state.cell).view(state.n_systems, 1, 1) + virial = -volumes * (stress + state.pressure) + virial = _apply_constraints( + virial, + hydrostatic_strain=state.hydrostatic_strain, + constant_volume=state.constant_volume, + ) + + # Check if this is Frechet method by examining the stored cell filter functions + cell_filter_funcs = getattr(state, "cell_filter", None) + is_frechet = ( + cell_filter_funcs is not None and cell_filter_funcs[0] is frechet_cell_filter_init + ) + + if is_frechet: + # Frechet cell force computation + cur_deform_grad = deform_grad(state.reference_cell.mT, state.row_vector_cell) + ucf_cell_grad = torch.bmm( + virial, torch.linalg.inv(torch.transpose(cur_deform_grad, 1, 2)) + ) + + # Calculate Frechet derivative for non-identity deformation gradients + device, dtype = virial.device, virial.dtype + n_systems = state.n_systems + + # Create direction matrices for Frechet derivative + directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) + for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): + directions[idx, mu, nu] = 1.0 + + # Compute deformation gradient log + deform_grad_log = torch.zeros_like(cur_deform_grad) + for sys_idx in range(n_systems): + deform_grad_log[sys_idx] = fm.matrix_log_33(cur_deform_grad[sys_idx]) + + # Compute Frechet derivatives + cell_forces = torch.zeros_like(ucf_cell_grad) + for sys_idx in range(n_systems): + expm_derivs = torch.stack( + [ + fm.expm_frechet(deform_grad_log[sys_idx], direction)[1] + for direction in directions + ] + ) + forces_flat = torch.sum( + expm_derivs * ucf_cell_grad[sys_idx].unsqueeze(0), dim=(1, 2) + ) + cell_forces[sys_idx] = forces_flat.reshape(3, 3) + + state.cell_forces = cell_forces / state.cell_factor + else: # Unit cell force computation + state.cell_forces = virial / state.cell_factor + + +CellFilterFuncs = tuple[Callable[..., None], Callable[..., None]] # (init_fn, update_fn) + +CELL_FILTER_REGISTRY: dict[CellFilter, CellFilterFuncs] = { + CellFilter.unit: (unit_cell_filter_init, unit_cell_update), + CellFilter.frechet: (frechet_cell_filter_init, frechet_cell_update), +} + + +def get_cell_filter(cell_filter: "CellFilter | tuple") -> CellFilterFuncs: + """Resolve cell filter into a tuple of init and update functions.""" + if isinstance(cell_filter, CellFilter): + return CELL_FILTER_REGISTRY[cell_filter] + if ( + isinstance(cell_filter, tuple) + and len(cell_filter) == 2 + and all(map(callable, cell_filter)) + ): + return cell_filter + raise ValueError( + f"Unknown {cell_filter=}, must be one of {list(map(str, CellFilter))} or " + "2-tuple of callables" + ) + + +@dataclass(kw_only=True) +class CellOptimState(OptimState): + """State class for cell optimization.""" + + reference_cell: torch.Tensor + cell_filter: CellFilterFuncs + cell_factor: torch.Tensor = field(default_factory=lambda: None) + pressure: torch.Tensor = field(default_factory=lambda: None) + hydrostatic_strain: bool = False + constant_volume: bool = False + cell_positions: torch.Tensor = field(default_factory=lambda: None) + cell_forces: torch.Tensor = field(default_factory=lambda: None) + cell_masses: torch.Tensor = field(default_factory=lambda: None) + + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "cell_factor", + "pressure", + "cell_positions", + "cell_forces", + "cell_masses", + "reference_cell", + "cell_filter", + } + _global_attributes = OptimState._global_attributes | { # noqa: SLF001 + "hydrostatic_strain", + "constant_volume", + } + + +@dataclass(kw_only=True) +class CellFireState(CellOptimState, FireState): + """State class for FIRE optimization with cell optimization.""" + + cell_velocities: torch.Tensor = field(default_factory=lambda: None) + + _system_attributes = ( + CellOptimState._system_attributes # noqa: SLF001 + | FireState._system_attributes # noqa: SLF001 + | {"cell_velocities"} + ) + + +AnyCellState = CellFireState | CellOptimState diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py new file mode 100644 index 000000000..147324c36 --- /dev/null +++ b/torch_sim/optimizers/fire.py @@ -0,0 +1,473 @@ +"""FIRE (Fast Inertial Relaxation Engine) optimizer implementation.""" + +from typing import TYPE_CHECKING, Any, get_args + +import torch + +import torch_sim as ts +import torch_sim.math as fm +from torch_sim.optimizers import cell_filters +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import FireState, MdFlavor + from torch_sim.optimizers.cell_filters import ( + CellFilter, + CellFilterFuncs, + CellFireState, + ) + + +def fire_init( + model: "ModelInterface", + state: SimState | StateDict, + *, + dt_start: float = 0.1, + alpha_start: float = 0.1, + md_flavor: "MdFlavor" = "ase_fire", + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "FireState | CellFireState": + """Initialize a FIRE optimization state. + + Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) + optimization on atomic positions and optionally cell parameters. + + Args: + model: Model that computes energies, forces, and optionally stress + state: Input state as SimState object or state parameter dict + dt_start: Initial timestep per system + alpha_start: Initial mixing parameter per system + md_flavor: Optimization flavor ("vv_fire" or "ase_fire") + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization + + Returns: + FireState with initialized optimization tensors + + Notes: + - md_flavor="vv_fire" follows the original paper closely + - md_flavor="ase_fire" mimics the ASE implementation + - Use cell_filter=UNIT_CELL_FILTER or FRECHET_CELL_FILTER for cell optimization + """ + # Import here to avoid circular imports + from torch_sim.optimizers import CellFireState, FireState, MdFlavor + + if md_flavor not in get_args(MdFlavor): + raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") + + tensor_args = dict(device=model.device, dtype=model.dtype) + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output.get("stress") + + # Common state arguments + common_args = { + # Copy SimState attributes + "positions": state.positions.clone(), + "masses": state.masses.clone(), + "cell": state.cell.clone(), + "atomic_numbers": state.atomic_numbers.clone(), + "system_idx": state.system_idx.clone(), + "pbc": state.pbc, + # Optimization state + "forces": forces, + "energy": energy, + "stress": stress, + "velocities": torch.full(state.positions.shape, torch.nan, **tensor_args), + # FIRE parameters + "dt": torch.full((n_systems,), dt_start, **tensor_args), + "alpha": torch.full((n_systems,), alpha_start, **tensor_args), + "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + } + + if cell_filter is not None: # Create cell optimization state + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + common_args["reference_cell"] = state.cell.clone() + common_args["cell_filter"] = cell_filter_funcs + cell_state = CellFireState(**common_args) + + # Initialize cell-specific attributes + init_fn(model, cell_state, **filter_kwargs) + + # Initialize cell velocities after cell_forces is set + cell_state.cell_velocities = torch.full( + cell_state.cell_forces.shape, torch.nan, **tensor_args + ) + + return cell_state + # Create regular FireState without cell optimization + return FireState(**common_args) + + +def fire_step( + model: "ModelInterface", + state: "FireState | CellFireState", + *, + dt_max: float = 1.0, + n_min: int = 5, + f_inc: float = 1.1, + f_dec: float = 0.5, + alpha_start: float = 0.1, + f_alpha: float = 0.99, + max_step: float = 0.2, + md_flavor: "MdFlavor" = "ase_fire", +) -> "FireState | CellFireState": + """Perform one FIRE optimization step. + + Args: + model: Model that computes energies, forces, and optionally stress + state: Current FIRE optimization state + dt_max: Maximum allowed timestep + n_min: Minimum steps before timestep increase + f_inc: Factor for timestep increase when power is positive + f_dec: Factor for timestep decrease when power is negative + alpha_start: Initial velocity mixing parameter + f_alpha: Factor for mixing parameter decrease + max_step: Maximum distance an atom can move per iteration + md_flavor: Optimization flavor ("vv_fire" or "ase_fire") + + Returns: + Updated FireState after one optimization step + """ + # Import here to avoid circular imports + from torch_sim.optimizers import MdFlavor, ase_fire_key, vv_fire_key + + if md_flavor not in get_args(MdFlavor): + raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") + + device, dtype = model.device, model.dtype + eps = 1e-8 if dtype == torch.float32 else 1e-16 + + # Setup parameters + dt_max, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) + ) + + step_func_kwargs = dict( + model=model, + dt_max=dt_max, + n_min=n_min, + f_inc=f_inc, + f_dec=f_dec, + alpha_start=alpha_start, + f_alpha=f_alpha, + eps=eps, + ) + if md_flavor == ase_fire_key: + step_func_kwargs["max_step"] = max_step + + step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] + return step_func(state, **step_func_kwargs) + + +def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 + state: T, + model: "ModelInterface", + *, + dt_max: torch.Tensor, + n_min: torch.Tensor, + f_inc: torch.Tensor, + f_dec: torch.Tensor, + alpha_start: torch.Tensor, + f_alpha: torch.Tensor, + eps: float, +) -> T: + """Perform one Velocity-Verlet based FIRE optimization step.""" + from torch_sim.optimizers import CellFireState + + n_systems, device, dtype = state.n_systems, state.device, state.dtype + + # Initialize velocities if NaN + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) + if isinstance(state, CellFireState): # update velocities to zero if NaN + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] + ) + + alpha_start_system = torch.full( + (n_systems,), alpha_start.item(), device=device, dtype=dtype + ) + + # First half of velocity update + atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + + # Position update + state.positions = state.positions + atom_wise_dt * state.velocities + + # Cell position updates are handled in the velocity update step above + + # Get new forces and energy + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + if "stress" in model_output: + state.stress = model_output["stress"] + + # Update cell forces + if isinstance(state, CellFireState): + cell_filters.compute_cell_forces(model_output, state) + + # Second half of velocity update + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + if isinstance(state, CellFireState): + cell_wise_dt = state.dt.view(n_systems, 1, 1) + state.cell_velocities += ( + 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) + ) + + # Calculate power + system_power = fm.batched_vdot(state.forces, state.velocities, state.system_idx) + if isinstance(state, CellFireState): + system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # Update dt, alpha, n_pos + pos_mask_system = system_power > 0.0 + neg_mask_system = ~pos_mask_system + + state.n_pos[pos_mask_system] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_system + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_system] *= f_dec + state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] + state.n_pos[neg_mask_system] = 0 + + # Velocity mixing + v_scaling_system = fm.batched_vdot( + state.velocities, state.velocities, state.system_idx + ) + f_scaling_system = fm.batched_vdot(state.forces, state.forces, state.system_idx) + + if isinstance(state, CellFireState): + v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + state.cell_velocities = torch.where( + pos_mask_system.view(n_systems, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) + + v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + state.velocities = torch.where( + pos_mask_system[state.system_idx].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) + + return state + + +def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 + state: T, + model: "ModelInterface", + *, + dt_max: torch.Tensor, + n_min: torch.Tensor, + f_inc: torch.Tensor, + f_dec: torch.Tensor, + alpha_start: torch.Tensor, + f_alpha: torch.Tensor, + max_step: torch.Tensor, + eps: float, +) -> T: + """Perform one ASE-style FIRE optimization step.""" + from torch_sim.optimizers import CellFireState + + n_systems, device, dtype = state.n_systems, state.device, state.dtype + + # Initialize velocities if NaN + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.velocities[nan_velocities] + ) + forces = state.forces + if isinstance(state, CellFireState): + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_velocities[nan_cell_velocities] + ) + else: + alpha_start_system = torch.full( + (n_systems,), alpha_start.item(), device=device, dtype=dtype + ) + + # Transform forces for cell optimization + if isinstance(state, CellFireState): + # Get deformation gradient for force transformation + cur_deform_grad = cell_filters.deform_grad( + state.row_vector_cell, + getattr(state, "reference_row_vector_cell", state.row_vector_cell), + ) + forces = torch.bmm( + state.forces.unsqueeze(1), cur_deform_grad[state.system_idx] + ).squeeze(1) + else: + forces = state.forces + + # Calculate power + system_power = fm.batched_vdot(forces, state.velocities, state.system_idx) + if isinstance(state, CellFireState): + system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # Update dt, alpha, n_pos + pos_mask_system = system_power > 0.0 + neg_mask_system = ~pos_mask_system + + inc_mask = (state.n_pos > n_min) & pos_mask_system + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + state.n_pos[pos_mask_system] += 1 + + state.dt[neg_mask_system] *= f_dec + state.alpha[neg_mask_system] = alpha_start_system[neg_mask_system] + state.n_pos[neg_mask_system] = 0 + + # Velocity mixing BEFORE acceleration (ASE ordering) + v_scaling_system = fm.batched_vdot( + state.velocities, state.velocities, state.system_idx + ) + f_scaling_system = fm.batched_vdot(forces, forces, state.system_idx) + + if isinstance(state, CellFireState): + v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + state.cell_velocities = torch.where( + pos_mask_system.view(n_systems, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) + + v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + state.velocities = torch.where( + pos_mask_system[state.system_idx].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) + + # Acceleration (single forward-Euler, no mass for ASE FIRE) + state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) + dr_scaling_system = fm.batched_vdot(dr_atom, dr_atom, state.system_idx) + + if isinstance(state, CellFireState): + state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) + + dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) + dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) + dr_cell = torch.where( + dr_scaling_cell > max_step, + max_step * dr_cell / (dr_scaling_cell + eps), + dr_cell, + ) + + dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) + dr_atom = torch.where( + dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + ) + + # Position updates + if isinstance(state, CellFireState): + # For cell optimization, handle both atomic and cell position updates + # This follows the ASE FIRE implementation pattern + + # Transform atomic positions to fractional coordinates + cur_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) + state.positions = ( + torch.linalg.solve( + cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) + ).squeeze(-1) + + dr_atom + ) + + # Update cell positions directly based on stored cell filter type + if hasattr(state, "cell_filter") and state.cell_filter is not None: + from torch_sim.optimizers.cell_filters import frechet_cell_filter_init + + init_fn, _step_fn = state.cell_filter + is_frechet = init_fn is frechet_cell_filter_init + + # Update cell positions + cell_positions_new = state.cell_positions + dr_cell + state.cell_positions = cell_positions_new + + if is_frechet: # Frechet: convert from log space to deformation gradient + cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped + deform_grad_new = torch.matrix_exp(deform_grad_log_new) + else: # Unit cell: positions are scaled deformation gradient + cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) + deform_grad_new = cell_positions_new / cell_factor_expanded + + # Update cell from deformation gradient + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) + + # Transform positions back to Cartesian + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) + state.positions = torch.bmm( + state.positions.unsqueeze(1), + new_deform_grad[state.system_idx].transpose(-2, -1), + ).squeeze(1) + else: + state.positions = state.positions + dr_atom + + # Get new forces, energy, and stress + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + if "stress" in model_output: + state.stress = model_output["stress"] + + # Update cell forces + if isinstance(state, CellFireState): + cell_filters.compute_cell_forces(model_output, state) + + return state diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py new file mode 100644 index 000000000..8f3678045 --- /dev/null +++ b/torch_sim/optimizers/gradient_descent.py @@ -0,0 +1,129 @@ +"""Gradient descent optimizer implementation.""" + +from typing import TYPE_CHECKING, Any + +import torch + +import torch_sim as ts +from torch_sim.optimizers import cell_filters +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import CellOptimState, OptimState + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs + + +def gradient_descent_init( + model: "ModelInterface", + state: SimState | StateDict, + *, + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "OptimState | CellOptimState": + """Initialize a gradient descent optimization state. + + Args: + model: Model that computes energies, forces, and optionally stress + state: SimState containing positions, masses, cell, etc. + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization + + Returns: + Initialized OptimState with forces, energy, and optional cell state + + Notes: + Use cell_filter=None for position-only optimization. + Use cell_filter=UNIT_CELL_FILTER or FRECHET_CELL_FILTER for cell optimization. + """ + # Import here to avoid circular imports + from torch_sim.optimizers import CellOptimState, OptimState + + if not isinstance(state, SimState): + state = SimState(**state) + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output.get("stress") + + # Common state arguments + common_args = { + "positions": state.positions, + "forces": forces, + "energy": energy, + "stress": stress, + "masses": state.masses, + "cell": state.cell, + "pbc": state.pbc, + "atomic_numbers": state.atomic_numbers, + "system_idx": state.system_idx, + } + + if cell_filter is not None: # Create cell optimization state + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + common_args["reference_cell"] = state.cell.clone() + common_args["cell_filter"] = cell_filter_funcs + cell_state = CellOptimState(**common_args) + + # Initialize cell-specific attributes + init_fn(model, cell_state, **filter_kwargs) + + return cell_state + # Create regular OptimState without cell optimization + return OptimState(**common_args) + + +def gradient_descent_step( + model: "ModelInterface", + state: "OptimState | CellOptimState", + *, + pos_lr: float | torch.Tensor = 0.01, + cell_lr: float | torch.Tensor = 0.1, +) -> "OptimState | CellOptimState": + """Perform one gradient descent optimization step. + + Updates atomic positions and optionally cell parameters based on the filter. + + Args: + model: Model that computes energies, forces, and optionally stress + state: Current optimization state + pos_lr: Learning rate(s) for atomic positions + cell_lr: Learning rate(s) for cell optimization (ignored if no cell filter) + + Returns: + Updated OptimState after one optimization step + """ + from torch_sim.optimizers import CellOptimState + + device, dtype = model.device, model.dtype + + # Get per-atom learning rates + if isinstance(pos_lr, (int, float)): + pos_lr = torch.full((state.n_systems,), pos_lr, device=device, dtype=dtype) + atom_lr = pos_lr[state.system_idx].unsqueeze(-1) + + # Update atomic positions + state.positions = state.positions + atom_lr * state.forces + + # Update cell if using cell optimization + if isinstance(state, CellOptimState): + # Compute cell step and update cell + _init_fn, step_fn = state.cell_filter + step_fn(state, cell_lr) + + # Get updated forces, energy, and stress + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + if stress := model_output.get("stress"): + state.stress = stress + + # Update cell forces + if isinstance(state, CellOptimState): + cell_filters.compute_cell_forces(model_output, state) + + return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py new file mode 100644 index 000000000..6a6d37628 --- /dev/null +++ b/torch_sim/optimizers/state.py @@ -0,0 +1,48 @@ +"""Optimizer state classes.""" + +from dataclasses import dataclass +from typing import Literal, get_args + +import torch + +from torch_sim.state import SimState + + +MdFlavor = Literal["vv_fire", "ase_fire"] +vv_fire_key, ase_fire_key = get_args(MdFlavor) + + +@dataclass(kw_only=True) +class OptimState(SimState): + """Unified state class for optimization algorithms. + + This class extends SimState to store and track the evolution of system state + during optimization. It maintains the energies, forces, and optional cell + optimization state needed for structure relaxation. + """ + + forces: torch.Tensor + energy: torch.Tensor + stress: torch.Tensor + + _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy", "stress"} # noqa: SLF001 + + +@dataclass(kw_only=True) +class FireState(OptimState): + """State class for FIRE optimization. + + Extends OptimState with FIRE-specific parameters for velocity-based optimization. + """ + + velocities: torch.Tensor + dt: torch.Tensor + alpha: torch.Tensor + n_pos: torch.Tensor + + _atom_attributes = OptimState._atom_attributes | {"velocities"} # noqa: SLF001 + _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 + + +# there's no GradientDescentState, it's the same as OptimState diff --git a/torch_sim/properties/correlations.py b/torch_sim/properties/correlations.py index 2dda340d5..9f6869a06 100644 --- a/torch_sim/properties/correlations.py +++ b/torch_sim/properties/correlations.py @@ -16,7 +16,6 @@ .. [1] D. Frenkel and B. Smit, "Understanding molecular simulation: From algorithms to applications", Academic Press, 2002. .. [2] `pwtools: Phonon DOS `_ - """ from collections.abc import Callable @@ -24,6 +23,8 @@ import torch +from torch_sim.elastic import full_3x3_to_voigt_6_stress +from torch_sim.quantities import calc_heat_flux from torch_sim.state import SimState @@ -58,7 +59,7 @@ def append(self, value: torch.Tensor) -> None: """Append a new value to the buffer. Args: - value: New tensor to store + value (torch.Tensor): New tensor to store """ if self.buffer is None: # Initialize buffer shape as first value @@ -74,7 +75,7 @@ def get_array(self) -> torch.Tensor: """Get the current buffer contents as a tensor. Returns: - Tensor containing the buffered data in chron. order + torch.Tensor: Containing the buffered data in chronological order. """ if self.count == 0 or self.buffer is None: return torch.empty(0, device=self.device) @@ -226,7 +227,7 @@ def _compute_correlations(self) -> None: # noqa: C901, PLR0915 # Batch FFT processing n_fft = 2 * data_batch.shape[1] fft_batch = torch.fft.rfft(data_batch, n=n_fft) - power_batch = torch.abs(fft_batch) ** 2 + power_batch = torch.square(torch.abs(fft_batch)) corr_batch = torch.fft.irfft(power_batch)[:, : data_batch.shape[1]] corr_batch = corr_batch.T # Shape: [time_steps, n_dims] @@ -248,7 +249,7 @@ def _compute_correlations(self) -> None: # noqa: C901, PLR0915 # FFT: n=2*len for zero-padding n_fft = 2 * len(dim_data) fft = torch.fft.rfft(dim_data, n=n_fft) - power = torch.abs(fft) ** 2 + power = torch.square(torch.abs(fft)) corr = torch.fft.irfft(power)[: len(dim_data)] if self.normalize and corr[0] > 1e-10: @@ -264,7 +265,7 @@ def _compute_correlations(self) -> None: # noqa: C901, PLR0915 n_fft = 2 * len(dim_data) fft = torch.fft.rfft(dim_data, n=n_fft) - power = torch.abs(fft) ** 2 + power = torch.square(torch.abs(fft)) corr = torch.fft.irfft(power)[: len(dim_data)] if self.normalize and corr[0] > 1e-10: @@ -396,7 +397,7 @@ class VelocityAutoCorrelation: Using ``VelocityAutoCorrelation`` with - :class:`~torch_sim.trajectory.TrajectoryReporter`:: + :class:`~ts.trajectory.TrajectoryReporter`:: # Create VACF calculator vacf_calc = VelocityAutoCorrelation( @@ -473,3 +474,109 @@ def __call__(self, state: SimState, _: Any = None) -> torch.Tensor: def vacf(self) -> torch.Tensor | None: """Current VACF result.""" return self._avg + + +class HeatFluxAutoCorrelation: + """Calculator for heat flux autocorrelation function (HFACF). + + Computes HFACF by averaging over atoms and dimensions, with optional + running average across correlation windows. + + + Using ``HeatFluxAutoCorrelation`` with + :class:`~ts.trajectory.TrajectoryReporter`:: + + # Create HFACF calculator + hfacf_calc = HeatFluxAutoCorrelation( + window_size=100, + device=device, + use_running_average=True, + model=model, + ) + + # Set up trajectory reporter + reporter = TrajectoryReporter( + "simulation_hfacf.h5", + state_frequency=100, + prop_calculators={10: {"hfacf": hfacf_calc}}, + ) + + """ + + def __init__( + self, + *, + model: torch.nn.Module, + window_size: int, + device: torch.device, + use_running_average: bool = True, + normalize: bool = True, + ) -> None: + """Initialize HFACF calculator. + + Args: + window_size: Number of steps in correlation window + device: Computation device + use_running_average: Whether to compute running average across windows + normalize: Whether to normalize correlation functions to [0,1] + model: Model to use for calculating heat flux + """ + # TODO (AG): Figure out how to do it in a more efficient way + self.model = model + self.model.per_atom_stresses = True + self.model.per_atom_energies = True + + self.corr_calc = CorrelationCalculator( + window_size=window_size, + properties={ + "heat_flux": lambda s: calc_heat_flux( + momenta=s.momenta, + masses=s.masses, + velocities=None, + energies=self.model(s)["energies"], + stresses=full_3x3_to_voigt_6_stress(self.model(s)["stresses"]), + batch=s.batch, + is_centroid_stress=False, + is_virial_only=False, + ) + }, + device=device, + normalize=normalize, + ) + self.use_running_average = use_running_average + self._window_count = 0 + self._avg = torch.zeros(window_size, device=device) + + def __call__(self, state: SimState, _: Any = None) -> torch.Tensor: + """Update HFACF with new state. + + Args: + state: Current simulation state + _: Unused model argument (required property calculator interface) + + Returns: + Tensor containing average HFACF + """ + self.corr_calc.update(state) + + if self.corr_calc.buffers["heat_flux"].count == self.corr_calc.window_size: + correlations = self.corr_calc.get_auto_correlations() + # dims: (ndims, 1) + hfacf = torch.mean(correlations["heat_flux"], dim=(1, 2)) + + self._window_count += 1 + + if self.use_running_average: + factor = 1.0 / self._window_count + self._avg += (hfacf - self._avg) * factor + else: + self._avg = hfacf + + self.corr_calc.reset() + + return torch.tensor([self._window_count], device=state.device) + + @property + def hfacf(self) -> torch.Tensor: + """Current HFACF result.""" + return self._avg diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index a1ac0811e..bcb824c02 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -1,24 +1,15 @@ """Functions for computing physical quantities.""" -from typing import cast +from typing import TYPE_CHECKING import torch -from torch_sim.state import SimState from torch_sim.units import MetalUnits -# @torch.jit.script -def count_dof(tensor: torch.Tensor) -> int: - """Count the degrees of freedom in the system. - - Args: - tensor: Tensor to count the degrees of freedom in - - Returns: - Number of degrees of freedom - """ - return tensor.numel() +if TYPE_CHECKING: + from torch_sim.integrators.md import MDState + from torch_sim.optimizers import OptimState # @torch.jit.script @@ -44,17 +35,18 @@ def calc_kT( # noqa: N802 if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: + if momenta is None and velocities is not None: # If velocity provided, calculate mv^2 - velocities = cast("torch.Tensor", velocities) - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # If momentum provided, calculate v^2 = p^2/m^2 - squared_term = (momenta**2) / masses.unsqueeze(-1) + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: # Count total degrees of freedom - dof = count_dof(squared_term) + dof = squared_term.numel() return torch.sum(squared_term) / dof # Sum squared terms for each system flattened_squared = torch.sum(squared_term, dim=-1) @@ -121,10 +113,12 @@ def calc_kinetic_energy( if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: # Using velocities - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: # Using momenta - squared_term = (momenta**2) / masses.unsqueeze(-1) + if momenta is None and velocities is not None: # Using velocities + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # Using momenta + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: return 0.5 * torch.sum(squared_term) @@ -135,7 +129,10 @@ def calc_kinetic_energy( def get_pressure( - stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 + stress: torch.Tensor, + kinetic_energy: float | torch.Tensor, + volume: torch.Tensor, + dim: int = 3, ) -> torch.Tensor: """Compute the pressure from the stress tensor. @@ -145,7 +142,133 @@ def get_pressure( return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) -def systemwise_max_force(state: SimState) -> torch.Tensor: +def calc_heat_flux( + momenta: torch.Tensor | None, + masses: torch.Tensor, + velocities: torch.Tensor | None, + energies: torch.Tensor, + stresses: torch.Tensor, + batch: torch.Tensor | None = None, + *, # Force keyword arguments for booleans + is_centroid_stress: bool = False, + is_virial_only: bool = False, +) -> torch.Tensor: + r"""Calculate the heat flux vector. + + Computes the microscopic heat flux, :math:`\mathbf{J}` + defined as: + + .. math:: + \mathbf{J} = \mathbf{J}^c + \mathbf{J}^v + + where the convective part :math:`\mathbf{J}^c` and virial part + :math:`\mathbf{J}^v` are: + + .. math:: + \mathbf{J}^c &= \sum_i \epsilon_i \mathbf{v}_i \\ + \mathbf{J}^v &= \sum_i \sum_j \mathbf{S}_{ij} \cdot \mathbf{v}_j + + where :math:`\epsilon_i` is the per-atom energy (p.e. + k.e.), + :math:`\mathbf{v}_i` is velocity, and :math:`\mathbf{S}_{ij}` is the + per-atom stress tensor. + + Args: + momenta: Particle momenta, shape (n_particles, n_dim) + masses: Particle masses, shape (n_particles,) + velocities: Particle velocities, shape (n_particles, n_dim) + energies: Per-atom energies (p.e. + k.e.), shape (n_particles,) + stresses: Per-atom stress tensor components: + - If is_centroid_stress=False: shape (n_particles, 6) for + :math:`[\sigma_{xx}, \sigma_{yy}, \sigma_{zz}, + \sigma_{xy}, \sigma_{xz}, \sigma_{yz}]` + - If is_centroid_stress=True: shape (n_particles, 9) for + :math:`[\mathbf{r}_{ix}f_{ix}, \mathbf{r}_{iy}f_{iy}, + \mathbf{r}_{iz}f_{iz}, \mathbf{r}_{ix}f_{iy}, + \mathbf{r}_{ix}f_{iz}, \mathbf{r}_{iy}f_{iz}, + \mathbf{r}_{iy}f_{ix}, \mathbf{r}_{iz}f_{ix}, + \mathbf{r}_{iz}f_{iy}]` + batch: Optional tensor indicating system membership + is_centroid_stress: Whether stress uses centroid formulation + is_virial_only: If True, returns only virial part :math:`\mathbf{J}^v` + + Returns: + Heat flux vector of shape (3,) or (n_systems, 3) + """ + if momenta is not None and velocities is not None: + raise ValueError("Must pass either momenta or velocities, not both") + if momenta is None and velocities is None: + raise ValueError("Must pass either momenta or velocities") + + # Deduce velocities + if velocities is None: + velocities = momenta / masses.unsqueeze(-1) + + convective_flux = energies.unsqueeze(-1) * velocities + + # Calculate virial flux + if is_centroid_stress: + # Centroid formulation: r_i[x,y,z] . f_i[x,y,z] + virial_x = -( + stresses[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x + + stresses[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y + + stresses[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z + ) + virial_y = -( + stresses[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x + + stresses[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y + + stresses[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z + ) + virial_z = -( + stresses[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x + + stresses[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y + + stresses[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z + ) + else: + # Standard stress tensor components + virial_x = -( + stresses[:, 0] * velocities[:, 0] # s_xx.v_x + + stresses[:, 3] * velocities[:, 1] # s_xy.v_y + + stresses[:, 4] * velocities[:, 2] # s_xz.v_z + ) + virial_y = -( + stresses[:, 3] * velocities[:, 0] # s_xy.v_x + + stresses[:, 1] * velocities[:, 1] # s_yy.v_y + + stresses[:, 5] * velocities[:, 2] # s_yz.v_z + ) + virial_z = -( + stresses[:, 4] * velocities[:, 0] # s_xz.v_x + + stresses[:, 5] * velocities[:, 1] # s_yz.v_y + + stresses[:, 2] * velocities[:, 2] # s_zz.v_z + ) + + virial_flux = torch.stack([virial_x, virial_y, virial_z], dim=-1) + + if batch is None: + # All atoms + virial_sum = torch.sum(virial_flux, dim=0) + if is_virial_only: + return virial_sum + conv_sum = torch.sum(convective_flux, dim=0) + return conv_sum + virial_sum + + # All atoms in each system + n_systems = int(torch.max(batch) + 1) + virial_sum = torch.zeros( + (n_systems, 3), device=velocities.device, dtype=velocities.dtype + ) + virial_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), virial_flux) + + if is_virial_only: + return virial_sum + + conv_sum = torch.zeros( + (n_systems, 3), device=velocities.device, dtype=velocities.dtype + ) + conv_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), convective_flux) + return conv_sum + virial_sum + + +def system_wise_max_force[T: MDState | OptimState](state: T) -> torch.Tensor: """Compute the maximum force per system. Args: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 14164375d..3eb8a0aab 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -9,46 +9,45 @@ from collections.abc import Callable from dataclasses import dataclass from itertools import chain -from typing import Any +from typing import TYPE_CHECKING, Any import torch from tqdm import tqdm +import torch_sim as ts from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher +from torch_sim.integrators import INTEGRATOR_REGISTRY, MdFlavor +from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import ( - FireState, - FrechetCellFIREState, - UnitCellFireState, - UnitCellGDState, -) -from torch_sim.quantities import calc_kinetic_energy, calc_kT, systemwise_max_force -from torch_sim.state import SimState, concatenate_states, initialize_state +from torch_sim.optimizers import OPTIM_REGISTRY, FireState, OptimFlavor, OptimState +from torch_sim.state import SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem +if TYPE_CHECKING: + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs + + def _configure_reporter( - trajectory_reporter: TrajectoryReporter | dict | None, + trajectory_reporter: TrajectoryReporter | dict, *, state_kwargs: dict | None = None, properties: list[str] | None = None, prop_frequency: int = 10, state_frequency: int = 100, ) -> TrajectoryReporter: - if trajectory_reporter is None: - return None if isinstance(trajectory_reporter, TrajectoryReporter): return trajectory_reporter possible_properties = { "potential_energy": lambda state: state.energy, "forces": lambda state: state.forces, "stress": lambda state: state.stress, - "kinetic_energy": lambda state: calc_kinetic_energy( + "kinetic_energy": lambda state: ts.calc_kinetic_energy( velocities=state.velocities, masses=state.masses ), - "temperature": lambda state: calc_kT( + "temperature": lambda state: ts.calc_kT( velocities=state.velocities, masses=state.masses ), } @@ -56,7 +55,7 @@ def _configure_reporter( prop_calculators = { prop: calculator for prop, calculator in possible_properties.items() - if prop in properties + if prop in (properties or ()) } # ordering is important to ensure we can override defaults @@ -73,6 +72,7 @@ def _configure_reporter( def _configure_batches_iterator( model: ModelInterface, state: SimState, + *, autobatcher: BinningAutoBatcher | bool, ) -> BinningAutoBatcher | list[tuple[SimState, list[int]]]: """Create a batches iterator for the integrate function. @@ -89,49 +89,47 @@ def _configure_batches_iterator( if autobatcher is True: autobatcher = BinningAutoBatcher( model=model, - return_indices=True, max_memory_padding=0.9, ) autobatcher.load_states(state) batches = autobatcher elif isinstance(autobatcher, BinningAutoBatcher): autobatcher.load_states(state) - autobatcher.return_indices = True batches = autobatcher elif autobatcher is False: batches = [(state, [])] else: + autobatcher_type = type(autobatcher).__name__ raise TypeError( - f"Invalid autobatcher type: {type(autobatcher).__name__}, " - "must be bool or BinningAutoBatcher." + f"Invalid {autobatcher_type=}, must be bool or BinningAutoBatcher." ) return batches -def integrate( +def integrate[T: SimState]( # noqa: C901 system: StateLike, model: ModelInterface, *, - integrator: Callable, + integrator: MdFlavor | tuple[Callable[..., T], Callable[..., T]], n_steps: int, temperature: float | list | torch.Tensor, timestep: float, trajectory_reporter: TrajectoryReporter | dict | None = None, autobatcher: BinningAutoBatcher | bool = False, pbar: bool | dict[str, Any] = False, - **integrator_kwargs: dict, -) -> SimState: + **integrator_kwargs: Any, +) -> T: """Simulate a system using a model and integrator. Args: system (StateLike): Input system to simulate model (ModelInterface): Neural network model module - integrator (Callable): Integration algorithm function + integrator (MdFlavor | tuple): Either a key from MdFlavor or a tuple of + (init_func, step_func) functions. n_steps (int): Number of integration steps temperature (float | ArrayLike): Temperature or array of temperatures for each step timestep (float): Integration time step - integrator_kwargs: Additional keyword arguments for integrator trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking trajectory. If a dict, will be passed to the TrajectoryReporter constructor. @@ -142,33 +140,49 @@ def integrate( **integrator_kwargs: Additional keyword arguments for integrator init function Returns: - SimState: Final state after integration + T: Final state after integration """ unit_system = UnitSystem.metal # create a list of temperatures - temps = temperature if hasattr(temperature, "__iter__") else [temperature] * n_steps + temps = ( + [temperature] * n_steps + if isinstance(temperature, (float, int)) + else list(temperature) + ) if len(temps) != n_steps: raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") - # initialize the state - state: SimState = initialize_state(system, model.device, model.dtype) - dtype, device = state.dtype, state.device + initial_state: SimState = ts.initialize_state(system, model.device, model.dtype) + dtype, device = initial_state.dtype, initial_state.device kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature - init_fn, update_fn = integrator( - model=model, - kT=kTs[0], - dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), - **integrator_kwargs, - ) + dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device) + + # Handle both string names and direct function tuples + if isinstance(integrator, MdFlavor): + init_func, step_func = INTEGRATOR_REGISTRY[integrator] + elif ( + isinstance(integrator, tuple) + and len(integrator) == 2 + and {*map(callable, integrator)} == {True} + ): + init_func, step_func = integrator + else: + raise ValueError( + f"integrator must be key from MdFlavor or a tuple of " + f"(init_func, step_func), got {type(integrator)}" + ) # batch_iterator will be a list if autobatcher is False - batch_iterator = _configure_batches_iterator(model, state, autobatcher) - trajectory_reporter = _configure_reporter( - trajectory_reporter, - properties=["kinetic_energy", "potential_energy", "temperature"], + batch_iterator = _configure_batches_iterator( + model, initial_state, autobatcher=autobatcher ) + if trajectory_reporter is not None: + trajectory_reporter = _configure_reporter( + trajectory_reporter, + properties=["kinetic_energy", "potential_energy", "temperature"], + ) - final_states: list[SimState] = [] + final_states: list[T] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None tqdm_pbar = None @@ -176,13 +190,15 @@ def integrate( pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Integrate") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) + tqdm_pbar = tqdm(total=initial_state.n_systems, **pbar_kwargs) + # Handle both BinningAutoBatcher and list of tuples for state, system_indices in batch_iterator: - state = init_fn(state) + # Pass correct parameters based on integrator type + state = init_func(model=model, state=state, kT=kTs[0], dt=dt, **integrator_kwargs) # set up trajectory reporters - if autobatcher and trajectory_reporter: + if autobatcher and trajectory_reporter is not None and og_filenames is not None: # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( filenames=[og_filenames[i] for i in system_indices] @@ -190,7 +206,7 @@ def integrate( # run the simulation for step in range(1, n_steps + 1): - state = update_fn(state, kT=kTs[step - 1]) + state = step_func(model, state, dt=dt, kT=kTs[step - 1], **integrator_kwargs) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) @@ -205,7 +221,7 @@ def integrate( if isinstance(batch_iterator, BinningAutoBatcher): reordered_states = batch_iterator.restore_original_order(final_states) - return concatenate_states(reordered_states) + return ts.concatenate_states(reordered_states) return state @@ -213,6 +229,7 @@ def integrate( def _configure_in_flight_autobatcher( model: ModelInterface, state: SimState, + *, autobatcher: InFlightAutoBatcher | bool, max_attempts: int, # TODO: change name to max_iterations ) -> InFlightAutoBatcher: @@ -230,7 +247,6 @@ def _configure_in_flight_autobatcher( """ # load and properly configure the autobatcher if isinstance(autobatcher, InFlightAutoBatcher): - autobatcher.return_indices = True autobatcher.max_attempts = max_attempts elif isinstance(autobatcher, bool): if autobatcher: @@ -241,7 +257,6 @@ def _configure_in_flight_autobatcher( max_memory_scaler = state.n_atoms + 1 autobatcher = InFlightAutoBatcher( model=model, - return_indices=True, max_memory_scaler=max_memory_scaler, memory_scales_with=memory_scales_with, max_iterations=max_attempts, @@ -254,41 +269,41 @@ def _configure_in_flight_autobatcher( return autobatcher -def _chunked_apply( - fn: Callable, +def _chunked_apply[T: SimState]( + fn: Callable[..., T], states: SimState, model: ModelInterface, - **batcher_kwargs: dict, -) -> SimState: + init_kwargs: Any, + **batcher_kwargs: Any, +) -> T: """Apply a function to a state in chunks. This prevents us from running out of memory when applying a function to a large number of states. Args: - fn (Callable): The function to apply - states (SimState): The state to apply the function to + fn (Callable): The state function to apply + states (SimState): The states to apply the function to model (ModelInterface): The model to use for the autobatcher + init_kwargs (Any): Unpacked into state init function. **batcher_kwargs: Additional keyword arguments for the autobatcher Returns: A state with the function applied """ - autobatcher = BinningAutoBatcher( - model=model, - return_indices=False, - **batcher_kwargs, - ) + autobatcher = BinningAutoBatcher(model=model, **batcher_kwargs) autobatcher.load_states(states) initialized_states = [] - initialized_states = [fn(system) for system in autobatcher] + initialized_states = [ + fn(model=model, state=system, **init_kwargs) for system, _indices in autobatcher + ] ordered_states = autobatcher.restore_original_order(initialized_states) - return concatenate_states(ordered_states) + return ts.concatenate_states(ordered_states) -def generate_force_convergence_fn( +def generate_force_convergence_fn[T: MDState | FireState]( force_tol: float = 1e-1, *, include_cell_forces: bool = True ) -> Callable: """Generate a force-based convergence function for the convergence_fn argument @@ -305,7 +320,7 @@ def generate_force_convergence_fn( """ def convergence_fn( - state: SimState, + state: T, last_energy: torch.Tensor | None = None, # noqa: ARG001 ) -> torch.Tensor: """Check if the system has converged. @@ -314,7 +329,7 @@ def convergence_fn( torch.Tensor: Boolean tensor of shape (n_systems,) indicating convergence status for each system. """ - force_conv = systemwise_max_force(state) < force_tol + force_conv = ts.system_wise_max_force(state) < force_tol if include_cell_forces: if (cell_forces := getattr(state, "cell_forces", None)) is None: @@ -328,7 +343,9 @@ def convergence_fn( return convergence_fn -def generate_energy_convergence_fn(energy_tol: float = 1e-3) -> Callable: +def generate_energy_convergence_fn[T: MDState | OptimState]( + energy_tol: float = 1e-3, +) -> Callable[[T, torch.Tensor | None], torch.Tensor]: """Generate an energy-based convergence function for the convergence_fn argument of the optimize function. @@ -336,14 +353,11 @@ def generate_energy_convergence_fn(energy_tol: float = 1e-3) -> Callable: energy_tol (float): Energy tolerance for convergence Returns: - Convergence function that takes a state and last energy and - returns a systemwise boolean function + Callable[[T, torch.Tensor | None], torch.Tensor]: Convergence function that takes + a state and last energy and returns a systemwise boolean function. """ - def convergence_fn( - state: SimState, - last_energy: torch.Tensor | None = None, - ) -> torch.Tensor: + def convergence_fn(state: T, last_energy: torch.Tensor | None = None) -> torch.Tensor: """Check if the system has converged. Returns: @@ -355,29 +369,31 @@ def convergence_fn( return convergence_fn -def optimize( # noqa: C901 +def optimize[T: OptimState]( # noqa: C901, PLR0915 system: StateLike, model: ModelInterface, *, - optimizer: Callable, - convergence_fn: Callable | None = None, + optimizer: OptimFlavor | tuple[Callable[..., T], Callable[..., T]], + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + convergence_fn: Callable[[T, torch.Tensor | None], torch.Tensor] | None = None, trajectory_reporter: TrajectoryReporter | dict | None = None, autobatcher: InFlightAutoBatcher | bool = False, max_steps: int = 10_000, steps_between_swaps: int = 5, pbar: bool | dict[str, Any] = False, - **optimizer_kwargs: dict, -) -> SimState: + **optimizer_kwargs: Any, +) -> T: """Optimize a system using a model and optimizer. Args: system (StateLike): Input system to optimize (ASE Atoms, Pymatgen Structure, or SimState) model (ModelInterface): Neural network model module - optimizer (Callable): Optimization algorithm function + optimizer (OptimFlavor | tuple): Optimization algorithm function convergence_fn (Callable | None): Condition for convergence, should return a boolean tensor of length n_systems - optimizer_kwargs: Additional keyword arguments for optimizer init function + cell_filter (CellFilter | CellFilterFuncs | None): Optional cell filter to use. + If None, the system will not optimize the cell. trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking optimization trajectory. If a dict, will be passed to the TrajectoryReporter constructor. @@ -394,43 +410,58 @@ def optimize( # noqa: C901 pbar (bool | dict[str, Any], optional): Show a progress bar. Only works with an autobatcher in interactive shell. If a dict is given, it's passed to `tqdm` as kwargs. + optimizer_kwargs: Additional keyword arguments for optimizer init function Returns: - Optimized system state + T: Optimized system state """ # create a default convergence function if one is not provided # TODO: document this behavior if convergence_fn is None: convergence_fn = generate_energy_convergence_fn(energy_tol=1e-3) - # initialize the state - state: SimState = initialize_state(system, model.device, model.dtype) - init_fn, update_fn = optimizer(model=model, **optimizer_kwargs) + initial_state = ts.initialize_state(system, model.device, model.dtype) + if isinstance(optimizer, OptimFlavor): + init_fn, step_fn = OPTIM_REGISTRY[optimizer] + elif ( + isinstance(optimizer, tuple) + and len(optimizer) == 2 + and {*map(callable, optimizer)} == {True} + ): + init_fn, step_fn = optimizer + else: + optimizer_type = type(optimizer).__name__ + raise TypeError( + f"Invalid {optimizer_type=}, must be key from OptimFlavor or a tuple of " + f"(init_func, step_func), got {optimizer_type}" + ) max_attempts = max_steps // steps_between_swaps autobatcher = _configure_in_flight_autobatcher( - model, state, autobatcher, max_attempts + model, initial_state, autobatcher=autobatcher, max_attempts=max_attempts ) - if not isinstance( - state, (FireState, UnitCellFireState, UnitCellGDState, FrechetCellFIREState) - ): + if isinstance(initial_state, OptimState): + state = initial_state + else: state = _chunked_apply( init_fn, - state, + initial_state, model, + init_kwargs=dict(cell_filter=cell_filter), max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, ) autobatcher.load_states(state) - trajectory_reporter = _configure_reporter( - trajectory_reporter, - properties=["potential_energy"], - ) + if trajectory_reporter is not None: + trajectory_reporter = _configure_reporter( + trajectory_reporter, properties=["potential_energy"] + ) step: int = 1 last_energy = None - all_converged_states, convergence_tensor = [], None + all_converged_states: list[T] = [] + convergence_tensor = None og_filenames = trajectory_reporter.filenames if trajectory_reporter else None tqdm_pbar = None @@ -438,22 +469,32 @@ def optimize( # noqa: C901 pbar_kwargs = pbar if isinstance(pbar, dict) else {} pbar_kwargs.setdefault("desc", "Optimize") pbar_kwargs.setdefault("disable", None) - tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) - - while (result := autobatcher.next_batch(state, convergence_tensor))[0] is not None: - state, converged_states, system_indices = result + tqdm_pbar = tqdm(total=initial_state.n_systems, **pbar_kwargs) + + while True: + result = autobatcher.next_batch(state, convergence_tensor) + if result[0] is None: + # All states have converged, collect the final converged states + all_converged_states.extend(result[1]) + break + state, converged_states = result all_converged_states.extend(converged_states) # need to update the trajectory reporter if any states have converged - if trajectory_reporter and (step == 1 or len(converged_states) > 0): + if ( + trajectory_reporter is not None + and og_filenames is not None + and (step == 1 or len(converged_states) > 0) + ): trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[i] for i in system_indices] + filenames=[og_filenames[i] for i in autobatcher.current_idx] ) for _step in range(steps_between_swaps): - last_energy = state.energy + if hasattr(state, "energy"): + last_energy = state.energy - state = update_fn(state) + state = step_fn(model=model, state=state, **optimizer_kwargs) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) @@ -468,16 +509,14 @@ def optimize( # noqa: C901 # assume convergence_tensor shape is correct tqdm_pbar.update(torch.count_nonzero(convergence_tensor).item()) - all_converged_states.extend(result[1]) - if trajectory_reporter: trajectory_reporter.finish() if autobatcher: final_states = autobatcher.restore_original_order(all_converged_states) - return concatenate_states(final_states) + return ts.concatenate_states(final_states) - return state + return state # type: ignore[return-value] def static( @@ -516,10 +555,9 @@ def static( Returns: list[dict[str, torch.Tensor]]: Maps of property names to tensors for all batches """ - # initialize the state - state: SimState = initialize_state(system, model.device, model.dtype) + state: SimState = ts.initialize_state(system, model.device, model.dtype) - batch_iterator = _configure_batches_iterator(model, state, autobatcher) + batch_iterator = _configure_batches_iterator(model, state, autobatcher=autobatcher) properties = ["potential_energy"] if model.compute_forces: properties.append("forces") @@ -536,7 +574,7 @@ def static( ) @dataclass - class StaticState(type(state)): + class StaticState(SimState): energy: torch.Tensor forces: torch.Tensor stress: torch.Tensor @@ -558,6 +596,7 @@ class StaticState(type(state)): pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) + # Handle both BinningAutoBatcher and list of tuples for sub_state, system_indices in batch_iterator: # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: diff --git a/torch_sim/state.py b/torch_sim/state.py index a36312a90..a04fa5d37 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,16 +7,15 @@ import copy import importlib import typing -import warnings from collections import defaultdict -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self import torch import torch_sim as ts -from torch_sim.typing import SimStateVar, StateLike +from torch_sim.typing import StateLike if TYPE_CHECKING: @@ -99,7 +98,7 @@ def __init__( positions: torch.Tensor, masses: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.) + pbc: bool, # noqa: FBT001 atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, ) -> None: @@ -148,10 +147,7 @@ def __init__( self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) - else: - # assert that system indices are unique consecutive integers - # TODO(curtis): I feel like this logic is not reliable. - # I'll come up with something better later. + else: # assert that system indices are unique consecutive integers _, counts = torch.unique_consecutive(system_idx, return_counts=True) if not torch.all(counts == torch.bincount(system_idx)): raise ValueError("System indices must be unique consecutive integers") @@ -200,62 +196,6 @@ def n_atoms_per_system(self) -> torch.Tensor: else torch.tensor([self.n_atoms], device=self.device) ) - @property - def n_atoms_per_batch(self) -> torch.Tensor: - """Number of atoms per batch. - - deprecated:: - Use :attr:`n_atoms_per_system` instead. - """ - warnings.warn( - "n_atoms_per_batch is deprecated, use n_atoms_per_system instead", - DeprecationWarning, - stacklevel=2, - ) - return self.n_atoms_per_system - - @property - def batch(self) -> torch.Tensor: - """System indices. - - deprecated:: - Use :attr:`system_idx` instead. - """ - warnings.warn( - "batch is deprecated, use system_idx instead", - DeprecationWarning, - stacklevel=2, - ) - return self.system_idx - - @batch.setter - def batch(self, system_idx: torch.Tensor) -> None: - """Set the system indices from a batch index. - - deprecated:: - Use :attr:`system_idx` instead. - """ - warnings.warn( - "Setting batch is deprecated, use system_idx instead", - DeprecationWarning, - stacklevel=2, - ) - self.system_idx = system_idx - - @property - def n_batches(self) -> int: - """Number of batches in the system. - - deprecated:: - Use :attr:`n_systems` instead. - """ - warnings.warn( - "n_batches is deprecated, use n_systems instead", - DeprecationWarning, - stacklevel=2, - ) - return self.n_systems - @property def n_systems(self) -> int: """Number of systems in the system.""" @@ -264,8 +204,6 @@ def n_systems(self) -> int: @property def volume(self) -> torch.Tensor: """Volume of the system.""" - if not self.pbc: - raise ValueError("Volume is only defined for periodic systems") return torch.det(self.cell) @property @@ -312,7 +250,44 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return self.__class__(**attrs) + return type(self)(**attrs) + + @classmethod + def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: + """Create a new state from an existing state with additional attributes. + + This method copies all attributes from the source state and adds any additional + attributes needed for the target state class. It's useful for converting between + different state types (e.g., SimState to MDState). + + Args: + state: Source state to copy base attributes from + **additional_attrs: Additional attributes required by the target state class + + Returns: + New state of the target class with copied and additional attributes + + Example: + >>> from torch_sim.integrators.md import MDState + >>> md_state = MDState.from_state( + ... sim_state, + ... energy=model_output["energy"], + ... forces=model_output["forces"], + ... momenta=torch.zeros_like(sim_state.positions), + ... ) + """ + # Copy all attributes from the source state + attrs = {} + for attr_name, attr_value in vars(state).items(): + if isinstance(attr_value, torch.Tensor): + attrs[attr_name] = attr_value.clone() + else: + attrs[attr_name] = copy.deepcopy(attr_value) + + # Add/override with additional attributes + attrs.update(additional_attrs) + + return cls(**attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -376,7 +351,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se for attr_name, attr_value in vars(modified_state).items(): setattr(self, attr_name, attr_value) - return cast("list[Self]", popped_states) + return popped_states def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None @@ -415,7 +390,6 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. - See https://github.com/TorchSim/torch-sim/pull/219 for more details. Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. @@ -428,8 +402,8 @@ def __init_subclass__(cls, **kwargs) -> None: def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) - for attr_name, attr_typehint in type_hints.items(): - origin = typing.get_origin(attr_typehint) + for attr_name, attr_type_hint in type_hints.items(): + origin = typing.get_origin(attr_type_hint) is_union = origin is typing.Union if not is_union and origin is not None: @@ -437,7 +411,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We check by name to be robust against module reloading/patching issues is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" if is_union: - args = typing.get_args(attr_typehint) + args = typing.get_args(attr_type_hint) if torch.Tensor in args and type(None) in args: raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " @@ -465,11 +439,11 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: # 2) assert that all attributes are defined in all_defined_attributes all_annotations = {} - for c in cls.mro(): - if hasattr(c, "__annotations__"): - all_annotations.update(c.__annotations__) + for parent_cls in cls.mro(): + if hasattr(parent_cls, "__annotations__"): + all_annotations.update(parent_cls.__annotations__) - attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) + attributes_to_check = set(vars(cls)) | set(all_annotations) for attr_name in attributes_to_check: is_special_attribute = attr_name.startswith("__") @@ -539,7 +513,7 @@ def deform_grad(self) -> torch.Tensor: def _normalize_system_indices( - system_indices: int | list[int] | slice | torch.Tensor, + system_indices: int | Sequence[int] | slice | torch.Tensor, n_systems: int, device: torch.device, ) -> torch.Tensor: @@ -578,11 +552,9 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def state_to_device( - state: SimStateVar, - device: torch.device | None = None, - dtype: torch.dtype | None = None, -) -> SimStateVar: +def state_to_device[T: SimState]( + state: T, device: torch.device | None = None, dtype: torch.dtype | None = None +) -> T: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -611,7 +583,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + return type(state)(**attrs) # type: ignore[invalid-return-type] def get_attrs_for_scope( @@ -666,7 +638,7 @@ def _filter_attrs_by_mask( for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms - old_system_idxs = attr_value[atom_mask] + old_system_indices = attr_value[atom_mask] # Get the system indices that are kept kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ @@ -678,7 +650,7 @@ def _filter_attrs_by_mask( # Create new system tensor with remapped indices new_system_idxs = torch.tensor( - [system_idx_map[b.item()] for b in old_system_idxs], + [system_idx_map[b.item()] for b in old_system_indices], device=attr_value.device, dtype=attr_value.dtype, ) @@ -688,14 +660,15 @@ def _filter_attrs_by_mask( # Filter per-system attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - filtered_attrs[attr_name] = attr_value[system_mask] + if isinstance(attr_value, torch.Tensor): + filtered_attrs[attr_name] = attr_value[system_mask] + else: # Non-tensor attributes (e.g. cell filter) are copied as-is + filtered_attrs[attr_name] = attr_value return filtered_attrs -def _split_state( - state: SimStateVar, -) -> list[SimStateVar]: +def _split_state[T: SimState](state: T) -> list[T]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving @@ -717,38 +690,43 @@ def _split_state( split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + if isinstance(attr_value, torch.Tensor): + split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + else: # Non-tensor attributes are replicated for each split + split_per_system[attr_name] = [attr_value] * state.n_systems global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system - states = [] + states: list[T] = [] n_systems = len(system_sizes) - for i in range(n_systems): + for sys_idx in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( - system_sizes[i], device=state.device, dtype=torch.int64 + system_sizes[sys_idx], device=state.device, dtype=torch.int64 ), # Add the split per-atom attributes - **{attr_name: split_per_atom[attr_name][i] for attr_name in split_per_atom}, + **{ + attr_name: split_per_atom[attr_name][sys_idx] + for attr_name in split_per_atom + }, # Add the split per-system attributes **{ - attr_name: split_per_system[attr_name][i] + attr_name: split_per_system[attr_name][sys_idx] for attr_name in split_per_system }, # Add the global attributes **global_attrs, } - states.append(type(state)(**system_attrs)) + states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] return states -def _pop_states( - state: SimState, - pop_indices: list[int] | torch.Tensor, -) -> tuple[SimState, list[SimState]]: +def _pop_states[T: SimState]( + state: T, pop_indices: list[int] | torch.Tensor +) -> tuple[T, list[T]]: """Pop off the states with the specified indices. Extracts and removes the specified system indices from the state. @@ -784,19 +762,16 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = type(state)(**keep_attrs) + keep_state: T = type(state)(**keep_attrs) # type: ignore[assignment] # Create and split the pop state - pop_state = type(state)(**pop_attrs) + pop_state: T = type(state)(**pop_attrs) # type: ignore[assignment] pop_states = _split_state(pop_state) return keep_state, pop_states -def _slice_state( - state: SimStateVar, - system_indices: list[int] | torch.Tensor, -) -> SimStateVar: +def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor) -> T: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving @@ -830,12 +805,12 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state - return type(state)(**filtered_attrs) + return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] -def concatenate_states( - states: list[SimState], device: torch.device | None = None -) -> SimState: +def concatenate_states[T: SimState]( # noqa: C901 + states: Sequence[T], device: torch.device | None = None +) -> T: """Concatenate a list of SimStates into a single SimState. Combines multiple states into a single state with multiple systems. @@ -843,7 +818,7 @@ def concatenate_states( properties are concatenated. Args: - states (list[SimState]): A list of SimState objects to concatenate + states (Sequence[SimState]): A list of SimState objects to concatenate device (torch.device, optional): The device to place the concatenated state on. Defaults to the device of the first state. @@ -907,7 +882,10 @@ def concatenate_states( for prop, tensors in per_system_tensors.items(): # if tensors: - concatenated[prop] = torch.cat(tensors, dim=0) + if isinstance(tensors[0], torch.Tensor): + concatenated[prop] = torch.cat(tensors, dim=0) + else: # Non-tensor attributes, take first one (they should all be identical) + concatenated[prop] = tensors[0] # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) @@ -943,14 +921,14 @@ def initialize_state( if isinstance(system, SimState): return state_to_device(system, device, dtype) - if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(cast("SimState", state).n_systems == 1 for state in system): + if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): + if not all(state.n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) - return concatenate_states(system) + return ts.concatenate_states(system) converters = [ ("pymatgen.core", "Structure", ts.io.structures_to_state), @@ -965,23 +943,27 @@ def initialize_state( cls = getattr(module, class_name) if isinstance(system, cls) or ( - isinstance(system, list) and all(isinstance(s, cls) for s in system) + isinstance(system, list | tuple) + and all(isinstance(s, cls) for s in system) ): return converter_func(system, device, dtype) except ImportError: continue # remaining code just for informative error - is_list = isinstance(system, list) all_same_type = ( - is_list and all(isinstance(s, type(system[0])) for s in system) and system + isinstance(system, list | tuple) + and all(isinstance(s, type(system[0])) for s in system) + and system ) - if is_list and not all_same_type: + if isinstance(system, list | tuple) and not all_same_type: raise ValueError( f"All items in list must be of the same type, " f"found {type(system[0])} and {type(system[1])}" ) - system_type = f"list[{type(system[0])}]" if is_list else type(system) + system_type = ( + f"list[{type(system[0])}]" if isinstance(system, list | tuple) else type(system) + ) - raise ValueError(f"Unsupported system type, {system_type}") + raise ValueError(f"Unsupported {system_type=}") diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index fb170754d..d8a306d8c 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ase import Atoms - from ase.io.trajectory import Trajectory + from ase.io.trajectory import TrajectoryWriter _DATA_TYPE_MAP = { np.dtype("float32"): tables.Float32Atom(), @@ -197,14 +197,16 @@ def _add_model_arg_to_prop_calculators(self) -> None: if len(sig.parameters) == 1: # we partially evaluate the function to create a new function with # an optional second argument, this can be set to state later on - new_fn = partial(lambda state, _=None, fn=None: fn(state), fn=prop_fn) + new_fn = partial( + lambda state, _=None, fn=None: ( + None if fn is None else fn(state) + ), + fn=prop_fn, + ) self.prop_calculators[frequency][name] = new_fn def report( - self, - state: SimState, - step: int, - model: ModelInterface | None = None, + self, state: SimState, step: int, model: ModelInterface | None = None ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. @@ -377,7 +379,9 @@ def __init__( compression = None # TODO FIX THIS - if handles := tables.file._open_files.get_handlers_by_name(str(filename)): + if hasattr(tables, "file") and ( + handles := tables.file._open_files.get_handlers_by_name(str(filename)) + ): list(handles)[-1].close() # create parent directory if it doesn't exist @@ -387,7 +391,7 @@ def __init__( self.array_registry: dict[str, tuple[tuple[int, ...], np.dtype]] = {} # check if the header has already been written - if "header" not in [node._v_name for node in self._file.list_nodes("/")]: + if "header" not in (node._v_name for node in self._file.list_nodes("/")): self._initialize_header(metadata) self._initialize_registry() @@ -715,10 +719,8 @@ def write_state( # noqa: C901 steps = [steps] if isinstance(system_index, int): - system_index = [system_index] sub_states = [state[system_index] for state in state] elif system_index is None and torch.unique(state[0].system_idx) == 0: - system_index = 0 sub_states = state else: raise ValueError( @@ -970,7 +972,7 @@ def __len__(self) -> int: """ return self._file.root.data.positions.shape[0] - def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": + def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryWriter": """Convert trajectory to ASE Trajectory format. Writes the entire trajectory to a new file in ASE format for compatibility diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 65947e8db..1b2c416b5 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -5,7 +5,7 @@ general PBC wrapping. """ -from collections.abc import Callable +from collections.abc import Callable, Iterable from functools import wraps import torch @@ -129,8 +129,7 @@ def pbc_wrap_general( lattice vectors as columns (A matrix in the equations). Returns: - torch.Tensor: Tensor of wrapped positions in real space with - same shape as input positions. + torch.Tensor: Wrapped positions in real space with same shape as input positions. """ # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point( @@ -172,8 +171,7 @@ def pbc_wrap_batched( indices for each atom. Returns: - torch.Tensor: Tensor of wrapped positions in real space with - same shape as input positions. + torch.Tensor: Wrapped positions in real space with same shape as input positions. """ # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point(cell): @@ -183,8 +181,8 @@ def pbc_wrap_batched( raise ValueError("Position dimensionality must match lattice vectors.") # Get unique system indices and counts - unique_systems = torch.unique(system_idx) - n_systems = len(unique_systems) + uniq_systems = torch.unique(system_idx) + n_systems = len(uniq_systems) if n_systems != cell.shape[0]: raise ValueError( @@ -349,7 +347,7 @@ def wrap_positions( cell: torch.Tensor, *, pbc: bool | list[bool] | torch.Tensor = True, - center: tuple[float, float, float] | float = (0.5, 0.5, 0.5), + center: tuple[float, float, float] = (0.5, 0.5, 0.5), pretty_translation: bool = False, eps: float = 1e-7, ) -> torch.Tensor: @@ -374,10 +372,7 @@ def wrap_positions( device = positions.device # Convert center to tensor - if isinstance(center, float): - center_pos = torch.tensor((center,) * 3, dtype=positions.dtype, device=device) - else: - center_pos = torch.tensor(center, dtype=positions.dtype, device=device) + center_tensor = torch.tensor(center, dtype=positions.dtype, device=device) # Handle PBC input if isinstance(pbc, bool): @@ -386,7 +381,7 @@ def wrap_positions( pbc = torch.tensor(pbc, dtype=torch.bool, device=device) # Calculate shift based on center - shift = center_pos - 0.5 - eps + shift = center_tensor - 0.5 - eps shift[~pbc] = 0.0 # Convert positions to fractional coordinates @@ -394,7 +389,7 @@ def wrap_positions( if pretty_translation: fractional = translate_pretty(fractional, pbc) - shift = center_pos - 0.5 + shift = center_tensor - 0.5 shift[~pbc] = 0.0 fractional += shift else: @@ -489,7 +484,7 @@ def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor num_repeats[ii] + 1, device=num_repeats.device, dtype=dtype, - ) # type: ignore[call-overload] + ) _, indices = torch.sort(torch.abs(r1)) reps.append(r1[indices]) return torch.cartesian_prod(reps[0], reps[1], reps[2]) @@ -498,7 +493,7 @@ def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor def compute_distances_with_cell_shifts( pos: torch.Tensor, mapping: torch.Tensor, - cell_shifts: torch.Tensor | None, + cell_shifts: torch.Tensor, ) -> torch.Tensor: """Compute distances between pairs of positions, optionally including cell shifts. @@ -514,7 +509,7 @@ def compute_distances_with_cell_shifts( mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `pos` for which to compute distances. - cell_shifts (torch.Tensor | None): A tensor of shape (n_pairs, 3) + cell_shifts (Optional[torch.Tensor]): A tensor of shape (n_pairs, 3) representing the shifts to apply to the distances for periodic boundary conditions. If None, no shifts are applied. @@ -536,15 +531,15 @@ def compute_distances_with_cell_shifts( def compute_cell_shifts( - cell: torch.Tensor | None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor -) -> torch.Tensor | None: + cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor +) -> torch.Tensor: """Compute the cell shifts based on the provided indices and cell matrix. This function calculates the shifts to apply to positions based on the specified indices and the unit cell matrix. If the cell is None, it returns None. Args: - cell (torch.Tensor | None): A tensor of shape (n_cells, 3, 3) + cell (torch.Tensor): A tensor of shape (n_cells, 3, 3) representing the unit cell matrices. shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) representing the indices for shifts. @@ -556,17 +551,12 @@ def compute_cell_shifts( the computed cell shifts. """ if cell is None: - return None - return compute_cell_shifts_strict(cell, shifts_idx, system_mapping) - - -def compute_cell_shifts_strict( - cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor -) -> torch.Tensor: - """Same thing as compute_cell_shifts, but cell cannot be None. - Having a non-optional cell makes torchjit not complain. - """ - return torch.einsum("jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping]) + cell_shifts = None + else: + cell_shifts = torch.einsum( + "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping] + ) + return cell_shifts def get_fully_connected_mapping( @@ -666,22 +656,17 @@ def build_naive_neighborhood( ids = torch.arange(positions.shape[0], device=device, dtype=torch.long) mapping, system_mapping, shifts_idx_ = [], [], [] - for i_structure in range(n_atoms.shape[0]): - num_repeats = num_repeats_[i_structure] + for struct_idx in range(n_atoms.shape[0]): + num_repeats = num_repeats_[struct_idx] shifts_idx = get_cell_shift_idx(num_repeats, dtype) - i_ids = ids[stride[i_structure] : stride[i_structure + 1]] + i_ids = ids[stride[struct_idx] : stride[struct_idx + 1]] s_mapping, shifts_idx = get_fully_connected_mapping( i_ids=i_ids, shifts_idx=shifts_idx, self_interaction=self_interaction ) mapping.append(s_mapping) system_mapping.append( - torch.full( - (s_mapping.shape[0],), - i_structure, - dtype=torch.long, - device=device, - ) + torch.full((s_mapping.shape[0],), struct_idx, dtype=torch.long, device=device) ) shifts_idx_.append(shifts_idx) return ( @@ -854,7 +839,7 @@ def linked_cell( # noqa: PLR0915 shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image ) batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long) - cell_shifts = compute_cell_shifts_strict(cell.view(-1, 3, 3), shifts_idx, batch_image) + cell_shifts = compute_cell_shifts(cell.view(-1, 3, 3), shifts_idx, batch_image) i_ids = torch.arange(n_atom, device=device, dtype=torch.long) i_ids = i_ids.repeat(n_cell_image) @@ -1021,21 +1006,21 @@ def build_linked_cell_neighborhood( stride = strides_of(n_atoms) mapping, system_mapping, cell_shifts_idx = [], [], [] - for i_structure in range(n_structure): + for struct_idx in range(n_structure): # Compute the neighborhood with the linked cell algorithm neigh_atom, neigh_shift_idx = linked_cell( - positions[stride[i_structure] : stride[i_structure + 1]], - cell[i_structure], + positions[stride[struct_idx] : stride[struct_idx + 1]], + cell[struct_idx], cutoff, - num_repeats[i_structure], + num_repeats[struct_idx], self_interaction, ) system_mapping.append( - i_structure * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device) + struct_idx * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device) ) # Shift the mapping indices to access positions - mapping.append(neigh_atom + stride[i_structure]) + mapping.append(neigh_atom + stride[struct_idx]) cell_shifts_idx.append(neigh_shift_idx) return ( @@ -1047,8 +1032,8 @@ def build_linked_cell_neighborhood( def multiplicative_isotropic_cutoff( fn: Callable[..., torch.Tensor], - r_onset: torch.Tensor, - r_cutoff: torch.Tensor, + r_onset: float | torch.Tensor, + r_cutoff: float | torch.Tensor, ) -> Callable[..., torch.Tensor]: """Creates a smoothly truncated version of an isotropic function. @@ -1080,16 +1065,16 @@ def multiplicative_isotropic_cutoff( HOOMD-blue documentation: https://hoomd-blue.readthedocs.io/en/latest/hoomd/md/module-pair.html """ - r_c = r_cutoff**2 - r_o = r_onset**2 + r_c = torch.square(torch.tensor(r_cutoff)) + r_o = torch.square(torch.tensor(r_onset)) def smooth_fn(dr: torch.Tensor) -> torch.Tensor: """Compute the smooth switching function.""" - r = dr**2 + r = torch.square(dr) # Compute switching function for intermediate region - numerator = (r_c - r) ** 2 * (r_c + 2 * r - 3 * r_o) - denominator = (r_c - r_o) ** 3 + numerator = torch.square(r_c - r) * (r_c + 2 * r - 3 * r_o) + denominator = torch.pow(r_c - r_o, 3) intermediate = torch.where( dr < r_cutoff, numerator / denominator, torch.zeros_like(dr) ) @@ -1107,7 +1092,7 @@ def cutoff_fn(dr: torch.Tensor, *args, **kwargs) -> torch.Tensor: def high_precision_sum( x: torch.Tensor, - dim: int | list[int] | tuple[int, ...] | None = None, + dim: int | Iterable[int] | None = None, *, keepdim: bool = False, ) -> torch.Tensor: @@ -1144,7 +1129,7 @@ def high_precision_sum( def safe_mask( mask: torch.Tensor, - fn: Callable[..., torch.Tensor], + fn: Callable[[torch.Tensor], torch.Tensor], operand: torch.Tensor, placeholder: float = 0.0, ) -> torch.Tensor: @@ -1156,7 +1141,7 @@ def safe_mask( Args: mask: Boolean tensor indicating which elements to process (True) or mask (False) - fn: callable function to apply to the masked elements + fn: TorchScript function to apply to the masked elements operand: Input tensor to apply the function to placeholder: Value to use for masked-out positions (default: 0.0) diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 94ec44caf..94c6dab2d 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -1,7 +1,7 @@ -"""Types used across torch-sim.""" +"""Types used across TorchSim.""" -from enum import Enum -from typing import TYPE_CHECKING, Literal, TypeVar, Union +from enum import StrEnum +from typing import TYPE_CHECKING, Literal, Union import torch @@ -17,10 +17,9 @@ MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"] StateDict = dict[StateKey, torch.Tensor] -SimStateVar = TypeVar("SimStateVar", bound="SimState") -class BravaisType(Enum): +class BravaisType(StrEnum): """Enumeration of the seven Bravais lattice types in 3D crystals. These lattice types represent the distinct crystal systems classified @@ -31,13 +30,13 @@ class BravaisType(Enum): which determine the number of independent elastic constants. """ - CUBIC = "cubic" - HEXAGONAL = "hexagonal" - TRIGONAL = "trigonal" - TETRAGONAL = "tetragonal" - ORTHORHOMBIC = "orthorhombic" - MONOCLINIC = "monoclinic" - TRICLINIC = "triclinic" + cubic = "cubic" + hexagonal = "hexagonal" + trigonal = "trigonal" + tetragonal = "tetragonal" + orthorhombic = "orthorhombic" + monoclinic = "monoclinic" + triclinic = "triclinic" StateLike = Union[ @@ -47,6 +46,5 @@ class BravaisType(Enum): list["Atoms"], list["Structure"], list["PhonopyAtoms"], - SimStateVar, - list[SimStateVar], + "SimState", ] diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index fb5bba0b7..2d5e2c618 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -8,21 +8,19 @@ """ # ruff: noqa: T201 - import itertools from collections.abc import Sequence -from typing import Any import numpy as np import torch -from pymatgen.core.composition import Composition +from pymatgen.core import Composition import torch_sim as ts from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.models.soft_sphere import SoftSphereModel, SoftSphereMultiModel -from torch_sim.optimizers import FireState, UnitCellFireState, fire -from torch_sim.optimizers import unit_cell_fire as batched_unit_cell_fire +from torch_sim.optimizers import FireState +from torch_sim.quantities import get_pressure def min_distance( @@ -95,7 +93,7 @@ def get_diameter(composition: Composition) -> float: For single elements, it is twice the appropriate atomic radius. Examples: - >>> from pymatgen.core.composition import Composition + >>> from pymatgen.core import Composition >>> # Multi-element example: Fe2O3 >>> comp = Composition("Fe2O3") >>> diameter = get_diameter(comp) # Returns minimum Fe-O or O-O separation @@ -168,7 +166,7 @@ def get_diameter_matrix( at [i,j] represents the minimum separation between elements i and j. Examples: - >>> from pymatgen.core.composition import Composition + >>> from pymatgen.core import Composition >>> comp = Composition("Fe2O3") >>> diameters = get_diameter_matrix(comp) >>> print(diameters) # Shows Fe-Fe, Fe-O, and O-O separations @@ -228,8 +226,7 @@ def random_packed_structure( distance_tolerance: float = 0.0001, device: torch.device | None = None, dtype: torch.dtype = torch.float32, - log: Any | None = None, -) -> FireState | tuple[FireState, list[np.ndarray]]: +) -> tuple[FireState, list[np.ndarray]]: """Generates a random packed atomic structure and minimizes atomic overlaps. This function creates a random atomic structure within a given cell and optionally @@ -256,7 +253,7 @@ def random_packed_structure( Returns: FIREState: The optimized structure state containing positions, forces, - energies and other optimization parameters. + energies and a list of positions at each iteration. Notes: - If both diameter and auto_diameter are None, no overlap minimization @@ -274,6 +271,7 @@ def random_packed_structure( if seed is not None: generator.manual_seed(seed) + log = [] # Generate initial random positions in fractional coordinates N_atoms = sum(element_counts) positions = torch.rand((N_atoms, 3), device=device, dtype=dtype, generator=generator) @@ -309,8 +307,7 @@ def random_packed_structure( cell=cell, pbc=True, ) - fire_init, fire_update = fire(model=model) - state = fire_init(state) + state = ts.fire_init(model, state) print(f"Initial energy: {state.energy.item():.4f}") # Run FIRE optimization until convergence or max iterations for _step in range(max_iter): @@ -318,17 +315,13 @@ def random_packed_structure( if min_distance(state.positions, cell, distance_tolerance) > diameter * 0.95: break - if log is not None: - log.append(state.positions.cpu().numpy()) + log.append(state.positions.cpu().numpy()) - state = fire_update(state) + state = ts.fire_step(model, state) print(f"Final energy: {state.energy.item():.4f}") - if log is not None: - return state, log - - return state + return state, log def random_packed_structure_multi( @@ -391,7 +384,7 @@ def random_packed_structure_multi( # Create species indices tensor mapping each atom to its species type # e.g. for Fe80B20: [0,0,...,0,1,1,...,1] where 0=Fe, 1=B species_idx = torch.tensor( - [i for i, count in enumerate(element_counts) for _ in range(count)], + [idx for idx, count in enumerate(element_counts) for _ in range(count)], device=device, ) @@ -439,8 +432,7 @@ def random_packed_structure_multi( pbc=True, ) # Set up FIRE optimizer with unit masses for all atoms - fire_init, fire_update = fire(model=model) - state = fire_init(state_dict) + state = ts.fire_init(model, state_dict) print(f"Initial energy: {state.energy.item():.4f}") # Run FIRE optimization until convergence or max iterations for _step in range(max_iter): @@ -448,7 +440,7 @@ def random_packed_structure_multi( min_dist = min_distance(state.positions, cell, distance_tolerance) if min_dist > diameter_matrix.min() * 0.95: break - state = fire_update(state) + state = ts.fire_step(model, state) print(f"Final energy: {state.energy.item():.4f}") return state @@ -460,7 +452,7 @@ def valid_subcell( initial_energy: float, final_energy: float, e_tol: float = 0.001, - fe_lower_limit: float = -5.0, + e_form_lower_limit: float = -5.0, fe_upper_limit: float = 0.0, fusion_distance: float = 1.5, distance_tolerance: float = 0.0001, @@ -482,8 +474,8 @@ def valid_subcell( final_energy: Total energy of the structure after relaxation, in eV. e_tol: Energy tolerance for comparing initial and final energies, in eV. Used to check if optimization reduced the energy. Defaults to 0.001 eV. - fe_lower_limit: Lower limit for formation energy, in eV/atom. Values below this - are considered unphysical. Defaults to -5.0 eV/atom. + e_form_lower_limit: Lower limit for formation energy, in eV/atom. Values below + this are considered unphysical. Defaults to -5.0 eV/atom. fe_upper_limit: Upper limit for formation energy, in eV/atom. Values above this indicate poor convergence. Defaults to 0.0 eV/atom. fusion_distance: Minimum allowed distance between any pair of atoms, in Γ…. @@ -501,7 +493,7 @@ def valid_subcell( - Atomic fusion (distances < ~1.5 Γ…) indicates an unphysical structure """ # Check if formation energy is unphysically negative - if final_energy < fe_lower_limit: + if final_energy < e_form_lower_limit: return False # Check if optimization properly reduced the energy @@ -708,7 +700,8 @@ def get_unit_cell_relaxed_structure( state: ts.SimState, model: ModelInterface, max_iter: int = 200, -) -> tuple[UnitCellFireState, dict[str, torch.Tensor], list[float], list[float]]: + verbose: bool = True, # noqa: FBT001, FBT002 +) -> tuple[ts.FireState, dict[str, torch.Tensor], list[float], list[float]]: """Relax both atomic positions and cell parameters using FIRE algorithm. This function performs geometry optimization of both atomic positions and unit cell @@ -719,6 +712,7 @@ def get_unit_cell_relaxed_structure( state: State containing positions, cell and atomic numbers model: Model to compute energies, forces, and stresses max_iter: Maximum number of FIRE iterations. Defaults to 200. + verbose: Whether to print initial and final energy and pressure. Defaults to True. Returns: tuple containing: @@ -740,23 +734,21 @@ def get_unit_cell_relaxed_structure( results = model(state) init_energy = [e.item() for e in results["energy"]] init_stress = results["stress"] - init_pressure = [(torch.trace(stress) / 3.0).item() for stress in init_stress] - print( - f"Initial energy: {[f'{e:.4f}' for e in init_energy]} eV, " - f"Initial pressure: {[f'{p:.4f}' for p in init_pressure]} eV/A^3" - ) + init_pressure = [p.item() for p in get_pressure(init_stress, 0.0, state.volume)] + if verbose: + print( + f"Initial energy: {[f'{e:.4f}' for e in init_energy]} eV, " + f"Initial pressure: {[f'{p:.4f}' for p in init_pressure]} eV/A^3" + ) - unit_cell_fire_init, unit_cell_fire_update = batched_unit_cell_fire( - model=model, - ) - state = unit_cell_fire_init(state) + state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.unit) def step_fn( - step: int, state: UnitCellFireState, logger: dict[str, torch.Tensor] - ) -> tuple[UnitCellFireState, dict[str, torch.Tensor]]: + step: int, state: ts.FireState, logger: dict[str, torch.Tensor] + ) -> tuple[ts.FireState, dict[str, torch.Tensor]]: logger["energy"][step] = state.energy logger["stress"][step] = state.stress - state = unit_cell_fire_update(state) + state = ts.fire_step(model=model, state=state) return state, logger for step in range(max_iter): @@ -767,9 +759,148 @@ def step_fn( final_energy = [e.item() for e in final_results["energy"]] final_stress = final_results["stress"] - final_pressure = [(torch.trace(stress) / 3.0).item() for stress in final_stress] - print( - f"Final energy: {[f'{e:.4f}' for e in final_energy]} eV, " - f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" + final_pressure = [p.item() for p in get_pressure(final_stress, 0.0, state.volume)] + if verbose: + print( + f"Final energy: {[f'{e:.4f}' for e in final_energy]} eV, " + f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" + ) + return state, logger, final_energy, final_pressure + + +def get_frechet_cell_relaxed_structure( + state: ts.SimState, + model: ModelInterface, + max_iter: int = 200, + verbose: bool = True, # noqa: FBT001, FBT002 +) -> tuple[ts.FireState, dict[str, torch.Tensor], list[float], list[float]]: + """Relax both atomic positions and cell parameters using FIRE algorithm. + + This function performs geometry optimization of both atomic positions and unit cell + parameters simultaneously. Uses the Fast Inertial Relaxation Engine (FIRE) algorithm + to minimize forces on atoms and stresses on the cell. + + Args: + state: State containing positions, cell and atomic numbers + model: Model to compute energies, forces, and stresses + max_iter: Maximum number of FIRE iterations. Defaults to 200. + verbose: Whether to print initial and final energy and pressure. Defaults to True. + + Returns: + tuple containing: + - ts.FireState: Final state containing relaxed positions, + cell and more + - dict: Logger with energy and stress trajectories + - float: Final energy in eV + - float: Final pressure in eV/Γ…Β³ + """ + # Get device and dtype from model + device, dtype = model.device, model.dtype + + logger = { + "energy": torch.zeros((max_iter, state.n_systems), device=device, dtype=dtype), + "stress": torch.zeros( + (max_iter, state.n_systems, 3, 3), device=device, dtype=dtype + ), + } + + results = model(state) + init_energy = [e.item() for e in results["energy"]] + init_stress = results["stress"] + init_pressure = [p.item() for p in get_pressure(init_stress, 0.0, state.volume)] + if verbose: + print( + f"Initial energy: {[f'{e:.4f}' for e in init_energy]} eV, " + f"Initial pressure: {[f'{p:.4f}' for p in init_pressure]} eV/A^3" + ) + + state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.frechet) + + def step_fn( + step: int, state: ts.FireState, logger: dict[str, torch.Tensor] + ) -> tuple[ts.FireState, dict]: + logger["energy"][step] = state.energy + logger["stress"][step] = state.stress + state = ts.fire_step(model=model, state=state) + return state, logger + + for step in range(max_iter): + state, logger = step_fn(step, state, logger) + + # Get final results + final_results = model(state) + + final_energy = [e.item() for e in final_results["energy"]] + final_stress = final_results["stress"] + final_pressure = [p.item() for p in get_pressure(final_stress, 0.0, state.volume)] + if verbose: + print( + f"Final energy: {[f'{e:.4f}' for e in final_energy]} eV, " + f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" + ) + return state, logger, final_energy, final_pressure + + +def get_relaxed_structure( + state: ts.SimState, + model: ModelInterface, + max_iter: int = 200, + verbose: bool = True, # noqa: FBT001, FBT002 +) -> tuple[FireState, dict[str, torch.Tensor], list[float], list[float]]: + """Relax atomic positions at fixed cell parameters using FIRE algorithm. + + Does geometry optimization of atomic positions while keeping the unit cell fixed. + Uses the Fast Inertial Relaxation Engine (FIRE) algorithm to minimize forces on atoms. + + Args: + state: State containing positions, cell and atomic numbers + model: Model to compute energies, forces, and stresses + max_iter: Maximum number of FIRE iterations. Defaults to 200. + verbose: Whether to print initial and final energy and pressure. Defaults to True. + + Returns: + tuple containing: + - FIREState: Final state containing relaxed positions and other quantities + - dict: Logger with energy trajectory + - float: Final energy in eV + - float: Final pressure in eV/Γ…Β³ + """ + # Get device and dtype from model + device, dtype = model.device, model.dtype + + logger = {"energy": torch.zeros((max_iter, 1), device=device, dtype=dtype)} + + results = model(state) + init_energy = [e.item() for e in results["energy"]] + if verbose: + print(f"Initial energy: {[f'{e:.4f}' for e in init_energy]} eV") + + state = ts.fire_init(model=model, state=state) + + def step_fn( + idx: int, state: FireState, logger: dict[str, torch.Tensor] + ) -> tuple[FireState, dict[str, torch.Tensor]]: + logger["energy"][idx] = state.energy + state = ts.fire_step(model=model, state=state) + return state, logger + + for idx in range(max_iter): + state, logger = step_fn(idx, state, logger) + + # Get final results with stress computation enabled + final_results = model( + positions=state.positions, + cell=state.cell, + atomic_numbers=state.atomic_numbers, + compute_stress=True, ) + + final_energy = [e.item() for e in final_results["energy"]] + final_stress = final_results["stress"] + final_pressure = [p.item() for p in get_pressure(final_stress, 0.0, state.volume)] + if verbose: + print( + f"Final energy: {[f'{e:.4f}' for e in final_energy]} eV, " + f"Final pressure: {[f'{p:.4f}' for p in final_pressure]} eV/A^3" + ) return state, logger, final_energy, final_pressure From 8d585bb3b9f254fb7c43cb8a97f1c05645e81cbc Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 15:29:41 -0700 Subject: [PATCH 02/40] bump min python version 3.11->3.12 fix README.md file name casing --- .github/workflows/docs.yml | 2 +- .github/workflows/test.yml | 28 ++++++++++++++-------------- .pre-commit-config.yaml | 4 ++-- citation.cff | 2 +- docs/user/introduction.rst | 2 +- pyproject.toml | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7aeabd47f..492e299fd 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -24,7 +24,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: '3.12' - name: Set up uv uses: astral-sh/setup-uv@v6 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f08a83c1f..8b8626e65 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,8 +18,8 @@ jobs: matrix: os: [ubuntu-latest, macos-14] version: - - { python: "3.11", resolution: highest } - - { python: "3.12", resolution: lowest-direct } + - { python: '3.12', resolution: highest } + - { python: '3.13', resolution: lowest-direct } runs-on: ${{ matrix.os }} steps: @@ -64,18 +64,18 @@ jobs: matrix: os: [ubuntu-latest, macos-14] version: - - { python: "3.11", resolution: highest } - - { python: "3.12", resolution: lowest-direct } + - { python: '3.12', resolution: highest } + - { python: '3.13', resolution: lowest-direct } model: - - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - - { name: mace, test_path: "tests/models/test_mace.py" } - - { name: mace, test_path: "tests/test_elastic.py" } - - { name: mace, test_path: "tests/test_optimizers_vs_ase.py" } - - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - - { name: metatomic, test_path: "tests/models/test_metatomic.py" } - - { name: orb, test_path: "tests/models/test_orb.py" } - - { name: sevenn, test_path: "tests/models/test_sevennet.py" } + - { name: fairchem, test_path: tests/models/test_fairchem.py } + - { name: graphpes, test_path: tests/models/test_graphpes.py } + - { name: mace, test_path: tests/models/test_mace.py } + - { name: mace, test_path: tests/test_elastic.py } + - { name: mace, test_path: tests/test_optimizers_vs_ase.py } + - { name: mattersim, test_path: tests/models/test_mattersim.py } + - { name: metatomic, test_path: tests/models/test_metatomic.py } + - { name: orb, test_path: tests/models/test_orb.py } + - { name: sevenn, test_path: tests/models/test_sevennet.py } runs-on: ${{ matrix.os }} steps: @@ -162,7 +162,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: '3.12' - name: Set up uv uses: astral-sh/setup-uv@v6 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 96e9ec319..300ac347e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.13 hooks: - - id: ruff + - id: ruff-check args: [--fix] - id: ruff-format @@ -39,4 +39,4 @@ repos: # MD033: no inline HTML # MD041: first line in a file should be a top-level heading # MD034: bare URL used - args: [--disable, MD013, MD033, MD041, MD034, "--"] + args: [--disable, MD013, MD033, MD041, MD034, '--'] diff --git a/citation.cff b/citation.cff index 9ac5db9b8..bcf594e7b 100644 --- a/citation.cff +++ b/citation.cff @@ -15,7 +15,7 @@ authors: - family-names: Falletta given-names: Stefano license: MIT -license-url: https://github.com/torchsim/torch-sim/blob/main/license +license-url: https://github.com/torchsim/torch-sim/blob/main/LICENSE repository-code: https://github.com/torchsim/torch-sim url: https://github.com/torchsim/torch-sim type: software diff --git a/docs/user/introduction.rst b/docs/user/introduction.rst index 910a3b816..67833f0e1 100644 --- a/docs/user/introduction.rst +++ b/docs/user/introduction.rst @@ -3,6 +3,6 @@ Introduction ============ -.. include:: ../../readme.md +.. include:: ../../README.md :start-after: :parser: myst_parser.sphinx_ diff --git a/pyproject.toml b/pyproject.toml index 4fc030287..3b6154b71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,8 @@ authors = [ { name = "Janosh Riebesell", email = "janosh.riebesell@gmail.com" }, { name = "Orion Cohen", email = "orioncohen@berkeley.edu" }, ] -readme = "readme.md" -license = { file = "license" } +readme = "README.md" +license = { file = "LICENSE" } keywords = [ "chemistry", "interatomic-potentials", From 8480fee1c49a6f59fe14b44a74241ce1d9d354d6 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 16:07:16 -0700 Subject: [PATCH 03/40] CI fixes --- .../6_Phonons/6.3_Conductivity_MACE.py | 8 +++-- examples/tutorials/diff_sim.py | 36 ++++++++++--------- pyproject.toml | 2 +- tests/models/test_metatomic.py | 1 + tests/test_autobatching.py | 10 ++++-- tests/test_runners.py | 6 ++-- torch_sim/properties/correlations.py | 2 +- torch_sim/runners.py | 7 ++-- 8 files changed, 44 insertions(+), 28 deletions(-) diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index 066a06add..34d36755f 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -13,6 +13,7 @@ import time from typing import TYPE_CHECKING, Literal, cast +import IPython import numpy as np import plotly.graph_objects as go import torch @@ -101,11 +102,11 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: system=struct, model=model, optimizer=ts.OptimFlavor.fire, + cell_filter=ts.CellFilter.frechet, max_steps=max_steps, convergence_fn=converge_max_force, trajectory_reporter=reporter, - constant_volume=True, - hydrostatic_strain=True, + init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), ) print_relax_info(trajectory_file, device) @@ -199,4 +200,5 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: height=600, plot_bgcolor="white", ) -fig.show() +if IPython.get_ipython() is not None: + fig.show() diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index 8fbe01a85..6db460893 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -100,16 +100,6 @@ def draw_system( # %% -@dataclass -class BaseState: - """Simple simulation state""" - - positions: torch.Tensor - cell: torch.Tensor - pbc: bool - species: torch.Tensor - - class SoftSphereMultiModel(ModelInterface): """Soft sphere potential""" @@ -127,8 +117,8 @@ def __init__( ) -> None: """Initialize a soft sphere model for multi-component systems.""" super().__init__() - self.device = device or torch.device("cpu") - self.dtype = dtype + self._device = device or torch.device("cpu") + self._dtype = dtype self.pbc = pbc # Store species list and determine number of unique species @@ -183,14 +173,14 @@ def __init__( ) def forward( - self, custom_state: BaseState, species: torch.Tensor | None = None + self, custom_state: ts.SimState, species: torch.Tensor | None = None ) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system with multiple species.""" # Convert inputs to proper device/dtype and handle species positions = custom_state.positions.requires_grad_(True) cell = custom_state.cell - species = custom_state.species + species = custom_state.atomic_numbers if species is not None: species = species.to(device=self.device, dtype=torch.long) @@ -307,7 +297,14 @@ def simulation( R = torch.rand(N, 2) * box_size # Minimize to the nearest minimum. - custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + custom_state = ts.SimState( + atomic_numbers=species, + masses=torch.ones(N), + system_idx=torch.arange(N), + positions=R, + cell=cell, + pbc=True, + ) state = ts.gradient_descent_init(model, state=custom_state) for _ in range(simulation_steps): @@ -393,7 +390,14 @@ def short_simulation( model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) # Minimize to the nearest minimum. - custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + custom_state = ts.SimState( + atomic_numbers=species, + masses=torch.ones(N), + system_idx=torch.arange(N), + positions=R, + cell=cell, + pbc=True, + ) state = ts.gradient_descent_init(model, state=custom_state) for i in range(short_simulation_steps): diff --git a/pyproject.toml b/pyproject.toml index 3b6154b71..dd61083f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ test = [ io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] mace = ["mace-torch>=0.3.12"] mattersim = ["mattersim>=0.1.2"] -metatomic = ["metatomic-torch>=0.1.1,<0.2", "metatrain[pet]==2025.7"] +metatomic = ["metatomic-torch>=0.1.4,<0.2", "metatrain[pet]==2025.10"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index 7145bbd94..ca360e015 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -26,6 +26,7 @@ def metatomic_calculator(): "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" ).export(), device=DEVICE, + dtype=torch.float32, ) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index dd0ff94ed..4549a1c12 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -175,7 +175,7 @@ def test_binning_auto_batcher_auto_metric( """Test BinningAutoBatcher with different states.""" # monkeypatch determine max memory scaler monkeypatch.setattr( - "ts.autobatching.determine_max_batch_size", + "torch_sim.autobatching.determine_max_batch_size", lambda *args, **kwargs: 50, # noqa: ARG005 ) @@ -356,7 +356,9 @@ def test_determine_max_batch_size_fibonacci( def mock_measure(*_args: Any, **_kwargs: Any) -> float: return 0.1 # Return a small constant memory usage - monkeypatch.setattr("ts.autobatching.measure_model_memory_forward", mock_measure) + monkeypatch.setattr( + "torch_sim.autobatching.measure_model_memory_forward", mock_measure + ) # Test with a small max_atoms value to limit the sequence max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10) @@ -375,7 +377,9 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( scale_factor: float, ) -> None: """Test determine_max_batch_size doesn't infinite loop with small scale factors.""" - monkeypatch.setattr("ts.autobatching.measure_model_memory_forward", lambda *_: 0.1) + monkeypatch.setattr( + "torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1 + ) max_size = determine_max_batch_size( si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor diff --git a/tests/test_runners.py b/tests/test_runners.py index 8db32ff18..a1896a048 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -488,7 +488,9 @@ def test_integrate_with_default_autobatcher( def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 return 10_000.0 - monkeypatch.setattr("ts.autobatching.estimate_max_memory_scaler", mock_estimate) + monkeypatch.setattr( + "torch_sim.autobatching.estimate_max_memory_scaler", mock_estimate + ) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] triple_state = ts.initialize_state(states, lj_model.device, lj_model.dtype) @@ -520,7 +522,7 @@ def test_optimize_with_default_autobatcher( def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 return 200 - monkeypatch.setattr("ts.autobatching.determine_max_batch_size", mock_estimate) + monkeypatch.setattr("torch_sim.autobatching.determine_max_batch_size", mock_estimate) states = [ar_supercell_sim_state, fe_supercell_sim_state, ar_supercell_sim_state] triple_state = ts.initialize_state( diff --git a/torch_sim/properties/correlations.py b/torch_sim/properties/correlations.py index 9f6869a06..dfa81e89e 100644 --- a/torch_sim/properties/correlations.py +++ b/torch_sim/properties/correlations.py @@ -535,7 +535,7 @@ def __init__( velocities=None, energies=self.model(s)["energies"], stresses=full_3x3_to_voigt_6_stress(self.model(s)["stresses"]), - batch=s.batch, + batch=s.system_idx, is_centroid_stress=False, is_virial_only=False, ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3eb8a0aab..2c5d7274d 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -381,6 +381,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 max_steps: int = 10_000, steps_between_swaps: int = 5, pbar: bool | dict[str, Any] = False, + init_kwargs: dict[str, Any] | None = None, **optimizer_kwargs: Any, ) -> T: """Optimize a system using a model and optimizer. @@ -410,7 +411,9 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 pbar (bool | dict[str, Any], optional): Show a progress bar. Only works with an autobatcher in interactive shell. If a dict is given, it's passed to `tqdm` as kwargs. - optimizer_kwargs: Additional keyword arguments for optimizer init function + init_kwargs (dict[str, Any], optional): Additional keyword arguments for optimizer + init function. + **optimizer_kwargs: Additional keyword arguments for optimizer step function Returns: T: Optimized system state @@ -448,7 +451,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 init_fn, initial_state, model, - init_kwargs=dict(cell_filter=cell_filter), + init_kwargs=dict(cell_filter=cell_filter, **init_kwargs), max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, ) From b6f50c75a4e2d070f9a726bd86e5d54758c98c77 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 16:09:36 -0700 Subject: [PATCH 04/40] fix outdated mention of black in dev_install.md address https://github.com/TorchSim/torch-sim/pull/264#discussion_r2384381351 --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- .github/release.yml | 4 ++-- docs/dev/dev_install.md | 10 +++++----- pyproject.toml | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0194cc5e0..48744c31c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,4 +11,4 @@ Before a pull request can be merged, the following items must be checked: * [ ] Run [ruff](https://beta.ruff.rs/docs/rules/#pydocstyle-d) on your code. * [ ] Tests have been added for any new functionality or bug fixes. -We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run `pip install pre-commit && pre-commit install` to install the hooks which will check your code before each commit. +We highly recommended installing the `prek` hooks running in CI locally to speedup the development process. Simply run `pip install prek && prek install` to install the hooks which will check your code before each commit. diff --git a/.github/release.yml b/.github/release.yml index 0b3044524..ca128cb1d 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -1,6 +1,6 @@ changelog: exclude: - authors: [github-actions, pre-commit-ci] + authors: [github-actions] categories: - title: πŸ’₯ Breaking Changes labels: [breaking] @@ -33,4 +33,4 @@ changelog: - title: 🧹 Linting labels: [linting] - title: πŸ€·β€β™‚οΈ Other Changes - labels: ["*"] + labels: ['*'] diff --git a/docs/dev/dev_install.md b/docs/dev/dev_install.md index a08aa1239..f4779c87d 100644 --- a/docs/dev/dev_install.md +++ b/docs/dev/dev_install.md @@ -20,19 +20,19 @@ cd torch-sim pip install . -e ``` -### Installing pre-commit +### Installing prek If you're planning on contributing to the torch-sim source, you should also install the developer requirements with: ```bash pip install -e . -pre-commit install -pre-commit run --all-files +prek install +prek run --all-files ``` -The `pre-commit` command will ensure that changes to the source code match the -TorchSim style guidelines by running code linters such as `black` and `ruff` automatically with each commit. +The `prek` command will ensure that changes to the source code match the +TorchSim style guidelines by running the `ruff` code linters and the `ty` type checker automatically with each commit. ## Running unit tests diff --git a/pyproject.toml b/pyproject.toml index dd61083f2..23ef8f88f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,7 @@ conflicts = [ ] [dependency-groups] -dev = ["pre-commit>=4.3.0", "ty>=0.0.1a20"] +dev = ["prek>=4.3.0", "ty>=0.0.1a20"] [tool.ty.rules] # TODO: Unable to work with **kwargs: https://github.com/astral-sh/ty/issues/247 From 49be1076b942c4a243b30ec3e9d1a7736f71ae12 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 16:17:22 -0700 Subject: [PATCH 05/40] clean up fairchem CI install, just install latest v2 from pypi --- .github/workflows/test.yml | 20 ++------------------ pyproject.toml | 2 +- torch_sim/runners.py | 2 +- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8b8626e65..0e59093cc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,14 +82,6 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - name: Check out fairchem repository - if: ${{ matrix.model.name == 'fairchem' }} - uses: actions/checkout@v4 - with: - repository: FAIR-Chem/fairchem - path: fairchem-repo - ref: fairchem_core-1.10.0 - - name: Set up Python uses: actions/setup-python@v5 with: @@ -103,20 +95,12 @@ jobs: env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - uv pip install huggingface_hub --system + uv pip install huggingface_hub fairchem-core --system if [ -n "$HF_TOKEN" ]; then - huggingface-cli login --token "$HF_TOKEN" + hf auth login --token "$HF_TOKEN" else echo "HF_TOKEN is not set. Skipping login." fi - if [ -f fairchem-repo/packages/requirements.txt ]; then - uv pip install -r fairchem-repo/packages/requirements.txt --system - fi - if [ -f fairchem-repo/packages/requirements-optional.txt ]; then - uv pip install -r fairchem-repo/packages/requirements-optional.txt --system - fi - uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system - uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system - name: Install torch_sim with model dependencies if: ${{ matrix.model.name != 'fairchem' }} diff --git a/pyproject.toml b/pyproject.toml index 23ef8f88f..4443fbffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ test = [ io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] mace = ["mace-torch>=0.3.12"] mattersim = ["mattersim>=0.1.2"] -metatomic = ["metatomic-torch>=0.1.4,<0.2", "metatrain[pet]==2025.10"] +metatomic = ["metatomic-torch>=0.1.1", "metatrain[pet]>=2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 2c5d7274d..2d5ea0dcf 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -451,7 +451,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 init_fn, initial_state, model, - init_kwargs=dict(cell_filter=cell_filter, **init_kwargs), + init_kwargs=dict(cell_filter=cell_filter, **init_kwargs or {}), max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, ) From b1dc332db7eaadba72c3d4447c0cfd4560e5aeab Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 17:51:28 -0700 Subject: [PATCH 06/40] address @CompRhys comments (thanks for the thorough and fast review!) --- examples/tutorials/low_level_tutorial.py | 4 +- tests/conftest.py | 13 ++---- tests/models/test_orb.py | 8 ++-- tests/models/test_soft_sphere.py | 52 ++++++++++++------------ 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 71049eb91..c12fa8937 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -22,9 +22,7 @@ """ ## Setting up the system -TorchSim's state aka `SimState` is a -import torch_sim as ts -class that contains the information of the +TorchSim's state aka `SimState` is a class that contains the information of the system like positions, cell, etc. of the system(s). All the models in the TorchSim package take in a `SimState` as an input and return the properties of the system(s). diff --git a/tests/conftest.py b/tests/conftest.py index 6e6f5ac6b..de97b021a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.mace import MaceModel DEVICE = torch.device("cpu") @@ -310,7 +309,7 @@ def mixed_double_sim_state( @pytest.fixture -def osn2_sim_state(ts_mace_mpa: MaceModel) -> ts.SimState: +def osn2_sim_state() -> ts.SimState: """Provides an initial SimState for rhombohedral OsN2.""" # For pymatgen Structure initialization from pymatgen.core import Lattice, Structure @@ -320,13 +319,11 @@ def osn2_sim_state(ts_mace_mpa: MaceModel) -> ts.SimState: species = ["Os", "N"] frac_coords = [[0.75, 0.7501, -0.25], [0, 0, 0]] # Slightly perturbed structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False) - return ts.initialize_state( - structure, dtype=ts_mace_mpa.dtype, device=ts_mace_mpa.device - ) + return ts.initialize_state(structure, dtype=DTYPE, device=DEVICE) @pytest.fixture -def distorted_fcc_al_conventional_sim_state(ts_mace_mpa: MaceModel) -> ts.SimState: +def distorted_fcc_al_conventional_sim_state() -> ts.SimState: """Initial SimState for a slightly distorted FCC Al conventional cell (4 atoms).""" # Create a standard 4-atom conventional FCC Al cell atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True) @@ -344,7 +341,5 @@ def distorted_fcc_al_conventional_sim_state(ts_mace_mpa: MaceModel) -> ts.SimSta positions += np_rng.normal(scale=0.01, size=positions.shape) atoms_fcc.set_positions(positions) - dtype = ts_mace_mpa.dtype - device = ts_mace_mpa.device # Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms) - return ts.io.atoms_to_state(atoms_fcc, device=device, dtype=dtype) + return ts.io.atoms_to_state(atoms_fcc, device=DEVICE, dtype=DTYPE) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index c6559fb1c..84f46c283 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -52,8 +52,8 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: model_fixture_name="orbv3_conservative_inf_omat_model", calculator_fixture_name="orbv3_conservative_inf_omat_calculator", sim_state_names=consistency_test_simstate_fixtures, - energy_rtol=1e-3, - energy_atol=1e-3, + energy_rtol=5e-5, + energy_atol=5e-5, ) test_orb_direct_consistency = make_model_calculator_consistency_test( @@ -61,8 +61,8 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: model_fixture_name="orbv3_direct_20_omat_model", calculator_fixture_name="orbv3_direct_20_omat_calculator", sim_state_names=consistency_test_simstate_fixtures, - energy_rtol=1e-3, - energy_atol=1e-3, + energy_rtol=5e-5, + energy_atol=5e-5, ) test_validate_conservative_model_outputs = make_validate_model_outputs_test( diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index 1f89b8f48..99c2ed649 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -4,7 +4,7 @@ import torch import torch_sim as ts -import torch_sim.models.soft_sphere as fss +import torch_sim.models.soft_sphere as ss from tests.conftest import DEVICE from torch_sim.models.interface import validate_model_outputs @@ -23,8 +23,8 @@ def models( "compute_stress": True, } - model_nl = fss.SoftSphereModel(use_neighbor_list=True, **calc_params) - model_direct = fss.SoftSphereModel(use_neighbor_list=False, **calc_params) + model_nl = ss.SoftSphereModel(use_neighbor_list=True, **calc_params) + model_direct = ss.SoftSphereModel(use_neighbor_list=False, **calc_params) return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) @@ -45,8 +45,8 @@ def models_with_per_atom( "per_atom_stresses": True, } - model_nl = fss.SoftSphereModel(use_neighbor_list=True, **calc_params) - model_direct = fss.SoftSphereModel(use_neighbor_list=False, **calc_params) + model_nl = ss.SoftSphereModel(use_neighbor_list=True, **calc_params) + model_direct = ss.SoftSphereModel(use_neighbor_list=False, **calc_params) return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) @@ -122,8 +122,8 @@ def test_validate_model_outputs() -> None: "compute_stress": True, } - model_nl = fss.SoftSphereModel(use_neighbor_list=True, **model_params) - model_direct = fss.SoftSphereModel(use_neighbor_list=False, **model_params) + model_nl = ss.SoftSphereModel(use_neighbor_list=True, **model_params) + model_direct = ss.SoftSphereModel(use_neighbor_list=False, **model_params) for out in (model_nl, model_direct): validate_model_outputs(out, DEVICE, torch.float64) @@ -166,7 +166,7 @@ def test_soft_sphere_pair_single( distance: float, sigma: float, epsilon: float, alpha: float, expected: float ) -> None: """Test the soft sphere pair calculation for single values.""" - energy = fss.soft_sphere_pair( + energy = ss.soft_sphere_pair( torch.tensor(distance), torch.tensor(sigma), torch.tensor(epsilon), @@ -177,13 +177,13 @@ def test_soft_sphere_pair_single( def test_model_initialization_defaults() -> None: """Test initialization with default parameters.""" - model = fss.SoftSphereModel() + model = ss.SoftSphereModel() # Check default parameters are used - assert torch.allclose(model.sigma, fss.DEFAULT_SIGMA) - assert torch.allclose(model.epsilon, fss.DEFAULT_EPSILON) - assert torch.allclose(model.alpha, fss.DEFAULT_ALPHA) - assert torch.allclose(model.cutoff, fss.DEFAULT_SIGMA) # Default cutoff is sigma + assert torch.allclose(model.sigma, ss.DEFAULT_SIGMA) + assert torch.allclose(model.epsilon, ss.DEFAULT_EPSILON) + assert torch.allclose(model.alpha, ss.DEFAULT_ALPHA) + assert torch.allclose(model.cutoff, ss.DEFAULT_SIGMA) # Default cutoff is sigma @pytest.mark.parametrize( @@ -200,7 +200,7 @@ def test_model_initialization_custom_params( ) -> None: """Test initialization with custom parameters.""" params = {param_name: param_value, "dtype": expected_dtype} - model = fss.SoftSphereModel(**params) + model = ss.SoftSphereModel(**params) param_tensor = getattr(model, param_name) assert torch.allclose(param_tensor, torch.tensor(param_value, dtype=expected_dtype)) @@ -219,7 +219,7 @@ def test_model_initialization_custom_params( ) def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) -> None: """Test initialization with custom flags.""" - model = fss.SoftSphereModel(**{flag_name: flag_value}) + model = ss.SoftSphereModel(**{flag_name: flag_value}) # For compute_forces and compute_stress, we need to check the private attributes if flag_name == "compute_forces": @@ -233,7 +233,7 @@ def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_model_dtype(dtype: torch.dtype) -> None: """Test model with different dtypes.""" - model = fss.SoftSphereModel(dtype=dtype) + model = ss.SoftSphereModel(dtype=dtype) assert model.sigma.dtype == dtype assert model.epsilon.dtype == dtype @@ -246,7 +246,7 @@ def test_multispecies_initialization_defaults() -> None: # Create with minimal parameters species = torch.tensor([0, 1], dtype=torch.long) dtype = torch.float32 - model = fss.SoftSphereMultiModel(species=species, dtype=dtype) + model = ss.SoftSphereMultiModel(species=species, dtype=dtype) # Check matrices are created with defaults assert model.sigma_matrix.shape == (2, 2) @@ -255,12 +255,12 @@ def test_multispecies_initialization_defaults() -> None: # Check default values ones = torch.ones(2, 2, dtype=dtype) - assert torch.allclose(model.sigma_matrix, fss.DEFAULT_SIGMA * ones) - assert torch.allclose(model.epsilon_matrix, fss.DEFAULT_EPSILON * ones) - assert torch.allclose(model.alpha_matrix, fss.DEFAULT_ALPHA * ones) + assert torch.allclose(model.sigma_matrix, ss.DEFAULT_SIGMA * ones) + assert torch.allclose(model.epsilon_matrix, ss.DEFAULT_EPSILON * ones) + assert torch.allclose(model.alpha_matrix, ss.DEFAULT_ALPHA * ones) # Check cutoff is max sigma - assert model.cutoff.item() == fss.DEFAULT_SIGMA.item() + assert model.cutoff.item() == ss.DEFAULT_SIGMA.item() def test_multispecies_initialization_custom() -> None: @@ -270,7 +270,7 @@ def test_multispecies_initialization_custom() -> None: epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]], dtype=torch.float64) alpha_matrix = torch.tensor([[2.0, 3.0], [3.0, 4.0]], dtype=torch.float64) - model = fss.SoftSphereMultiModel( + model = ss.SoftSphereMultiModel( species=species, sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, @@ -298,7 +298,7 @@ def test_multispecies_matrix_validation() -> None: # Should raise ValueError due to matrix size mismatch with pytest.raises(ValueError, match="sigma_matrix must have shape"): - fss.SoftSphereMultiModel( + ss.SoftSphereMultiModel( species=species, sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, @@ -332,7 +332,7 @@ def test_matrix_symmetry_validation(matrix_name: str, matrix: torch.Tensor) -> N # Should raise ValueError due to asymmetric matrix with pytest.raises(ValueError, match="is not symmetric"): - fss.SoftSphereMultiModel(**params) + ss.SoftSphereMultiModel(**params) def test_multispecies_cutoff_default() -> None: @@ -341,7 +341,7 @@ def test_multispecies_cutoff_default() -> None: species = torch.tensor([0, 1, 2], dtype=torch.long) sigma_matrix = torch.tensor([[1.0, 1.5, 2.0], [1.5, 2.0, 2.5], [2.0, 2.5, 3.0]]) - model = fss.SoftSphereMultiModel(species=species, sigma_matrix=sigma_matrix) + model = ss.SoftSphereMultiModel(species=species, sigma_matrix=sigma_matrix) # Cutoff should default to max value in sigma_matrix assert model.cutoff.item() == 3.0 @@ -364,7 +364,7 @@ def test_multispecies_model_flags(*, flag_name: str, flag_value: bool) -> None: """Test flags of the SoftSphereMultiModel.""" species = torch.tensor([0, 1], dtype=torch.long) - model = fss.SoftSphereMultiModel(species=species, **{flag_name: flag_value}) + model = ss.SoftSphereMultiModel(species=species, **{flag_name: flag_value}) # For SoftSphereMultiModel, we don't need to convert attribute names # as it uses public attribute names for all flags From 38befe2de165892383359f763c8920b4c2cff865 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 18:01:59 -0700 Subject: [PATCH 07/40] fix tests and add rng kwarg to ts.swap_mc_step used in tutorials/hybrid_swap_tutorial.py --- .github/workflows/test.yml | 10 +++------ examples/tutorials/hybrid_swap_tutorial.py | 6 ++++- tests/models/test_metatomic.py | 12 +++------- torch_sim/monte_carlo.py | 26 +++++++++++++--------- torch_sim/optimizers/gradient_descent.py | 4 ++-- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0e59093cc..2a2b04922 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -90,25 +90,21 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 - - name: Install fairchem repository and dependencies + - name: Login to @janosh's HuggingFace account to access fairchem models if: ${{ matrix.model.name == 'fairchem' }} env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - uv pip install huggingface_hub fairchem-core --system + uv pip install huggingface_hub --system if [ -n "$HF_TOKEN" ]; then hf auth login --token "$HF_TOKEN" else echo "HF_TOKEN is not set. Skipping login." fi - - name: Install torch_sim with model dependencies - if: ${{ matrix.model.name != 'fairchem' }} - run: | - uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system - - name: Run ${{ matrix.model.test_path }} tests run: | + uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} - name: Upload coverage to Codecov diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 74b8f8e5d..93c5942cb 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -149,12 +149,16 @@ class HybridSwapMCState(SwapMCState, MDState): - Make larger compositional changes through swap moves """ +# Create a persistent PRNG for reproducibility across the whole run +rng = torch.Generator(device=mace_model.device) +rng.manual_seed(seed=42) + # %% Run the hybrid simulation n_steps = 100 for step in range(n_steps): if step % 10 == 0: # Attempt swap Monte Carlo move hybrid_state = ts.swap_mc_step( - model=mace_model, state=hybrid_state, kT=kT, seed=42 + step + model=mace_model, state=hybrid_state, kT=kT, rng=rng ) else: # Perform MD step hybrid_state = ts.nvt_langevin_update( diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index ca360e015..fe6566d3f 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -21,22 +21,16 @@ @pytest.fixture def metatomic_calculator(): """Load a pretrained metatomic model for testing.""" + model_url = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" return ase_calculator.MetatomicCalculator( - model=load_model( - "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" - ).export(), - device=DEVICE, - dtype=torch.float32, + model=load_model(model_url).export(), device=DEVICE ) @pytest.fixture def metatomic_model() -> MetatomicModel: """Create an MetatomicModel wrapper for the pretrained model.""" - return MetatomicModel( - model="pet-mad", - device=DEVICE, - ) + return MetatomicModel(model="pet-mad", device=DEVICE) def test_metatomic_initialization() -> None: diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 7db5213d7..9ba66a9ed 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -142,7 +142,7 @@ def metropolis_criterion( energy_new: torch.Tensor, energy_old: torch.Tensor, kT: float, - generator: torch.Generator | None = None, + rng: torch.Generator | None = None, ) -> torch.Tensor: """Apply the Metropolis acceptance criterion for Monte Carlo moves. @@ -153,7 +153,7 @@ def metropolis_criterion( energy_new (torch.Tensor): New energy after proposed move of shape [batch_size] energy_old (torch.Tensor): Old energy before proposed move of shape [batch_size] kT (float): Temperature of the system in energy units - generator (torch.Generator | None, optional): Random number generator for + rng (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: @@ -171,7 +171,7 @@ def metropolis_criterion( # Generate random numbers between 0 and 1 using the generator random_values = torch.rand( - p_acceptance.shape, generator=generator, device=p_acceptance.device + p_acceptance.shape, generator=rng, device=p_acceptance.device ) # Accept if random value < acceptance probability @@ -225,6 +225,7 @@ def swap_mc_step( *, kT: float, seed: int | None = None, + rng: torch.Generator | None = None, ) -> SwapMCState: """Perform a single swap Monte Carlo step. @@ -236,7 +237,11 @@ def swap_mc_step( 'energy' as a key state: The current Monte Carlo state kT: Temperature parameter in energy units - seed: Seed for the random number generator. Defaults to None. + seed: (Deprecated) Seed for the random number generator. If provided and + `generator` is None, a temporary generator seeded with this value will + be used. + rng: Optional torch.Generator to drive all randomness for this step. + Prefer passing a persistent generator across steps for reproducibility. Returns: SwapMCState: Updated Monte Carlo state after applying the step @@ -245,12 +250,13 @@ def swap_mc_step( The function handles batched systems and ensures that swaps only occur within the same system. """ - generator = None - if seed is not None: - generator = torch.Generator(device=model.device) - generator.manual_seed(seed) + # Prefer explicit generator if provided; otherwise build one from seed + _rng = rng + if _rng is None and seed is not None: + _rng = torch.Generator(device=model.device) + _rng.manual_seed(seed) - swaps = generate_swaps(state, generator=generator) + swaps = generate_swaps(state, generator=_rng) permutation = swaps_to_permutation(swaps, state.n_atoms) @@ -263,7 +269,7 @@ def swap_mc_step( model_output = model(state) energies_new = model_output["energy"] - accepted = metropolis_criterion(energies_new, energies_old, kT, generator=generator) + accepted = metropolis_criterion(energies_new, energies_old, kT, generator=_rng) rejected_swaps = swaps[~accepted] reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) state.positions = state.positions[reverse_rejected_swaps] diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 8f3678045..252cc88ac 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -119,8 +119,8 @@ def gradient_descent_step( model_output = model(state) state.forces = model_output["forces"] state.energy = model_output["energy"] - if stress := model_output.get("stress"): - state.stress = stress + if "stress" in model_output: + state.stress = model_output["stress"] # Update cell forces if isinstance(state, CellOptimState): From 4059d36c1e1418151b46262802ac27d6c56bc0b7 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 18:11:00 -0700 Subject: [PATCH 08/40] add fairchem-core to optional deps --- pyproject.toml | 1 + tests/models/test_metatomic.py | 2 ++ torch_sim/monte_carlo.py | 14 ++++++-------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4443fbffc..66e7f9971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] nequip = ["nequip>=0.12.0"] +fairchem = ["fairchem-core>=2.7"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index fe6566d3f..a4acaa67c 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -53,4 +53,6 @@ def test_metatomic_initialization() -> None: test_metatomic_model_outputs = make_validate_model_outputs_test( model_fixture_name="metatomic_model", + dtype=torch.float32, + device=DEVICE, ) diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 9ba66a9ed..115688cc7 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -45,9 +45,7 @@ class SwapMCState(SimState): _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 -def generate_swaps( - state: SimState, generator: torch.Generator | None = None -) -> torch.Tensor: +def generate_swaps(state: SimState, rng: torch.Generator | None = None) -> torch.Tensor: """Generate atom swaps for a given batched system. Generates proposed swaps between atoms of different types within the same system. @@ -56,7 +54,7 @@ def generate_swaps( Args: state (SimState): The simulation state - generator (torch.Generator | None, optional): Random number generator for + rng (torch.Generator | None, optional): Random number generator for reproducibility. Defaults to None. Returns: @@ -86,7 +84,7 @@ def generate_swaps( system_lengths_expanded = system_lengths.unsqueeze(1).expand(n_systems, max_length) weights = (range_tensor < system_lengths_expanded).float() - first_index = torch.multinomial(weights, 1, replacement=False, generator=generator) + first_index = torch.multinomial(weights, 1, replacement=False, generator=rng) # Process each system - we need this loop because of ragged systems system_starts = system_lengths.cumsum(dim=0) - system_lengths[0] @@ -106,7 +104,7 @@ def generate_swaps( # Zero out weights for same-type atoms (accounting for padding) weights[sys_idx, : len(same_type)][same_type] = 0.0 - second_index = torch.multinomial(weights, 1, replacement=False, generator=generator) + second_index = torch.multinomial(weights, 1, replacement=False, generator=rng) zeroed_swaps = torch.concatenate([first_index, second_index], dim=1) return zeroed_swaps + (system_lengths.cumsum(dim=0) - system_lengths[0]).unsqueeze(1) @@ -256,7 +254,7 @@ def swap_mc_step( _rng = torch.Generator(device=model.device) _rng.manual_seed(seed) - swaps = generate_swaps(state, generator=_rng) + swaps = generate_swaps(state, rng=_rng) permutation = swaps_to_permutation(swaps, state.n_atoms) @@ -269,7 +267,7 @@ def swap_mc_step( model_output = model(state) energies_new = model_output["energy"] - accepted = metropolis_criterion(energies_new, energies_old, kT, generator=_rng) + accepted = metropolis_criterion(energies_new, energies_old, kT, rng=_rng) rejected_swaps = swaps[~accepted] reverse_rejected_swaps = swaps_to_permutation(rejected_swaps, state.n_atoms) state.positions = state.positions[reverse_rejected_swaps] From d18cd80e2d61326b43b281d520731e91c6434d0e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 18:26:41 -0700 Subject: [PATCH 09/40] port hugginface login from CI to examples/scripts/1_Introduction/1.3_Fairchem.py --- .../scripts/1_Introduction/1.3_Fairchem.py | 30 +++++++++--- tests/models/test_metatomic.py | 2 + tests/test_monte_carlo.py | 49 +++++++++---------- 3 files changed, 49 insertions(+), 32 deletions(-) diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_Fairchem.py index 23175c033..06974026e 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_Fairchem.py @@ -1,8 +1,11 @@ """Minimal FairChem example demonstrating batching.""" # /// script -# dependencies = ["fairchem-core>=2.2.0"] +# dependencies = ["fairchem-core>=2.2.0", "huggingface_hub"] # /// + +import os + import torch from ase.build import bulk @@ -10,6 +13,19 @@ from torch_sim.models.fairchem import FairChemModel +# Optional Hugging Face login if HF_TOKEN is available (for private model access) +try: + from huggingface_hub import login as hf_login # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dependency + hf_login = None # type: ignore[assignment] + +hf_token = os.environ.get("HF_TOKEN") +if hf_token and hf_login is not None: + hf_login(token=hf_token) +else: + print("Need to login to HuggingFace to access fairchem models") + raise SystemExit(1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 @@ -32,16 +48,16 @@ print(results["energy"].shape) print(results["forces"].shape) -if stress := results.get("stress"): - print(stress.shape) +if "stress" in results: + print(results["stress"].shape) print(f"Energy: {results['energy']}") print(f"Forces: {results['forces']}") -if stress := results.get("stress"): - print(f"{stress=}") +if "stress" in results: + print(f"{results['stress']=}") # Check if the energy, forces, and stress are the same for the Si system across the batch print(torch.max(torch.abs(results["energy"][0] - results["energy"][1]))) print(torch.max(torch.abs(results["forces"][0] - results["forces"][1]))) -if stress := results.get("stress"): - print(torch.max(torch.abs(stress[0] - stress[1]))) +if "stress" in results: + print(torch.max(torch.abs(results["stress"][0] - results["stress"][1]))) diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index a4acaa67c..c0dbbbd1c 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -49,6 +49,8 @@ def test_metatomic_initialization() -> None: calculator_fixture_name="metatomic_calculator", sim_state_names=consistency_test_simstate_fixtures, energy_atol=5e-5, + dtype=torch.float32, + device=DEVICE, ) test_metatomic_model_outputs = make_validate_model_outputs_test( diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index 2414e8b5b..c88951318 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -37,11 +37,11 @@ def batched_diverse_state() -> ts.SimState: @pytest.mark.parametrize("use_generator", [True, False]) def test_generate_swaps(batched_diverse_state: ts.SimState, *, use_generator: bool): """Test swap generation with and without generator.""" - generator = torch.Generator(device=DEVICE) if use_generator else None - if generator: - generator.manual_seed(42) + rng = torch.Generator(device=DEVICE) if use_generator else None + if rng: + rng.manual_seed(42) - swaps = generate_swaps(batched_diverse_state, generator=generator) + swaps = generate_swaps(batched_diverse_state, rng=rng) # Basic validation assert isinstance(swaps, torch.Tensor) @@ -59,9 +59,9 @@ def test_generate_swaps(batched_diverse_state: ts.SimState, *, use_generator: bo assert atomic_numbers[swap[0]] != atomic_numbers[swap[1]] # Test reproducibility with generator - if use_generator and generator is not None: - generator.manual_seed(42) - swaps2 = generate_swaps(batched_diverse_state, generator=generator) + if use_generator and rng is not None: + rng.manual_seed(42) + swaps2 = generate_swaps(batched_diverse_state, rng=rng) assert torch.equal(swaps, swaps2) @@ -69,15 +69,14 @@ def test_generate_swaps(batched_diverse_state: ts.SimState, *, use_generator: bo def test_swaps_to_permutation(batched_diverse_state: ts.SimState, *, n_swaps: int): """Test permutation generation with different numbers of swaps.""" n_atoms = batched_diverse_state.n_atoms - generator = torch.Generator(device=DEVICE) - generator.manual_seed(42) + rng = torch.Generator(device=DEVICE) + rng.manual_seed(42) if n_swaps == 0: combined_swaps = torch.empty((0, 2), dtype=torch.long, device=DEVICE) else: all_swaps = [ - generate_swaps(batched_diverse_state, generator=generator) - for _ in range(n_swaps) + generate_swaps(batched_diverse_state, rng=rng) for _ in range(n_swaps) ] combined_swaps = torch.cat(all_swaps, dim=0) @@ -122,12 +121,10 @@ def test_metropolis_criterion( assert abs(actual_rate - expected_rate) < 0.1 else: # Statistical test - generator = torch.Generator(device=DEVICE) - generator.manual_seed(42) + rng = torch.Generator(device=DEVICE) + rng.manual_seed(42) total_accepted = sum( - metropolis_criterion( - energy_new_tensor, energy_old_tensor, kT, generator=generator - ) + metropolis_criterion(energy_new_tensor, energy_old_tensor, kT, rng=rng) .sum() .item() for _ in range(1000) @@ -141,14 +138,14 @@ def test_metropolis_criterion_randomness(): energy_old = torch.tensor([10.0, 20.0], device=DEVICE) energy_new = torch.tensor([11.0, 21.0], device=DEVICE) # ~37% acceptance - gen1 = torch.Generator(device=DEVICE) - gen1.manual_seed(42) - gen2 = torch.Generator(device=DEVICE) - gen2.manual_seed(43) + rng1 = torch.Generator(device=DEVICE) + rng1.manual_seed(42) + rng2 = torch.Generator(device=DEVICE) + rng2.manual_seed(43) - accepted1 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=gen1) - accepted2 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=gen2) - accepted3 = metropolis_criterion(energy_new, energy_old, kT=1.0, generator=None) + accepted1 = metropolis_criterion(energy_new, energy_old, kT=1.0, rng=rng1) + accepted2 = metropolis_criterion(energy_new, energy_old, kT=1.0, rng=rng2) + accepted3 = metropolis_criterion(energy_new, energy_old, kT=1.0, rng=None) different_results = not torch.equal(accepted1, accepted2) or not torch.equal( accepted1, accepted3 @@ -166,6 +163,8 @@ def test_monte_carlo_integration( ): """Test the complete Monte Carlo workflow.""" # Initialize + rng = torch.Generator(device=DEVICE) + rng.manual_seed(42) mc_state = swap_mc_init(model=lj_model, state=batched_diverse_state) assert isinstance(mc_state, SwapMCState) assert mc_state.energy.shape == (batched_diverse_state.n_systems,) @@ -174,8 +173,8 @@ def test_monte_carlo_integration( assert torch.equal(mc_state.last_permutation, expected_identity) # Run steps - for step in range(n_steps): - mc_state = swap_mc_step(model=lj_model, state=mc_state, kT=kT, seed=42 + step) + for _step in range(n_steps): + mc_state = swap_mc_step(model=lj_model, state=mc_state, kT=kT, rng=rng) assert isinstance(mc_state, SwapMCState) # Verify conservation properties From a078d4476b25eb7901a94a16e2bbe399ea9ffb6f Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 18:38:19 -0700 Subject: [PATCH 10/40] consistently name integrator step functions (nve|nvt|npt)_update -> (nve|nvt|npt)_step --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- .../3.11_Lennard_Jones_NPT_Langevin.py | 2 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 4 ++-- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 2 +- .../3_Dynamics/3.1_Lennard_Jones_NVE.py | 2 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 2 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 4 ++-- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 4 ++-- .../3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py | 2 +- .../3.6_MACE_NVT_Nose_Hoover_temp_profile.py | 2 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 2 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 4 ++-- .../3.9_MACE_NVT_staggered_stress.py | 2 +- .../4_High_level_api/4.2_auto_batching_api.py | 2 +- .../5_Workflow/5.1_a2c_silicon_batched.py | 2 +- .../7_Others/7.4_Velocity_AutoCorrelation.py | 2 +- .../7_Others/7.7_Heat_flux_and_kappa.py | 4 ++-- examples/tutorials/autobatching_tutorial.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- examples/tutorials/low_level_tutorial.py | 2 +- tests/test_integrators.py | 12 +++++----- torch_sim/__init__.py | 10 ++++----- torch_sim/integrators/__init__.py | 22 +++++++++---------- torch_sim/integrators/npt.py | 4 ++-- torch_sim/integrators/nve.py | 2 +- torch_sim/integrators/nvt.py | 4 ++-- torch_sim/optimizers/cell_filters.py | 8 +++---- 27 files changed, 56 insertions(+), 56 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 3aaaa420c..f9b156c13 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -97,6 +97,6 @@ class HybridSwapMCState(ts.SwapMCState, MDState): model=model, state=hybrid_state, kT=kT, seed=42 + step ) else: - hybrid_state = ts.nvt_langevin_update( + hybrid_state = ts.nvt_langevin_step( model=model, state=hybrid_state, dt=dt, kT=torch.tensor(kT) ) diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 1a1c47939..281cfb295 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -131,7 +131,7 @@ f"{pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = ts.npt_langevin_update( + state = ts.npt_langevin_step( model=model, state=state, dt=dt, diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index 3bf1b6ee0..587322630 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -69,7 +69,7 @@ ) invariant = float(ts.nvt_nose_hoover_invariant(state, kT=kT)) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") - state = ts.nvt_nose_hoover_update(model=model, state=state, dt=dt, kT=kT) + state = ts.nvt_nose_hoover_step(model=model, state=state, dt=dt, kT=kT) state = ts.npt_langevin_init(model=model, state=state, kT=kT, dt=dt, seed=1) @@ -101,7 +101,7 @@ f"pressure: {pressure:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = ts.npt_langevin_update( + state = ts.npt_langevin_step( model=model, state=state, dt=dt, diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index bd8228807..df1ee784f 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -70,7 +70,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = ts.nve_update(model=model, state=state, dt=dt) + state = ts.nve_step(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py index 54a3fab17..839d68d04 100644 --- a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py +++ b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py @@ -115,7 +115,7 @@ print(f"{step=}: Total energy: {total_energy.item():.4f}") # Update state using NVE integrator - state = ts.nve_update(model=model, state=state, dt=dt) + state = ts.nve_step(model=model, state=state, dt=dt) final_total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index d32aa23c9..0903c46f6 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -78,7 +78,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = ts.nve_update(model=model, state=state, dt=dt) + state = ts.nve_step(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index f926df135..e3e7247ad 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -11,7 +11,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nve_init, nve_update +from torch_sim.integrators import nve_init, nve_step from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.units import MetalUnits as Units @@ -69,7 +69,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = nve_update(model=model, state=state, dt=dt) + state = nve_step(model=model, state=state, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index 45dcff874..d270bc03f 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -10,7 +10,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts -from torch_sim.integrators import nvt_langevin_init, nvt_langevin_update +from torch_sim.integrators import nvt_langevin_init, nvt_langevin_step from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.units import MetalUnits as Units @@ -75,7 +75,7 @@ / Units.temperature ) print(f"{step=}: Temperature: {temp.item():.4f}") - state = nvt_langevin_update(model=model, state=state, dt=dt, kT=kT, gamma=gamma) + state = nvt_langevin_step(model=model, state=state, dt=dt, kT=kT, gamma=gamma) final_temp = ( ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) diff --git a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py index b9fd7864e..60a13b161 100644 --- a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py @@ -68,7 +68,7 @@ ) invariant = float(ts.nvt_nose_hoover_invariant(state, kT=kT)) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") - state = ts.nvt_nose_hoover_update(model=model, state=state, dt=dt, kT=kT) + state = ts.nvt_nose_hoover_step(model=model, state=state, dt=dt, kT=kT) final_temp = ( ts.calc_kT(masses=state.masses, momenta=state.momenta, system_idx=state.system_idx) diff --git a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py index d13daa8ca..52279c3a8 100644 --- a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py +++ b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py @@ -174,7 +174,7 @@ def get_kT( print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") # Update simulation state - state = ts.nvt_nose_hoover_update( + state = ts.nvt_nose_hoover_step( model=model, state=state, dt=dt, kT=current_kT * Units.temperature ) diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index c1d4a5db9..f64277195 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -140,7 +140,7 @@ f"{invariant=:.4f}, {pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = ts.npt_nose_hoover_update( + state = ts.npt_nose_hoover_step( model=model, state=state, dt=dt, kT=kT, external_pressure=target_pressure ) diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index 97d169de7..0aa9e4f87 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -71,7 +71,7 @@ ts.npt_nose_hoover_invariant(state, kT=kT, external_pressure=target_pressure) ) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") - state = ts.npt_nose_hoover_update( + state = ts.npt_nose_hoover_step( model=model, state=state, dt=torch.tensor(dt), @@ -104,7 +104,7 @@ f"{pressure=:.4f}, " f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) - state = ts.npt_nose_hoover_update( + state = ts.npt_nose_hoover_step( model=model, state=state, dt=torch.tensor(dt), diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index 6f5ae7b26..7f993362d 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -71,7 +71,7 @@ invariant = float(kinetic_energy + state.energy) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}") - state = ts.nvt_langevin_update(model=model, state=state, dt=torch.tensor(dt), kT=kT) + state = ts.nvt_langevin_step(model=model, state=state, dt=torch.tensor(dt), kT=kT) if step % 10 == 0: results = model(state) stress[step // 10] = results["stress"] diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index b3336a1cc..71005a00d 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -116,6 +116,6 @@ finished_states: list[ts.SimState] = [] for batch, _indices in batcher: for _ in range(100): - batch = ts.nvt_langevin_update(model=mace_model, state=batch) + batch = ts.nvt_langevin_step(model=mace_model, state=batch) finished_states.extend(batch.split()) diff --git a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py index 9aae48b09..cf31d5914 100644 --- a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py +++ b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py @@ -142,7 +142,7 @@ def step_fn( state, kT=torch.tensor(current_temp * Units.temperature, device=device, dtype=dtype), ).item() - state = ts.nvt_nose_hoover_update( + state = ts.nvt_nose_hoover_step( model=model, state=state, dt=dt, diff --git a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py index 50bb17a70..b58a1aad2 100644 --- a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py +++ b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py @@ -93,7 +93,7 @@ def main() -> None: num_steps = 15000 # NOTE: short run for step in range(num_steps): - state = ts.nve_update(model=lj_model, state=state, dt=dt) # type: ignore[call-arg] + state = ts.nve_step(model=lj_model, state=state, dt=dt) # type: ignore[call-arg] reporter.report(state, step) reporter.close() diff --git a/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py index bd2d0ff53..f75beecb5 100644 --- a/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py +++ b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py @@ -66,7 +66,7 @@ heat_flux = torch.zeros((num_steps_equilibration, 3), device=device, dtype=dtype) for step in range(num_steps_equilibration): - state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) results = lj_model(state) J = ts.quantities.calc_heat_flux( momenta=state.momenta, @@ -103,7 +103,7 @@ # Short production run for step in range(num_steps_production): - state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) reporter.report(state, step) if step % 1000 == 0: print(f"Step {step} | {state.energy.item():.4f} eV") diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index 317567dcd..b0eedb90d 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -195,7 +195,7 @@ def process_batch(batch): for batch, _indices in batcher: # Run 5 steps of NVT dynamics for _ in range(5): - batch = ts.nvt_langevin_update(mace_model, batch, dt=0.001, kT=0.01) + batch = ts.nvt_langevin_step(mace_model, batch, dt=0.001, kT=0.01) finished_states.append(batch) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 93c5942cb..c59072e0f 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -161,7 +161,7 @@ class HybridSwapMCState(SwapMCState, MDState): model=mace_model, state=hybrid_state, kT=kT, rng=rng ) else: # Perform MD step - hybrid_state = ts.nvt_langevin_update( + hybrid_state = ts.nvt_langevin_step( model=mace_model, state=hybrid_state, dt=0.002, kT=kT ) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index c12fa8937..68e7c7a39 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -209,7 +209,7 @@ initial_kT = kT for step in range(30): current_kT = initial_kT * (1 + step / 30) - state = ts.nvt_langevin_update( + state = ts.nvt_langevin_step( model=model, state=state, dt=dt, kT=current_kT, gamma=gamma ) if step % 5 == 0: diff --git a/tests/test_integrators.py b/tests/test_integrators.py index fd39ec58e..aad4ed28b 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -97,7 +97,7 @@ def test_npt_langevin( energies = [] temperatures = [] for _step in range(n_steps): - state = ts.npt_langevin_update( + state = ts.npt_langevin_step( model=lj_model, state=state, dt=dt, @@ -168,7 +168,7 @@ def test_npt_langevin_multi_kt( energies = [] temperatures = [] for _step in range(n_steps): - state = ts.npt_langevin_update( + state = ts.npt_langevin_step( model=lj_model, state=state, dt=dt, @@ -214,7 +214,7 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo energies = [] temperatures = [] for _step in range(n_steps): - state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) # Calculate instantaneous temperature from kinetic energy temp = ts.calc_kT( @@ -270,7 +270,7 @@ def test_nvt_langevin_multi_kt( energies = [] temperatures = [] for _step in range(n_steps): - state = ts.nvt_langevin_update(model=lj_model, state=state, dt=dt, kT=kT) + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) # Calculate instantaneous temperature from kinetic energy temp = ts.calc_kT( @@ -306,7 +306,7 @@ def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): # Run dynamics for several steps energies = [] for _step in range(n_steps): - state = ts.nve_update(model=lj_model, state=state, dt=dt) + state = ts.nve_step(model=lj_model, state=state, dt=dt) energies.append(state.energy) @@ -352,7 +352,7 @@ def test_compare_single_vs_batched_integrators( state.momenta = torch.zeros_like(state.momenta) # Start from rest for _step in range(n_steps): - state = ts.nve_update(model=lj_model, state=state, dt=dt) + state = ts.nve_step(model=lj_model, state=state, dt=dt) final_states[state_name] = state diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index ff8e40c93..5ebbe10bd 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -27,21 +27,21 @@ MdFlavor, NVTNoseHooverState, nve_init, - nve_update, + nve_step, nvt_langevin_init, - nvt_langevin_update, + nvt_langevin_step, nvt_nose_hoover_init, nvt_nose_hoover_invariant, - nvt_nose_hoover_update, + nvt_nose_hoover_step, ) from torch_sim.integrators.npt import ( NPTLangevinState, NPTNoseHooverState, npt_langevin_init, - npt_langevin_update, + npt_langevin_step, npt_nose_hoover_init, npt_nose_hoover_invariant, - npt_nose_hoover_update, + npt_nose_hoover_step, ) from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step from torch_sim.optimizers import ( diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index f996d3778..ca56404be 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -9,7 +9,7 @@ >>> import torch_sim as ts >>> state = ts.nvt_langevin_init(model, initial_state, kT=300.0 * units.temperature) >>> for _ in range(1000): - ... state = ts.nvt_langevin_update( + ... state = ts.nvt_langevin_step( ... model, state, dt=1e-3 * units.time, kT=300.0 * units.temperature ... ) @@ -30,19 +30,19 @@ NPTLangevinState, NPTNoseHooverState, npt_langevin_init, - npt_langevin_update, + npt_langevin_step, npt_nose_hoover_init, npt_nose_hoover_invariant, - npt_nose_hoover_update, + npt_nose_hoover_step, ) -from .nve import nve_init, nve_update +from .nve import nve_init, nve_step from .nvt import ( NVTNoseHooverState, nvt_langevin_init, - nvt_langevin_update, + nvt_langevin_step, nvt_nose_hoover_init, nvt_nose_hoover_invariant, - nvt_nose_hoover_update, + nvt_nose_hoover_step, ) @@ -60,9 +60,9 @@ class MdFlavor(StrEnum): INTEGRATOR_REGISTRY: Final[ dict[MdFlavor, tuple[Callable[..., Any], Callable[..., Any]]] ] = { - MdFlavor.nve: (nve_init, nve_update), - MdFlavor.nvt_langevin: (nvt_langevin_init, nvt_langevin_update), - MdFlavor.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_update), - MdFlavor.npt_langevin: (npt_langevin_init, npt_langevin_update), - MdFlavor.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_update), + MdFlavor.nve: (nve_init, nve_step), + MdFlavor.nvt_langevin: (nvt_langevin_init, nvt_langevin_step), + MdFlavor.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), + MdFlavor.npt_langevin: (npt_langevin_init, npt_langevin_step), + MdFlavor.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_step), } diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index fb4f6d914..86bebc616 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -642,7 +642,7 @@ def npt_langevin_init( ) -def npt_langevin_update( +def npt_langevin_step( model: ModelInterface, state: NPTLangevinState, *, @@ -1454,7 +1454,7 @@ def npt_nose_hoover_init( return npt_state -def npt_nose_hoover_update( +def npt_nose_hoover_step( model: ModelInterface, state: NPTNoseHooverState, *, diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 49509567e..53b5e31c9 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -70,7 +70,7 @@ def nve_init( ) -def nve_update( +def nve_step( model: ModelInterface, state: MDState, *, dt: torch.Tensor, **_kwargs: Any ) -> MDState: """Perform one complete NVE (microcanonical) integration step. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 59ef2b8fd..1c64e1d7f 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -132,7 +132,7 @@ def nvt_langevin_init( ) -def nvt_langevin_update( +def nvt_langevin_step( model: ModelInterface, state: MDState, *, @@ -330,7 +330,7 @@ def nvt_nose_hoover_init( ) -def nvt_nose_hoover_update( +def nvt_nose_hoover_step( model: ModelInterface, state: NVTNoseHooverState, *, diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 0aaa8fe28..653fc5a81 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -199,7 +199,7 @@ class CellFilter(StrEnum): # Filter type definitions for convenience -def unit_cell_update[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: +def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: """Update cell using unit cell approach.""" if isinstance(cell_lr, (int, float)): cell_lr = torch.full( @@ -228,7 +228,7 @@ def unit_cell_update[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) - state.cell_positions = cell_positions_new -def frechet_cell_update[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: +def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> None: """Update cell using frechet approach.""" if isinstance(cell_lr, (int, float)): cell_lr = torch.full( @@ -315,8 +315,8 @@ def compute_cell_forces[T: AnyCellState]( CellFilterFuncs = tuple[Callable[..., None], Callable[..., None]] # (init_fn, update_fn) CELL_FILTER_REGISTRY: dict[CellFilter, CellFilterFuncs] = { - CellFilter.unit: (unit_cell_filter_init, unit_cell_update), - CellFilter.frechet: (frechet_cell_filter_init, frechet_cell_update), + CellFilter.unit: (unit_cell_filter_init, unit_cell_step), + CellFilter.frechet: (frechet_cell_filter_init, frechet_cell_step), } From bd55620ee1872f0cf4f156268e1f4681dd3c29dd Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 27 Sep 2025 18:40:42 -0700 Subject: [PATCH 11/40] str(dtype).lstrip("torch.") -> str(dtype).removeprefix("torch.") --- examples/scripts/1_Introduction/1.2_MACE.py | 2 +- .../2.3_MACE_Gradient_Descent.py | 2 +- .../2.4_MACE_FIRE.py | 2 +- ....5_MACE_UnitCellFilter_Gradient_Descent.py | 2 +- .../2.6_MACE_UnitCellFilter_FIRE.py | 2 +- .../2.7_MACE_FrechetCellFilter_FIRE.py | 2 +- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 5 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 2 +- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 2 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 2 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 2 +- .../3_Dynamics/3.4_MACE_NVT_Langevin.py | 2 +- .../3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py | 2 +- .../3.6_MACE_NVT_Nose_Hoover_temp_profile.py | 2 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 2 +- .../3.9_MACE_NVT_staggered_stress.py | 2 +- .../5_Workflow/5.1_a2c_silicon_batched.py | 2 +- examples/scripts/5_Workflow/5.3_Elastic.py | 2 +- .../scripts/6_Phonons/6.1_Phonons_MACE.py | 2 +- .../6_Phonons/6.2_QuasiHarmonic_MACE.py | 2 +- .../6_Phonons/6.3_Conductivity_MACE.py | 2 +- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 48 +++++++++---------- examples/tutorials/low_level_tutorial.py | 2 +- tests/models/test_mace.py | 4 +- 24 files changed, 48 insertions(+), 51 deletions(-) diff --git a/examples/scripts/1_Introduction/1.2_MACE.py b/examples/scripts/1_Introduction/1.2_MACE.py index ee3f23f26..d91d42b35 100644 --- a/examples/scripts/1_Introduction/1.2_MACE.py +++ b/examples/scripts/1_Introduction/1.2_MACE.py @@ -20,7 +20,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py index 37c6c6c21..0a69d7a7c 100644 --- a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py @@ -21,7 +21,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py index c152385ac..3e326dd64 100644 --- a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py @@ -22,7 +22,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py index c7b6bbe09..f790ca903 100644 --- a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py @@ -23,7 +23,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py index 60044fc2f..53a564a76 100644 --- a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py @@ -23,7 +23,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py index 56c8eeb6e..bc8256560 100644 --- a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py @@ -23,7 +23,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index f9b156c13..cebfb9739 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -17,15 +17,13 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 - - kT = 1000 * Units.temperature # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) @@ -41,7 +39,6 @@ enable_cueq=False, ) - # %% lattice = [[5.43, 0, 0], [0, 5.43, 0], [0, 0, 5.43]] species = ["Cu", "Cu", "Cu", "Zr", "Cu", "Zr", "Zr", "Zr"] diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index 587322630..4d8539da9 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -22,7 +22,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index df1ee784f..0ee349d4f 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -23,7 +23,7 @@ loaded_model = mace_off( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index 0903c46f6..7fe0ed262 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -24,7 +24,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index e3e7247ad..f21578a38 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -24,7 +24,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py index d270bc03f..7530fc14b 100644 --- a/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py @@ -23,7 +23,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py index 60a13b161..9f7f9274d 100644 --- a/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py @@ -22,7 +22,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py index 52279c3a8..03e474f05 100644 --- a/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py +++ b/examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py @@ -78,7 +78,7 @@ def get_kT( loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index 0aa9e4f87..76dbe6627 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -22,7 +22,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index 7f993362d..b9033f4ba 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -23,7 +23,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py index cf31d5914..25eba0fdf 100644 --- a/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py +++ b/examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py @@ -55,7 +55,7 @@ raw_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/5_Workflow/5.3_Elastic.py b/examples/scripts/5_Workflow/5.3_Elastic.py index 891439703..6f24851bf 100644 --- a/examples/scripts/5_Workflow/5.3_Elastic.py +++ b/examples/scripts/5_Workflow/5.3_Elastic.py @@ -21,7 +21,7 @@ model=MaceUrls.mace_mpa_medium, enable_cueq=False, device=str(device), - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), return_raw_model=True, ) diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 19e60d25e..43a8bf844 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -93,7 +93,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 1031245fc..eb209dd2d 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -219,7 +219,7 @@ def get_qha_phonons( loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) model = MaceModel( diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index 34d36755f..bbc8ae04e 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -59,7 +59,7 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=str(device), ) model = MaceModel( diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 819ebb68f..027c7a4a2 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -28,16 +28,16 @@ # Set device, data type and unit conversion SMOKE_TEST = os.getenv("CI") is not None -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -dtype = torch.float32 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DTYPE = torch.float32 unit_conv = ts.units.UnitConversion # Option 1: Load the raw model from the downloaded model loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), - device=str(device), + default_dtype=str(DTYPE).removeprefix("torch."), + device=str(DEVICE), ) # Number of steps to run @@ -109,15 +109,15 @@ # Create batched model model = MaceModel( model=loaded_model, - device=device, + device=DEVICE, compute_forces=True, compute_stress=True, - dtype=dtype, + dtype=DTYPE, enable_cueq=False, ) # Convert atoms to state -state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +state = ts.io.atoms_to_state(atoms_list, device=DEVICE, dtype=DTYPE) # Run initial inference initial_energies = model(state)["energy"] @@ -165,13 +165,13 @@ def run_optimization_ts( # noqa: PLR0915 total_structures = opt_state.n_systems convergence_steps = torch.full( - (total_structures,), -1, dtype=torch.long, device=device + (total_structures,), -1, dtype=torch.long, device=DEVICE ) convergence_fn = ts.generate_force_convergence_fn( force_tol=force_tol, include_cell_forces=ts_use_frechet ) converged_tensor_global = torch.zeros( - total_structures, dtype=torch.bool, device=device + total_structures, dtype=torch.bool, device=DEVICE ) global_step = 0 all_converged_states = [] @@ -190,7 +190,7 @@ def run_optimization_ts( # noqa: PLR0915 last_active_state = opt_state current_indices = torch.tensor( - batcher.current_idx, dtype=torch.long, device=device + batcher.current_idx, dtype=torch.long, device=DEVICE ) steps_this_round = 1 @@ -281,8 +281,8 @@ def run_optimization_ase( # noqa: C901, PLR0915 ase_calc_instance = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=str(dtype).split(".")[-1], + device=DEVICE, + default_dtype=str(DTYPE).removeprefix("torch."), ) ase_atoms_orig.calc = ase_calc_instance @@ -330,22 +330,22 @@ def run_optimization_ase( # noqa: C901, PLR0915 current_atom_offset = 0 for sys_idx, ats_final in enumerate(final_ase_atoms_list): all_positions.append( - torch.tensor(ats_final.get_positions(), device=device, dtype=dtype) + torch.tensor(ats_final.get_positions(), device=DEVICE, dtype=DTYPE) ) all_masses.append( - torch.tensor(ats_final.get_masses(), device=device, dtype=dtype) + torch.tensor(ats_final.get_masses(), device=DEVICE, dtype=DTYPE) ) all_atomic_numbers.append( - torch.tensor(ats_final.get_atomic_numbers(), device=device, dtype=torch.long) + torch.tensor(ats_final.get_atomic_numbers(), device=DEVICE, dtype=torch.long) ) # ASE cell is row-vector, SimState expects column-vector all_cells.append( - torch.tensor(ats_final.get_cell().array.T, device=device, dtype=dtype) + torch.tensor(ats_final.get_cell().array.T, device=DEVICE, dtype=DTYPE) ) num_atoms_in_current = len(ats_final) all_systems_for_gd.append( - torch.full((num_atoms_in_current,), sys_idx, device=device, dtype=torch.long) + torch.full((num_atoms_in_current,), sys_idx, device=DEVICE, dtype=torch.long) ) current_atom_offset += num_atoms_in_current @@ -357,13 +357,13 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=str(dtype).split(".")[-1], + device=DEVICE, + default_dtype=str(DTYPE).removeprefix("torch."), ) ats_final.calc = temp_calc final_energies_ase.append(ats_final.get_potential_energy()) final_forces_ase_tensors.append( - torch.tensor(ats_final.get_forces(), device=device, dtype=dtype) + torch.tensor(ats_final.get_forces(), device=DEVICE, dtype=DTYPE) ) except Exception as exc: # noqa: BLE001 print(f"Couldn't get final energy/forces for ASE structure {sys_idx}: {exc}") @@ -372,12 +372,12 @@ def run_optimization_ase( # noqa: C901, PLR0915 final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) else: final_forces_ase_tensors.append( - torch.empty((0, 3), device=device, dtype=dtype) + torch.empty((0, 3), device=DEVICE, dtype=DTYPE) ) if not all_positions: # If all optimizations failed early print("Warning: No successful ASE structures to form OptimState.") - return torch.tensor(convergence_steps_list, dtype=torch.long, device=device), None + return torch.tensor(convergence_steps_list, dtype=torch.long, device=DEVICE), None # Concatenate all parts concatenated_positions = torch.cat(all_positions, dim=0) @@ -386,7 +386,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (n_systems, 3, 3) concatenated_system_indices = torch.cat(all_systems_for_gd, dim=0) - concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) + concatenated_energies = torch.tensor(final_energies_ase, device=DEVICE, dtype=DTYPE) concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) # Check for NaN energies which might cause issues @@ -409,7 +409,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) convergence_steps = torch.tensor( - convergence_steps_list, dtype=torch.long, device=device + convergence_steps_list, dtype=torch.long, device=DEVICE ) end_time = time.perf_counter() diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 68e7c7a39..88257a318 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -66,7 +66,7 @@ loaded_model = mace_mp( model=MaceUrls.mace_mpa_medium, return_raw_model=True, - default_dtype=str(dtype).lstrip("torch."), + default_dtype=str(dtype).removeprefix("torch."), device=device, ) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 50246ecb0..8276381a8 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -28,7 +28,7 @@ @pytest.fixture def ase_mace_calculator() -> MACECalculator: - dtype = str(DTYPE).lstrip("torch.") + dtype = str(DTYPE).removeprefix("torch.") return mace_mp( model=MaceUrls.mace_mp_small, device="cpu", default_dtype=dtype, dispersion=False ) @@ -87,7 +87,7 @@ def ase_mace_off_calculator() -> MACECalculator: return mace_off( model=MaceUrls.mace_off_small, device=str(DEVICE), - default_dtype=str(DTYPE).lstrip("torch."), + default_dtype=str(DTYPE).removeprefix("torch."), dispersion=False, ) From 7a4be59efa9fe0a4bab0a3400839bce74a8a2043 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 28 Sep 2025 11:19:46 -0700 Subject: [PATCH 12/40] address @orionarcher comments --- README.md | 20 ++++++++++++++++--- docs/conf.py | 2 +- docs/reference/index.rst | 2 +- docs/tutorials/index.rst | 2 +- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 13 ++++++------ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ef1025167..f5c52c54c 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,20 @@ relaxed_state = ts.optimize( print(relaxed_state.energy) ``` +## Speedup + +TorchSim achieves up to 100x speedup compared to ASE with popular MLIPs. + +
Speedup comparison + +This figure compares the time per atom of ASE and `torch_sim`. Time per atom is defined +as the number of atoms / total time. While ASE can only run a single system of `n_atoms` +(on the $x$ axis), `torch_sim` can run as many systems as will fit in memory. On an H100 80 GB card, +the max atoms that could fit in memory was ~8,000 for [EGIP](https://github.com/FAIR-Chem/fairchem), +~10,000 for [MACE-MPA-0](https://github.com/ACEsuit/mace), ~22,000 for [Mattersim V1 1M](https://github.com/microsoft/mattersim), +~2,500 for [SevenNet](https://github.com/MDIL-SNU/SevenNet), and ~9000 for [PET-MAD](https://github.com/lab-cosmo/pet-mad). +This metric describes model performance by capturing speed and memory usage simultaneously. + ## Installation ### PyPI Installation @@ -100,7 +114,7 @@ pip install torch-sim-atomistic ### Installing from source ```sh -git clone https://github.com/torchsim/torch-sim +git clone https://github.com/TorchSim/torch-sim cd torch-sim pip install . ``` @@ -113,11 +127,11 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https TorchSim's package structure is summarized in the [API reference](https://torchsim.github.io/torch-sim/reference/index.html) documentation and drawn as a treemap below. -![TorchSim package treemap](https://github.com/user-attachments/assets/56f894ad-b995-4108-a6de-a48714276d89) +![TorchSim package treemap](https://github.com/user-attachments/assets/1ccb3a15-233d-4bc0-b11c-35a676a2bcf3) ## License -TorchSim is released under an [MIT license](license). +TorchSim is released under an [MIT license](LICENSE). ## Citation diff --git a/docs/conf.py b/docs/conf.py index d2fbc6fce..218c91442 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,7 +17,7 @@ # -- Project information ----------------------------------------------------- -project = "torch_sim" +project = "torch-sim-atomistic" copyright = "2025, Project TorchSim" # noqa: A001 author = "Abhijeet Gangan, Orion Cohen, Janosh Riebesell" diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2ac6fcaa5..a21d7418b 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -5,7 +5,7 @@ API reference Overview of the TorchSim API. -.. currentmodule:: torch-sim +.. currentmodule:: torch_sim .. autosummary:: :recursive: diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 7ea9a8370..a240df48f 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -7,7 +7,7 @@ For a high-level overview of the tutorials see :doc:`../user/overview`. Runnable versions of the tutorials can also be found in the `torch-sim /examples/tutorials `_ directory. -.. currentmodule:: torch-sim +.. currentmodule:: torch_sim .. toctree:: :titlesonly: diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index cebfb9739..19f8dd48c 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -69,7 +69,10 @@ class HybridSwapMCState(ts.SwapMCState, MDState): last_permutation: torch.Tensor _atom_attributes = ( - ts.SwapMCState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) + _system_attributes = ( + ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) @@ -83,16 +86,14 @@ class HybridSwapMCState(ts.SwapMCState, MDState): ), ) -generator = torch.Generator(device=device) -generator.manual_seed(42) +rng = torch.Generator(device=device) +rng.manual_seed(42) n_steps = 100 dt = torch.tensor(0.002) for step in range(n_steps): if step % 10 == 0: - hybrid_state = ts.swap_mc_step( - model=model, state=hybrid_state, kT=kT, seed=42 + step - ) + hybrid_state = ts.swap_mc_step(model=model, state=hybrid_state, kT=kT, rng=rng) else: hybrid_state = ts.nvt_langevin_step( model=model, state=hybrid_state, dt=dt, kT=torch.tensor(kT) From 691bad32e5251e4e0f79dc6f05a05b95e55209ea Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 28 Sep 2025 11:30:53 -0700 Subject: [PATCH 13/40] fix thermostat init in examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py, @abhijeetgangan look ok to you? --- .../scripts/1_Introduction/1.3_Fairchem.py | 2 +- examples/tutorials/diff_sim.py | 54 +++++++++---------- examples/tutorials/hybrid_swap_tutorial.py | 2 +- torch_sim/integrators/npt.py | 13 +++-- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/examples/scripts/1_Introduction/1.3_Fairchem.py b/examples/scripts/1_Introduction/1.3_Fairchem.py index 06974026e..2bcdadc2a 100644 --- a/examples/scripts/1_Introduction/1.3_Fairchem.py +++ b/examples/scripts/1_Introduction/1.3_Fairchem.py @@ -24,7 +24,7 @@ hf_login(token=hf_token) else: print("Need to login to HuggingFace to access fairchem models") - raise SystemExit(1) + raise SystemExit(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index 6db460893..c74d5899e 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -277,7 +277,7 @@ def species_sigma(diameter: torch.Tensor) -> torch.Tensor: species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.int32) simulation_steps = 1000 packing_fraction = 0.98 -markersize = 260 +marker_size = 260 # %% @@ -319,13 +319,13 @@ def simulation( plt.subplot(1, 2, 1) box_size, raft_energy, bubble_positions = simulation(torch.tensor(1.0)) -draw_system(bubble_positions, box_size.numpy(), markersize) +draw_system(bubble_positions, box_size.numpy(), marker_size) plt.subplot(1, 2, 2) box_size, raft_energy, bubble_positions = simulation(torch.tensor(0.8)) -draw_system(bubble_positions[:N_2], box_size.numpy(), 0.8 * markersize) -draw_system(bubble_positions[N_2:], box_size.numpy(), markersize) +draw_system(bubble_positions[:N_2], box_size.numpy(), 0.8 * marker_size) +draw_system(bubble_positions[N_2:], box_size.numpy(), marker_size) # %% [markdown] """## Forward simulation for different diameters and seeds.""" @@ -335,13 +335,15 @@ def simulation( box_size_tensor = torch.zeros(len(diameters), len(seeds)) raft_energy_tensor = torch.zeros(len(diameters), len(seeds)) bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 2) -for i, d in enumerate(diameters): - for j, s in enumerate(seeds): - box_size, raft_energy, bubble_positions = simulation(d, s) - box_size_tensor[i, j] = box_size - raft_energy_tensor[i, j] = raft_energy.detach() - bubble_positions_tensor[i, j] = bubble_positions - print(f"Finished simulation for diameter {d}, final energy: {raft_energy.detach()}") +for ii, diam in enumerate(diameters): + for jj, seed in enumerate(seeds): + box_size, raft_energy, bubble_positions = simulation(diam, seed) + box_size_tensor[ii, jj] = box_size + raft_energy_tensor[ii, jj] = raft_energy.detach() + bubble_positions_tensor[ii, jj] = bubble_positions + print( + f"Finished simulation for diameter {diam}, final energy: {raft_energy.detach()}" + ) # %% U_mean = torch.mean(raft_energy_tensor, dim=1) U_std = torch.std(raft_energy_tensor, dim=1) @@ -353,21 +355,21 @@ def simulation( plt.ylabel(r"$U$", fontsize=20) plt.show() # %% -ms = 185 -for i, d in enumerate(diameters): - plt.subplot(2, 5, i + 1) - c = min(1, max(0, (U_mean[i].detach().numpy() - 0.4) * 4)) +marker_size = 185 +for ii, diam in enumerate(diameters): + plt.subplot(2, 5, ii + 1) + c = min(1, max(0, (U_mean[ii].detach().numpy() - 0.4) * 4)) color = [c, 0, 1 - c] draw_system( - bubble_positions_tensor[i, 0, :N_2].detach().numpy(), - box_size_tensor[i, 0].detach().numpy(), - d * ms, + bubble_positions_tensor[ii, 0, :N_2].detach().numpy(), + box_size_tensor[ii, 0].detach().numpy(), + diam * marker_size, color=color, ) draw_system( - bubble_positions_tensor[i, 0, N_2:].detach().numpy(), - box_size_tensor[i, 0].detach().numpy(), - ms, + bubble_positions_tensor[ii, 0, N_2:].detach().numpy(), + box_size_tensor[ii, 0].detach().numpy(), + marker_size, color=color, ) @@ -409,9 +411,7 @@ def short_simulation( ) ] grad = torch.autograd.grad( - outputs=[ - model(state)["energy"], - ], + outputs=[model(state)["energy"]], inputs=[diameter], grad_outputs=grad_outputs, create_graph=True, @@ -424,9 +424,9 @@ def short_simulation( # %% dU_dD = torch.zeros(len(diameters), len(seeds)) -for i, d in enumerate(diameters): - for j, s in enumerate(seeds): - _, dU_dD[i, j] = short_simulation(d, bubble_positions_tensor[i, j]) +for ii, diam in enumerate(diameters): + for jj, seed in enumerate(seeds): + _, dU_dD[ii, jj] = short_simulation(diam, bubble_positions_tensor[ii, jj]) # %% plt.subplot(2, 1, 1) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c59072e0f..b7bf6ad45 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -151,7 +151,7 @@ class HybridSwapMCState(SwapMCState, MDState): # Create a persistent PRNG for reproducibility across the whole run rng = torch.Generator(device=mace_model.device) -rng.manual_seed(seed=42) +rng.manual_seed(42) # %% Run the hybrid simulation n_steps = 100 diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 86bebc616..aca778521 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1392,6 +1392,11 @@ def npt_nose_hoover_init( # Calculate cell kinetic energy (using first system for initialization) KE_cell = ts.calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1]) + # Compute total DOF for thermostat initialization and a zero KE placeholder + dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim + total_dof = int(dof_per_system.sum().item()) + KE_zero = torch.tensor(0.0, device=device, dtype=dtype) + # Ensure reference_cell has proper system dimensions if state.cell.ndim == 2: # Single cell matrix - expand to batch dimension @@ -1429,7 +1434,7 @@ def npt_nose_hoover_init( cell_momentum=cell_momentum, cell_mass=cell_mass, barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=thermostat_fns.initialize(), + thermostat=thermostat_fns.initialize(total_dof, KE_zero, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) @@ -1444,12 +1449,6 @@ def npt_nose_hoover_init( # Initialize thermostat npt_state.momenta = momenta - KE = ts.calc_kinetic_energy( - momenta=npt_state.momenta, - masses=npt_state.masses, - system_idx=npt_state.system_idx, - ) - npt_state.thermostat = thermostat_fns.initialize(npt_state.positions.numel(), KE, kT) return npt_state From 25c633f9b41d543058b999edb5c07a162c57130d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 28 Sep 2025 11:43:48 -0700 Subject: [PATCH 14/40] fix examples 6.1_Phonons_MACE.py + 6.2_QuasiHarmonic_MACE.py fix state = step_fn(model=model, state=state, **optimizer_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: fire_step() got an unexpected keyword argument 'constant_volume' --- examples/scripts/6_Phonons/6.1_Phonons_MACE.py | 3 +-- examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py | 6 ++---- examples/scripts/6_Phonons/6.3_Conductivity_MACE.py | 4 ++-- examples/scripts/7_Others/7.3_Batched_neighbor_list.py | 6 +++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 43a8bf844..1508b0681 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -119,8 +119,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b optimizer=ts.OptimFlavor.fire, cell_filter=ts.CellFilter.frechet, max_steps=max_steps, - constant_volume=True, - hydrostatic_strain=True, + init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), ) # Define atoms and Phonopy object diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index eb209dd2d..d5bfadca2 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -68,8 +68,7 @@ def get_relaxed_structure( convergence_fn=converge_max_force, trajectory_reporter=reporter, autobatcher=use_autobatcher, - constant_volume=True, - hydrostatic_strain=True, + init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), ) os.remove(trajectory_file) @@ -121,8 +120,7 @@ def get_qha_structures( max_steps=Nmax, convergence_fn=ts.runners.generate_force_convergence_fn(force_tol=fmax), autobatcher=use_autobatcher, - constant_volume=True, - hydrostatic_strain=True, + init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), ) return scaled_state.to_phonopy() diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index bbc8ae04e..0fc842169 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -10,10 +10,10 @@ # ] # /// import os +import sys import time from typing import TYPE_CHECKING, Literal, cast -import IPython import numpy as np import plotly.graph_objects as go import torch @@ -200,5 +200,5 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: height=600, plot_bgcolor="white", ) -if IPython.get_ipython() is not None: +if "IPython" in sys.modules: fig.show() diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 489b9dd5a..724231e37 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -15,14 +15,14 @@ state = ts.io.atoms_to_state(atoms_list, device=torch.device("cpu"), dtype=torch.float32) pos, cell, pbc = state.positions, state.cell, state.pbc system_idx, n_atoms = state.system_idx, state.n_atoms -cutoff = 4.0 +cutoff = torch.tensor(4.0, dtype=pos.dtype) self_interaction = False # Fix: Ensure pbc has the correct shape [n_systems, 3] pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) mapping, mapping_system, shifts_idx = torch_nl_linked_cell( - cutoff, pos, cell, pbc_tensor, system_idx, self_interaction + pos, cell, pbc_tensor, cutoff, system_idx, self_interaction ) cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_system) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) @@ -34,7 +34,7 @@ print(dds.shape) mapping_n2, mapping_system_n2, shifts_idx_n2 = torch_nl_n2( - cutoff, pos, cell, pbc_tensor, system_idx, self_interaction + pos, cell, pbc_tensor, cutoff, system_idx, self_interaction ) cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_system_n2) dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2) From 5566668d1b0bd1800b9710a1a9f399d17da49f42 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 1 Oct 2025 17:15:14 -0400 Subject: [PATCH 15/40] Merge remote-tracking branch 'origin/main' into api-redesign --- CONTRIBUTING.md | 14 ++++++++++++- docs/conf.py | 2 +- tests/test_trajectory.py | 4 ++-- torch_sim/models/orb.py | 2 +- torch_sim/trajectory.py | 45 ++++++++++++++++++++++++++-------------- 5 files changed, 46 insertions(+), 21 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7d731a12a..62cbf0746 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,18 @@ # Contributing to TorchSim -TorchSim is an experimental library and we would appreciate any feedback from the community. +TorchSim welcomes contributions and feedback from the community. + +## Contributor's Certification + +By making a contribution to this project, you certify that you agree to the following: + +(a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or + +(b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or + +(c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. + +(d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ## Code Reviews diff --git a/docs/conf.py b/docs/conf.py index 218c91442..99c297a83 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,7 +83,7 @@ napoleon_use_ivar = True # The suffix(es) of source filenames. -source_suffix = {".rst": "restructuredtext", ".md": "restructuredtext"} +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} mathjax3_config = { "tex": { diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index e7bd4531b..215877c6f 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -259,7 +259,7 @@ def test_data_type_conversions(test_file: Path) -> None: rng = np.random.default_rng(seed=0) # Test data with different types - test_data = { + test_data: dict[str, np.ndarray | torch.Tensor] = { # NumPy arrays "np_float64": rng.random((10, 3)).astype(np.float64), "np_float32": rng.random((10, 3)).astype(np.float32), @@ -372,7 +372,7 @@ def test_scalar_dtype_handling(test_file: Path) -> None: test_file, coerce_to_float32=True, coerce_to_int32=True, mode="w" ) - scalar_data = { + scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = { "float64_scalar": np.float64(1.0), "float32_scalar": np.float32(1.0), "int64_scalar": np.int64(1), diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 93ac4c333..cff515e30 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -423,7 +423,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index d8a306d8c..c9167d0ff 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -30,7 +30,7 @@ import copy import inspect import pathlib -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TYPE_CHECKING, Any, Literal, Self @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ase import Atoms - from ase.io.trajectory import TrajectoryWriter + from ase.io.trajectory import TrajectoryReader _DATA_TYPE_MAP = { np.dtype("float32"): tables.Float32Atom(), @@ -91,9 +91,18 @@ class TrajectoryReporter: >>> reporter.close() """ + state_frequency: int + trajectory_kwargs: dict[str, Any] + prop_calculators: dict[int, dict[str, Callable]] + state_kwargs: dict[str, Any] + metadata: dict[str, str] | None + shape_warned: bool + trajectories: list["TorchSimTrajectory"] + filenames: list[str | pathlib.Path] | None + def __init__( self, - filenames: str | pathlib.Path | list[str | pathlib.Path] | None, + filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | None, state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, @@ -137,14 +146,13 @@ def __init__( self.trajectories = [] if filenames is None: self.filenames = None - self.trajectories = [] else: self.load_new_trajectories(filenames) self._add_model_arg_to_prop_calculators() def load_new_trajectories( - self, filenames: str | pathlib.Path | list[str | pathlib.Path] + self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: """Load new trajectories into the reporter. @@ -159,7 +167,9 @@ def load_new_trajectories( """ self.finish() - filenames = [filenames] if not isinstance(filenames, list) else filenames + filenames = ( + [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) + ) self.filenames = [pathlib.Path(filename) for filename in filenames] if len(set(self.filenames)) != len(self.filenames): raise ValueError("All filenames must be unique.") @@ -459,7 +469,7 @@ def _initialize_type_map( def write_arrays( self, - data: dict[str, np.ndarray | torch.Tensor], + data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]", steps: int | list[int], ) -> None: """Write arrays to the trajectory file. @@ -473,8 +483,8 @@ def write_arrays( file and that the steps are monotonically increasing. Args: - data (dict[str, np.ndarray | torch.Tensor]): Map of array names to numpy - arrays or torch tensors with shapes [n_frames, ...] + data (Mapping[str, np.ndarray | np.generic | torch.Tensor]): Map of array + names to numpy arrays or torch tensors with shapes [n_frames, ...] steps (int | list[int]): Step number(s) for the frame(s) being written. If steps is an integer, arrays will be treated as single frame data. @@ -489,9 +499,12 @@ def write_arrays( pad_first_dim = False for name, array in data.items(): - # TODO: coerce dtypes to numpy + # Normalize to numpy arrays if isinstance(array, torch.Tensor): array = array.cpu().detach().numpy() + elif not isinstance(array, np.ndarray): + # Convert numpy scalar (np.generic) or Python scalar to ndarray + array = np.array(array) if pad_first_dim: # pad 1st dim of array with 1 @@ -774,7 +787,7 @@ def write_state( # noqa: C901 # Write all arrays to file self.write_arrays(data, steps) - def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: + def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: """Get all available state tensors for a given frame. Retrieves all state-related arrays (positions, cell, masses, etc.) for a @@ -784,7 +797,7 @@ def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: frame (int): Frame index to retrieve (-1 for last frame) Returns: - dict[str, torch.Tensor]: Map of tensor names to their values + dict[str, np.ndarray]: Map of array names to their values Raises: ValueError: If required arrays are missing from trajectory or frame is @@ -918,7 +931,7 @@ def get_state( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=arrays.get("pbc", True), + pbc=bool(arrays.get("pbc", True)), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), @@ -972,7 +985,7 @@ def __len__(self) -> int: """ return self._file.root.data.positions.shape[0] - def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryWriter": + def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader": """Convert trajectory to ASE Trajectory format. Writes the entire trajectory to a new file in ASE format for compatibility @@ -982,7 +995,7 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryWrite filename (str | pathlib.Path): Path to the output ASE trajectory file Returns: - ase.io.trajectory.Trajectory: ASE trajectory object + ase.io.trajectory.TrajectoryReader: ASE trajectory object Raises: ImportError: If ASE is not installed @@ -1003,4 +1016,4 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryWrite traj.write(atoms) traj.close() - return Trajectory(filename) # Reopen in read mode + return Trajectory(filename, mode="r") # Reopen in read mode From 58aa8f6b198ffcc951262007ca24193c6382d4f7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 1 Oct 2025 22:26:32 -0400 Subject: [PATCH 16/40] maint: list more conflicting dependencies --- pyproject.toml | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 66e7f9971..453d0d37a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,10 +145,26 @@ conflicts = [ { extra = "graphpes" }, { extra = "sevenn" }, ], + [ + { extra = "graphpes" }, + { extra = "fairchem" }, + ], + [ + { extra = "graphpes" }, + { extra = "nequip" }, + ], + [ + { extra = "fairchem" }, + { extra = "mace" }, + ], [ { extra = "mace" }, { extra = "mattersim" }, ], + [ + { extra = "mace" }, + { extra = "nequip" }, + ], [ { extra = "mace" }, { extra = "sevenn" }, @@ -156,7 +172,7 @@ conflicts = [ ] [dependency-groups] -dev = ["prek>=4.3.0", "ty>=0.0.1a20"] +dev = ["prek>=0.2.0", "ty>=0.0.1a20"] [tool.ty.rules] # TODO: Unable to work with **kwargs: https://github.com/astral-sh/ty/issues/247 From 6bc02da2348d5c4d992de448d54b86396a57f291 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 3 Oct 2025 15:27:09 -0400 Subject: [PATCH 17/40] fix: bad merge wrt hf token --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8da092077..ef70eb185 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -109,6 +109,8 @@ jobs: uv pip install -e ".[test,${{ matrix.model.name }}]" --resolution=${{ matrix.version.resolution }} --system - name: Run ${{ matrix.model.test_path }} tests + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | if [[ "${{ matrix.model.name }}" == *"fairchem"* ]]; then uv pip install "huggingface_hub[cli]" --system From eb150ded5ee79b98f28faacdd37aa525f3d8c274 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 3 Oct 2025 15:48:47 -0400 Subject: [PATCH 18/40] fix: device fixture was deleted but missed in merge --- tests/models/test_fairchem_legacy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py index 3073c424e..7bcd9df9c 100644 --- a/tests/models/test_fairchem_legacy.py +++ b/tests/models/test_fairchem_legacy.py @@ -2,8 +2,8 @@ import traceback import pytest -import torch +from tests.conftest import DEVICE from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -32,16 +32,16 @@ def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: @pytest.fixture -def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemV1Model: - cpu = device.type == "cpu" +def eqv2_oc20_model_pbc(model_path_oc20: str) -> FairChemV1Model: + cpu = DEVICE.type == "cpu" return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=True) @pytest.fixture def eqv2_oc20_model_non_pbc( - model_path_oc20: str, device: torch.device + model_path_oc20: str, ) -> FairChemV1Model: - cpu = device.type == "cpu" + cpu = DEVICE.type == "cpu" return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=False) @@ -55,9 +55,9 @@ def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: @pytest.fixture def eqv2_omat24_model_pbc( - model_path_omat24: str, device: torch.device + model_path_omat24: str, ) -> FairChemV1Model: - cpu = device.type == "cpu" + cpu = DEVICE.type == "cpu" return FairChemV1Model(model=model_path_omat24, cpu=cpu, seed=0, pbc=True) From c72febd14cf89d8b236dbc3fed0d2a8eb3a41cb4 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 3 Oct 2025 16:04:31 -0400 Subject: [PATCH 19/40] fix: move over test skip for optimizers vs ase given conftest fixture deleted --- tests/test_optimizers_vs_ase.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index e8192fcea..f8e83b386 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -1,3 +1,4 @@ +import traceback from typing import TYPE_CHECKING, Any import pytest @@ -18,7 +19,12 @@ @pytest.fixture def ts_mace_mpa() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" - from mace.calculators.foundations_models import mace_mp + try: + from mace.calculators.foundations_models import mace_mp + except ImportError: + pytest.skip( + f"MACE not installed: {traceback.format_exc()}", allow_module_level=True + ) # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") @@ -37,7 +43,12 @@ def ts_mace_mpa() -> MaceModel: @pytest.fixture def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" - from mace.calculators.foundations_models import mace_mp + try: + from mace.calculators.foundations_models import mace_mp + except ImportError: + pytest.skip( + f"MACE not installed: {traceback.format_exc()}", allow_module_level=True + ) # Ensure dtype matches the one used in the torch-sim fixture (float64) return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") From 35265690e2cfaeb9412bb97d28fd386905ecd376 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 8 Oct 2025 14:58:35 -0700 Subject: [PATCH 20/40] fix: example 7.6 --- examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 027c7a4a2..dee318394 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -281,7 +281,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ase_calc_instance = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, - device=DEVICE, + device=str(DEVICE), default_dtype=str(DTYPE).removeprefix("torch."), ) ase_atoms_orig.calc = ase_calc_instance @@ -357,7 +357,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, - device=DEVICE, + device=str(DEVICE), default_dtype=str(DTYPE).removeprefix("torch."), ) ats_final.calc = temp_calc From e6996e2f4f2c5b620f27a4e9c84a7a9fd2391175 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:00:55 -0400 Subject: [PATCH 21/40] dont run fairchem with 3.13 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ef70eb185..fc41fa1f8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -92,7 +92,7 @@ jobs: uses: astral-sh/setup-uv@v6 - name: Install legacy fairchem repository and dependencies - if: ${{ matrix.model.name == 'fairchem-legacy' }} + if: ${{ matrix.model.name == 'fairchem-legacy' && matrix.version.python != '3.13' }} run: | if [ -f fairchem-repo/packages/requirements.txt ]; then uv pip install -r fairchem-repo/packages/requirements.txt --system From fc25199bf65767b89efe58387f80cda119cfa496 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 8 Oct 2025 15:05:38 -0700 Subject: [PATCH 22/40] fix: diff sim --- examples/tutorials/diff_sim.py | 230 ++++++++++++++++++++++----------- 1 file changed, 155 insertions(+), 75 deletions(-) diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index c74d5899e..3835c926c 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -1,15 +1,16 @@ -# %% +# %% [markdown] +#
+# Dependencies # /// script -# dependencies = ["matplotlib"] +# dependencies = [ +# "matplotlib", +# ] # /// +#
# %% -import torch_sim as ts -from typing import cast -from numpy.typing import NDArray import torch import matplotlib.pyplot as plt -from torch_sim.models.interface import ModelInterface from torch_sim.models.soft_sphere import ( soft_sphere_pair, DEFAULT_SIGMA, @@ -22,7 +23,6 @@ from torch._functorch import config config.donated_buffer = False - # %% [markdown] """ # Differentiable Simulation @@ -34,6 +34,15 @@ # %% +def finalize_plot(shape: tuple[int, int] = (1, 1)): + """Finalize the plot by setting the size and layout.""" + plt.gcf().set_size_inches( + shape[0] * 1.5 * plt.gcf().get_size_inches()[1], + shape[1] * 1.5 * plt.gcf().get_size_inches()[1], + ) + plt.tight_layout() + + def draw_system( R: torch.Tensor, box_size: float, marker_size: float, color: list[float] | None = None ): @@ -96,11 +105,23 @@ def draw_system( plt.show() # %% [markdown] -"""## Define the simple TorchSim model for the soft sphere potential.""" +""" +## Define the simple TorchSim model for the soft sphere potential. +""" # %% -class SoftSphereMultiModel(ModelInterface): +@dataclass +class BaseState: + """Simple simulation state""" + + positions: torch.Tensor + cell: torch.Tensor + pbc: bool + species: torch.Tensor + + +class SoftSphereMultiModel(torch.nn.Module): """Soft sphere potential""" def __init__( @@ -117,8 +138,8 @@ def __init__( ) -> None: """Initialize a soft sphere model for multi-component systems.""" super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype + self.device = device or torch.device("cpu") + self.dtype = dtype self.pbc = pbc # Store species list and determine number of unique species @@ -173,14 +194,16 @@ def __init__( ) def forward( - self, custom_state: ts.SimState, species: torch.Tensor | None = None + self, + custom_state: BaseState, + species: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system with multiple species.""" # Convert inputs to proper device/dtype and handle species positions = custom_state.positions.requires_grad_(True) cell = custom_state.cell - species = custom_state.atomic_numbers + species = custom_state.species if species is not None: species = species.to(device=self.device, dtype=torch.long) @@ -223,9 +246,15 @@ def forward( # Initialize results with total energy (divide by 2 to avoid double counting) potential_energy = pair_energies.sum() / 2 - grad_outputs: list[torch.Tensor | None] = [torch.ones_like(potential_energy)] + grad_outputs: list[torch.Tensor | None] = [ + torch.ones_like( + potential_energy, + ) + ] grad = torch.autograd.grad( - outputs=[potential_energy], + outputs=[ + potential_energy, + ], inputs=[positions], grad_outputs=grad_outputs, create_graph=False, @@ -247,8 +276,62 @@ def forward( """ +# %% +@dataclass +class GDState(BaseState): + """Simple simulation state""" + + forces: torch.Tensor + energy: torch.Tensor + + +def gradient_descent( + model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 +) -> tuple[Callable[[dict[str, torch.Tensor]], GDState], Callable[[GDState], GDState]]: + """Initialize a gradient descent optimization.""" + + def gd_init( + state: dict[str, torch.Tensor], + ) -> GDState: + """Initialize the gradient descent optimization state.""" + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + + return GDState( + positions=state.positions, + forces=forces, + energy=energy, + cell=state.cell, + pbc=state.pbc, + species=state.species, + ) + + def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: + """Perform one gradient descent optimization step to update the + atomic positions. The cell is not optimized.""" + + # Update positions using forces and per-atom learning rates + state.positions = state.positions + lr * state.forces + + # Get updated forces and energy from model + model_output = model(state) + + # Update state with new forces and energy + state.forces = model_output["forces"] + state.energy = model_output["energy"] + + return state + + return gd_init, gd_step + + # %% [markdown] -"""## Setup the simulation environment.""" +""" +## Setup the simulation environment. +""" # %% @@ -261,7 +344,7 @@ def box_size_at_number_density( def box_size_at_packing_fraction( diameter: torch.Tensor, packing_fraction: float ) -> torch.Tensor: - bubble_volume = N_2 * torch.pi * (torch.square(diameter) + 1) / 4 + bubble_volume = N_2 * torch.pi * (diameter**2 + 1) / 4 return torch.sqrt(bubble_volume / packing_fraction) @@ -277,7 +360,7 @@ def species_sigma(diameter: torch.Tensor) -> torch.Tensor: species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.int32) simulation_steps = 1000 packing_fraction = 0.98 -marker_size = 260 +markersize = 260 # %% @@ -290,63 +373,59 @@ def simulation( # Create the energy function. sigma = species_sigma(diameter) model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) - model = cast(SoftSphereMultiModel, torch.compile(model)) + model = torch.compile(model) # Randomly initialize the system. # Fix seed for reproducible random positions torch.manual_seed(seed) R = torch.rand(N, 2) * box_size # Minimize to the nearest minimum. - custom_state = ts.SimState( - atomic_numbers=species, - masses=torch.ones(N), - system_idx=torch.arange(N), - positions=R, - cell=cell, - pbc=True, - ) - state = ts.gradient_descent_init(model, state=custom_state) + init_fn, apply_fn = gradient_descent(model, lr=0.1) + custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + state = init_fn(custom_state) for _ in range(simulation_steps): - state = ts.gradient_descent_step(model, state, pos_lr=0.1) + state = apply_fn(state) return box_size, model(state)["energy"], state.positions # %% [markdown] -"""## Packing at different diameters.""" - +""" +## Packing at different diameters. +""" # %% plt.subplot(1, 2, 1) box_size, raft_energy, bubble_positions = simulation(torch.tensor(1.0)) -draw_system(bubble_positions, box_size.numpy(), marker_size) +draw_system(bubble_positions, box_size, markersize) +finalize_plot((0.5, 0.5)) plt.subplot(1, 2, 2) box_size, raft_energy, bubble_positions = simulation(torch.tensor(0.8)) -draw_system(bubble_positions[:N_2], box_size.numpy(), 0.8 * marker_size) -draw_system(bubble_positions[N_2:], box_size.numpy(), marker_size) +draw_system(bubble_positions[:N_2], box_size, 0.8 * markersize) +draw_system(bubble_positions[N_2:], box_size, markersize) +finalize_plot((2.0, 1)) # %% [markdown] -"""## Forward simulation for different diameters and seeds.""" - +""" +## Forward simulation for different diameters and seeds. +""" # %% diameters = torch.linspace(0.4, 1.0, 10) seeds = torch.arange(1, 6) box_size_tensor = torch.zeros(len(diameters), len(seeds)) raft_energy_tensor = torch.zeros(len(diameters), len(seeds)) bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 2) -for ii, diam in enumerate(diameters): - for jj, seed in enumerate(seeds): - box_size, raft_energy, bubble_positions = simulation(diam, seed) - box_size_tensor[ii, jj] = box_size - raft_energy_tensor[ii, jj] = raft_energy.detach() - bubble_positions_tensor[ii, jj] = bubble_positions - print( - f"Finished simulation for diameter {diam}, final energy: {raft_energy.detach()}" - ) +for i, d in enumerate(diameters): + for j, s in enumerate(seeds): + box_size, raft_energy, bubble_positions = simulation(d, s) + box_size_tensor[i, j] = box_size + raft_energy_tensor[i, j] = raft_energy.detach() + bubble_positions_tensor[i, j] = bubble_positions + print(f"Finished simulation for diameter {d}, final energy: {raft_energy.detach()}") # %% -U_mean = torch.mean(raft_energy_tensor, dim=1) -U_std = torch.std(raft_energy_tensor, dim=1) +U_mean = torch.mean(raft_energy_tensor, axis=1) +U_std = torch.std(raft_energy_tensor, axis=1) plt.plot(diameters.detach().numpy(), U_mean, linewidth=3) plt.fill_between(diameters.detach().numpy(), U_mean + U_std, U_mean - U_std, alpha=0.4) @@ -355,29 +434,32 @@ def simulation( plt.ylabel(r"$U$", fontsize=20) plt.show() # %% -marker_size = 185 -for ii, diam in enumerate(diameters): - plt.subplot(2, 5, ii + 1) - c = min(1, max(0, (U_mean[ii].detach().numpy() - 0.4) * 4)) +ms = 185 +for i, d in enumerate(diameters): + plt.subplot(2, 5, i + 1) + c = min(1, max(0, (U_mean[i].detach().numpy() - 0.4) * 4)) color = [c, 0, 1 - c] draw_system( - bubble_positions_tensor[ii, 0, :N_2].detach().numpy(), - box_size_tensor[ii, 0].detach().numpy(), - diam * marker_size, + bubble_positions_tensor[i, 0, :N_2].detach().numpy(), + box_size_tensor[i, 0].detach().numpy(), + d * ms, color=color, ) draw_system( - bubble_positions_tensor[ii, 0, N_2:].detach().numpy(), - box_size_tensor[ii, 0].detach().numpy(), - marker_size, + bubble_positions_tensor[i, 0, N_2:].detach().numpy(), + box_size_tensor[i, 0].detach().numpy(), + ms, color=color, ) +finalize_plot((2.5, 1)) # %% [markdown] -"""## Meta-optimization with differentiable simulation.""" - +""" +## Meta-optimization with differentiable simulation. +""" # %% + short_simulation_steps = 10 @@ -392,18 +474,12 @@ def short_simulation( model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) # Minimize to the nearest minimum. - custom_state = ts.SimState( - atomic_numbers=species, - masses=torch.ones(N), - system_idx=torch.arange(N), - positions=R, - cell=cell, - pbc=True, - ) - state = ts.gradient_descent_init(model, state=custom_state) + init_fn, apply_fn = gradient_descent(model, lr=0.1) + custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + state = init_fn(custom_state) for i in range(short_simulation_steps): - state = ts.gradient_descent_step(model, state, pos_lr=0.1) + state = apply_fn(state) grad_outputs: list[torch.Tensor | None] = [ torch.ones_like( @@ -411,7 +487,9 @@ def short_simulation( ) ] grad = torch.autograd.grad( - outputs=[model(state)["energy"]], + outputs=[ + model(state)["energy"], + ], inputs=[diameter], grad_outputs=grad_outputs, create_graph=True, @@ -424,15 +502,15 @@ def short_simulation( # %% dU_dD = torch.zeros(len(diameters), len(seeds)) -for ii, diam in enumerate(diameters): - for jj, seed in enumerate(seeds): - _, dU_dD[ii, jj] = short_simulation(diam, bubble_positions_tensor[ii, jj]) +for i, d in enumerate(diameters): + for j, s in enumerate(seeds): + _, dU_dD[i, j] = short_simulation(d, bubble_positions_tensor[i, j]) # %% plt.subplot(2, 1, 1) dU_dD = dU_dD.detach() -dU_mean = torch.mean(dU_dD, dim=1) -dU_std = torch.std(dU_dD, dim=1) +dU_mean = torch.mean(dU_dD, axis=1) +dU_std = torch.std(dU_dD, axis=1) plt.plot(diameters.detach().numpy(), dU_mean, linewidth=3) plt.fill_between( diameters.detach().numpy(), dU_mean + dU_std, dU_mean - dU_std, alpha=0.4 @@ -450,3 +528,5 @@ def short_simulation( plt.xlim([0.4, 1.0]) plt.xlabel(r"$D$", fontsize=20) plt.ylabel(r"$U$", fontsize=20) + +finalize_plot((1.25, 1)) From b15cccec4b8c57a9bf62beeb388d9fb3d544b642 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 8 Oct 2025 15:21:13 -0700 Subject: [PATCH 23/40] fix: 7.6 example --- examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index dee318394..27936b3f5 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -406,6 +406,7 @@ def run_optimization_ase( # noqa: C901, PLR0915 system_idx=concatenated_system_indices, energy=concatenated_energies, forces=concatenated_forces, + stress=None, ) convergence_steps = torch.tensor( From 1206cdd0759be967b4dc6b5acb54370d2cfda26f Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:27:01 -0400 Subject: [PATCH 24/40] try removing changing graph-pes versions and skipping 3.13 with fairchem and orb --- .github/workflows/test.yml | 4 +++- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fc41fa1f8..731180ecc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,6 +70,7 @@ jobs: - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } runs-on: ${{ matrix.os }} + if: ${{ !(contains(matrix.model.name, 'fairchem') && matrix.version.python != '3.12') || !(contains(matrix.model.name, 'orb') && matrix.version.python != '3.12') }} steps: - name: Check out repo @@ -92,7 +93,7 @@ jobs: uses: astral-sh/setup-uv@v6 - name: Install legacy fairchem repository and dependencies - if: ${{ matrix.model.name == 'fairchem-legacy' && matrix.version.python != '3.13' }} + if: ${{ matrix.model.name == 'fairchem-legacy' }} run: | if [ -f fairchem-repo/packages/requirements.txt ]; then uv pip install -r fairchem-repo/packages/requirements.txt --system @@ -103,6 +104,7 @@ jobs: uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system + - name: Install torch_sim with model dependencies if: ${{ matrix.model.name != 'fairchem-legacy' }} run: | diff --git a/pyproject.toml b/pyproject.toml index a28890f4d..d3e6c9f9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.1", "metatrain[pet]>=2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] -graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] +graphpes = ["graph-pes>=0.0.34,<=0.2.0", "mace-torch>=0.3.12"] nequip = ["nequip>=0.12.0"] fairchem = ["fairchem-core>=2.7"] docs = [ From 81739aa98624a78867a4af34c460ae67cf9d176c Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:28:26 -0400 Subject: [PATCH 25/40] change license back to Radical AI --- LICENSE | 2 +- docs/about/license.md | 2 +- docs/conf.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/LICENSE b/LICENSE index 33225395d..6573a2ab4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,5 @@ The MIT License (MIT) -Copyright 2025 Project TorchSim +Copyright 2025 Radical AI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/docs/about/license.md b/docs/about/license.md index 8aacb2b17..8354221ec 100644 --- a/docs/about/license.md +++ b/docs/about/license.md @@ -1,7 +1,7 @@ # License The MIT License (MIT) -Copyright 2025 Project TorchSim +Copyright 2025 Radical AI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/docs/conf.py b/docs/conf.py index 99c297a83..126fcd56b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,7 +18,7 @@ # -- Project information ----------------------------------------------------- project = "torch-sim-atomistic" -copyright = "2025, Project TorchSim" # noqa: A001 +copyright = "2025, Radical AI" # noqa: A001 author = "Abhijeet Gangan, Orion Cohen, Janosh Riebesell" # The short X.Y version From 7f42382ea9f8f03e67c684de58e64f28f8a65789 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:39:32 -0400 Subject: [PATCH 26/40] try skipping correct tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 731180ecc..5d44dffdd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,7 +70,7 @@ jobs: - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } runs-on: ${{ matrix.os }} - if: ${{ !(contains(matrix.model.name, 'fairchem') && matrix.version.python != '3.12') || !(contains(matrix.model.name, 'orb') && matrix.version.python != '3.12') }} + if: ${{ !((matrix.model.name == 'fairchem-legacy' && matrix.version.python == '3.13') || (matrix.model.name == 'orb' && matrix.version.python == '3.13')) }} steps: - name: Check out repo From 1fd77219666bbd355c6c0984c020bbb89ab3cec6 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:41:09 -0400 Subject: [PATCH 27/40] dont skip tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5d44dffdd..9f4a14bb1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -70,7 +70,7 @@ jobs: - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } runs-on: ${{ matrix.os }} - if: ${{ !((matrix.model.name == 'fairchem-legacy' && matrix.version.python == '3.13') || (matrix.model.name == 'orb' && matrix.version.python == '3.13')) }} + # if: ${{ !((matrix.model.name == 'fairchem-legacy' && matrix.version.python == '3.13') || (matrix.model.name == 'orb' && matrix.version.python == '3.13')) }} steps: - name: Check out repo From a4d3686119755b3d6e9c81572065444439af2e24 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 18:54:57 -0400 Subject: [PATCH 28/40] try ignoring orb and fairchem 3.13 in a new way --- .github/workflows/test.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f4a14bb1..a0a0c5e9d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -69,8 +69,12 @@ jobs: - { name: nequip, test_path: "tests/models/test_nequip_framework.py" } - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } + exclude: + - version: { python: '3.13', resolution: lowest-direct } + model: { name: orb, test_path: "tests/models/test_orb.py" } + - version: { python: '3.13', resolution: lowest-direct } + model: { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } runs-on: ${{ matrix.os }} - # if: ${{ !((matrix.model.name == 'fairchem-legacy' && matrix.version.python == '3.13') || (matrix.model.name == 'orb' && matrix.version.python == '3.13')) }} steps: - name: Check out repo From 3be3c0fa97e5d56687821960c079aee38ff08260 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 8 Oct 2025 15:55:32 -0700 Subject: [PATCH 29/40] fix: tutorial block --- examples/tutorials/hybrid_swap_tutorial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index b7bf6ad45..7eef6c15b 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -149,11 +149,12 @@ class HybridSwapMCState(SwapMCState, MDState): - Make larger compositional changes through swap moves """ +# %% # Create a persistent PRNG for reproducibility across the whole run rng = torch.Generator(device=mace_model.device) rng.manual_seed(42) -# %% Run the hybrid simulation +# Run the hybrid simulation n_steps = 100 for step in range(n_steps): if step % 10 == 0: # Attempt swap Monte Carlo move From 8618a7d151cbd221bec1898e545a8d05b1783be8 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 19:38:44 -0400 Subject: [PATCH 30/40] use napoleon instead of numpydoc --- docs/conf.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 126fcd56b..b42066ace 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -39,7 +39,6 @@ "sphinx.ext.autosummary", "myst_parser", "sphinxcontrib.autodoc_pydantic", - "numpydoc", "nbsphinx", "sphinx_design", "sphinx_copybutton", @@ -77,10 +76,13 @@ # autoclass_content = "both" # autodoc_member_order = "bysource" -# better napoleon support +# napoleon support for NumPy and Google style docstrings +napoleon_google_docstring = True +napoleon_numpy_docstring = True napoleon_use_param = True napoleon_use_rtype = True napoleon_use_ivar = True +napoleon_preprocess_types = True # The suffix(es) of source filenames. source_suffix = {".rst": "restructuredtext", ".md": "markdown"} @@ -120,13 +122,6 @@ autosummary_generate = True autosummary_generate_overwrite = True -# numpydoc options -numpydoc_class_members_toctree = False -numpydoc_show_class_members = False -numpydoc_show_inherited_class_members = False -numpydoc_attributes_as_param_list = False -numpydoc_xref_param_type = True - # sphinx-panels shouldn't add bootstrap css as the pydata-sphinx-theme already loads it panels_add_bootstrap_css = False From b1e3ced478828b1f9a23b998d72b3a7a7b8fb1c0 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 20:37:14 -0400 Subject: [PATCH 31/40] update changelog with new changes --- CHANGELOG.md | 148 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 106 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60b74931b..c37f28e31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,60 +1,124 @@ -## v0.2.1 + +## v0.3.0 + +Thank you to everyone who contributed to this release! @t-reents, @curtischong, and @CompRhys did great work squashing an issue with `SimState` concatenation. @curtischong continued his crusade to type and improve the TorchSim API. @orionarcher, @kianpu34593, and @janosh all made contributions that continue to improve package quality and usability. πŸš€ + +## What's Changed + +### πŸ›  Enhancements +* Define attribute scopes in `SimStates` by @curtischong, @CompRhys, @orionarcher in [#228](https://github.com/Radical-AI/torch-sim/pull/228) +* Improve typing of `ModelInterface` by @curtischong, @CompRhys in [#215](https://github.com/Radical-AI/torch-sim/pull/215) +* Make `system_idx` non-optional in `SimState` by @curtischong in [#231](https://github.com/Radical-AI/torch-sim/pull/231) +* Add new states when the `max_memory_scaler` is updated by @kianpu34593 in [#222](https://github.com/Radical-AI/torch-sim/pull/222) +* Rename `batch` to `system` by @curtischong in [#217](https://github.com/Radical-AI/torch-sim/pull/217), [#233](https://github.com/Radical-AI/torch-sim/pull/233) + +### πŸ› Bug Fixes +* Initial fix for concatenation of states in `InFlightAutoBatcher` by @t-reents in [#219](https://github.com/Radical-AI/torch-sim/pull/219) +* Finish fix for `SimState` concatenation by @t-reents and @curtischong in [#232](https://github.com/Radical-AI/torch-sim/pull/232) +* Fix broken code block in low-level tutorial by @CompRhys in [#226](https://github.com/Radical-AI/torch-sim/pull/226) +* Update metatomic checkpoint to fix tests by @curtischong in [#223](https://github.com/Radical-AI/torch-sim/pull/223) +* Fix memory scaling in `determine_max_batch_size` by @t-reents, @janosh in [#212](https://github.com/Radical-AI/torch-sim/pull/212) + +### πŸ“– Documentation +* Update README plot with more models by @orionarcher in [#236](https://github.com/Radical-AI/torch-sim/pull/236), [#237](https://github.com/Radical-AI/torch-sim/pull/237) +* Update `citation.cff` by @CompRhys in [#225](https://github.com/Radical-AI/torch-sim/pull/225) + +**Full Changelog**: https://github.com/Radical-AI/torch-sim/compare/v0.2.2...v0.3.0 -2025-05-01 +## v0.2.2 + +## What's Changed +### πŸ’₯ Breaking Changes +* Remove higher level model imports by @CompRhys in https://github.com/Radical-AI/torch-sim/pull/179 +### πŸ›  Enhancements +* Add per atom energies and stresses for batched LJ by @abhijeetgangan in https://github.com/Radical-AI/torch-sim/pull/144 +* throw error if autobatcher type is wrong by @orionarcher in https://github.com/Radical-AI/torch-sim/pull/167 +### πŸ› Bug Fixes +* Mattersim fix tensors on wrong device (CPU->GPU) by @orionarcher in https://github.com/Radical-AI/torch-sim/pull/154 +* fix `npt_langevin` by @jla-gardner in https://github.com/Radical-AI/torch-sim/pull/153 +* Make sure to move data to CPU before calling vesin by @Luthaf in https://github.com/Radical-AI/torch-sim/pull/156 +* Fix virial calculations in `optimizers` and `integrators` by @janosh in https://github.com/Radical-AI/torch-sim/pull/163 +* Pad memory estimation by @orionarcher in https://github.com/Radical-AI/torch-sim/pull/160 +* Refactor sevennet model by @YutackPark in https://github.com/Radical-AI/torch-sim/pull/172 +* `io` optional dependencies in `pyproject.toml` by @curtischong in https://github.com/Radical-AI/torch-sim/pull/185 +* Fix column->row cell vector mismatch in integrators by @CompRhys in https://github.com/Radical-AI/torch-sim/pull/175 +### πŸ“– Documentation +* (tiny) add graph-pes to README by @jla-gardner in https://github.com/Radical-AI/torch-sim/pull/149 +* Better module fig by @janosh in https://github.com/Radical-AI/torch-sim/pull/168 +### πŸš€ Performance +* More efficient Orb `state_to_atoms_graph` calculation by @AdeeshKolluru in https://github.com/Radical-AI/torch-sim/pull/165 +### 🚧 CI +* Refactor `test_math.py` and `test_transforms.py` by @janosh in https://github.com/Radical-AI/torch-sim/pull/151 +### πŸ₯ Package Health +* Try out hatchling for build vs setuptools by @CompRhys in https://github.com/Radical-AI/torch-sim/pull/177 +### πŸ“¦ Dependencies +* Bump `mace-torch` to v0.3.12 by @janosh in https://github.com/Radical-AI/torch-sim/pull/170 +* Update metatrain dependency by @Luthaf in https://github.com/Radical-AI/torch-sim/pull/186 +### 🏷️ Type Hints +* Add `torch_sim/typing.py` by @janosh in https://github.com/Radical-AI/torch-sim/pull/157 + +## New Contributors +* @Luthaf made their first contribution in https://github.com/Radical-AI/torch-sim/pull/156 +* @YutackPark made their first contribution in https://github.com/Radical-AI/torch-sim/pull/172 +* @curtischong made their first contribution in https://github.com/Radical-AI/torch-sim/pull/185 + +**Full Changelog**: https://github.com/Radical-AI/torch-sim/compare/v0.2.0...v0.2.1 + +## v0.2.1 ## What's Changed ### πŸ’₯ Breaking Changes -* Remove higher level model imports by @CompRhys in #179 +* Remove higher level model imports by @CompRhys in [#179](https://github.com/TorchSim/torch-sim/pull/179) ### πŸ›  Enhancements -* Add per atom energies and stresses for batched LJ by @abhijeetgangan in #144 -* throw error if autobatcher type is wrong by @orionarcher in #167 +* Add per atom energies and stresses for batched LJ by @abhijeetgangan in [#144](https://github.com/TorchSim/torch-sim/pull/144) +* throw error if autobatcher type is wrong by @orionarcher in [#167](https://github.com/TorchSim/torch-sim/pull/167) ### πŸ› Bug Fixes -* Fix column->row cell vector mismatch in integrators by @CompRhys in #175 -* Mattersim fix tensors on wrong device (CPU->GPU) by @orionarcher in #154 -* fix `npt_langevin` by @jla-gardner in #153 -* Make sure to move data to CPU before calling vesin by @Luthaf in #156 -* Fix virial calculations in `optimizers` and `integrators` by @janosh in #163 -* Pad memory estimation by @orionarcher in #160 -* Refactor sevennet model by @YutackPark in #172 -* `io` optional dependencies in `pyproject.toml` by @curtischong in #185 +* Fix column->row cell vector mismatch in integrators by @CompRhys in [#175](https://github.com/TorchSim/torch-sim/pull/175) +* Mattersim fix tensors on wrong device (CPU->GPU) by @orionarcher in [#154](https://github.com/TorchSim/torch-sim/pull/154) +* fix `npt_langevin` by @jla-gardner in [#153](https://github.com/TorchSim/torch-sim/pull/153) +* Make sure to move data to CPU before calling vesin by @Luthaf in [#156](https://github.com/TorchSim/torch-sim/pull/156) +* Fix virial calculations in `optimizers` and `integrators` by @janosh in [#163](https://github.com/TorchSim/torch-sim/pull/163) +* Pad memory estimation by @orionarcher in [#160](https://github.com/TorchSim/torch-sim/pull/160) +* Refactor sevennet model by @YutackPark in [#172](https://github.com/TorchSim/torch-sim/pull/172) +* `io` optional dependencies in `pyproject.toml` by @curtischong in [#185](https://github.com/TorchSim/torch-sim/pull/185) ### πŸ“– Documentation -* (tiny) add graph-pes to README by @jla-gardner in #149 -* Better module fig by @janosh in #168 +* (tiny) add graph-pes to README by @jla-gardner in [#149](https://github.com/TorchSim/torch-sim/pull/149) +* Better module fig by @janosh in [#168](https://github.com/TorchSim/torch-sim/pull/168) ### πŸš€ Performance -* More efficient Orb `state_to_atoms_graph` calculation by @AdeeshKolluru in #165 +* More efficient Orb `state_to_atoms_graph` calculation by @AdeeshKolluru in [#165](https://github.com/TorchSim/torch-sim/pull/165) ### 🚧 CI -* Refactor `test_math.py` and `test_transforms.py` by @janosh in #151 +* Refactor `test_math.py` and `test_transforms.py` by @janosh in [#151](https://github.com/TorchSim/torch-sim/pull/151) ### πŸ₯ Package Health -* Try out hatchling for build vs setuptools by @CompRhys in #177 +* Try out hatchling for build vs setuptools by @CompRhys in [#177](https://github.com/TorchSim/torch-sim/pull/177) ### 🏷️ Type Hints -* Add `torch-sim/typing.py` by @janosh in #157 +* Add `torch-sim/typing.py` by @janosh in [#157](https://github.com/TorchSim/torch-sim/pull/157) ### πŸ“¦ Dependencies -* Bump `mace-torch` to v0.3.12 by @janosh in #170 -* Update metatrain dependency by @Luthaf in #186 +* Bump `mace-torch` to v0.3.12 by @janosh in [#170](https://github.com/TorchSim/torch-sim/pull/170) +* Update metatrain dependency by @Luthaf in [#186](https://github.com/TorchSim/torch-sim/pull/186) ## New Contributors -* @Luthaf made their first contribution in #156 -* @YutackPark made their first contribution in #172 -* @curtischong made their first contribution in #185 +* @Luthaf made their first contribution in [#156](https://github.com/TorchSim/torch-sim/pull/156) +* @YutackPark made their first contribution in [#172](https://github.com/TorchSim/torch-sim/pull/172) +* @curtischong made their first contribution in [#185](https://github.com/TorchSim/torch-sim/pull/185) **Full Changelog**: https://github.com/torchsim/torch-sim/compare/v0.2.0...v0.2.1 @@ -62,33 +126,33 @@ ### Bug Fixes πŸ› -* Fix integrate reporting kwarg to arg error, #113 (raised by @hn-yu) -* Allow runners to take large initial batches, #128 (raised by @YutackPark) -* Add Fairchem model support for PBC, #111 (raised by @ryanliu30) +* Fix integrate reporting kwarg to arg error, [#113](https://github.com/TorchSim/torch-sim/pull/113) (raised by @hn-yu) +* Allow runners to take large initial batches, [#128](https://github.com/TorchSim/torch-sim/pull/128) (raised by @YutackPark) +* Add Fairchem model support for PBC, [#111](https://github.com/TorchSim/torch-sim/pull/111) (raised by @ryanliu30) ### Enhancements πŸ›  -* **breaking** Rename `HotSwappingAutobatcher` to `InFlightAutobatcher` and `ChunkingAutoBatcher` to `BinningAutoBatcher`, #143 @orionarcher -* Support for Orbv3, #140, @AdeeshKolluru -* Support metatensor models, #141, @frostedoyter @Luthaf -* Support for graph-pes models, #118 @jla-gardner -* Support MatterSim and fix ASE cell convention issues, #112 @CompRhys -* Implement positions only FIRE optimization, #139 @abhijeetgangan -* Allow different temperatures in batches, #123 @orionarcher -* FairChem model updates: PBC handling, test on OMat24 e-trained model, #126 @AdeeshKolluru -* FairChem model from_data_list support, #138 @ryanliu30 -* New correlation function module, #115 @stefanbringuier +* **breaking** Rename `HotSwappingAutobatcher` to `InFlightAutobatcher` and `ChunkingAutoBatcher` to `BinningAutoBatcher`, [#143](https://github.com/TorchSim/torch-sim/pull/143) @orionarcher +* Support for Orbv3, [#140](https://github.com/TorchSim/torch-sim/pull/140), @AdeeshKolluru +* Support metatensor models, [#141](https://github.com/TorchSim/torch-sim/pull/141), @frostedoyter @Luthaf +* Support for graph-pes models, [#118](https://github.com/TorchSim/torch-sim/pull/118) @jla-gardner +* Support MatterSim and fix ASE cell convention issues, [#112](https://github.com/TorchSim/torch-sim/pull/112) @CompRhys +* Implement positions only FIRE optimization, [#139](https://github.com/TorchSim/torch-sim/pull/139) @abhijeetgangan +* Allow different temperatures in batches, [#123](https://github.com/TorchSim/torch-sim/pull/123) @orionarcher +* FairChem model updates: PBC handling, test on OMat24 e-trained model, [#126](https://github.com/TorchSim/torch-sim/pull/126) @AdeeshKolluru +* FairChem model from_data_list support, [#138](https://github.com/TorchSim/torch-sim/pull/138) @ryanliu30 +* New correlation function module, [#115](https://github.com/TorchSim/torch-sim/pull/115) @stefanbringuier ### Documentation πŸ“– -* Improved model documentation, #121 @orionarcher -* Plot of TorchSim module graph in docs, #132 @janosh +* Improved model documentation, [#121](https://github.com/TorchSim/torch-sim/pull/121) @orionarcher +* Plot of TorchSim module graph in docs, [#132](https://github.com/TorchSim/torch-sim/pull/132) @janosh ### House-Keeping 🧹 -* Only install HF for fairchem tests, #134 @CompRhys -* Don't download MBD in CI, #135 @orionarcher -* Tighten graph-pes test bounds, #143 @orionarcher +* Only install HF for fairchem tests, [#134](https://github.com/TorchSim/torch-sim/pull/134) @CompRhys +* Don't download MBD in CI, [#135](https://github.com/TorchSim/torch-sim/pull/135) @orionarcher +* Tighten graph-pes test bounds, [#143](https://github.com/TorchSim/torch-sim/pull/143) @orionarcher ## v0.1.0 From 5dc082ea845f35f740299aed7b05241a9c3f52d5 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 20:51:03 -0400 Subject: [PATCH 32/40] demote cell_filter from runners api --- tests/test_runners.py | 14 +++++++------- torch_sim/runners.py | 38 +++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index a1896a048..5792664f7 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -303,9 +303,9 @@ def test_optimize_fire( system=ar_supercell_sim_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), trajectory_reporter=reporter, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) with TorchSimTrajectory(trajectory_files[0]) as traj: @@ -337,8 +337,8 @@ def test_default_converged_fn( system=ar_supercell_sim_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, trajectory_reporter=reporter, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) with TorchSimTrajectory(traj_file) as traj: @@ -374,10 +374,10 @@ def test_batched_optimize_fire( system=ar_double_sim_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-5), trajectory_reporter=reporter, max_steps=500, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) assert torch.all(final_state.forces < 1e-4) @@ -402,9 +402,9 @@ def test_optimize_with_autobatcher( system=triple_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), autobatcher=autobatcher, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) assert isinstance(final_states, SimState) @@ -447,10 +447,10 @@ def test_optimize_with_autobatcher_and_reporting( system=triple_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), trajectory_reporter=reporter, autobatcher=autobatcher, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) assert all(traj_file.is_file() for traj_file in trajectory_files) @@ -535,9 +535,9 @@ def mock_estimate(*args, **kwargs) -> float: # noqa: ARG001 system=triple_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=ts.generate_force_convergence_fn(force_tol=1e-1), autobatcher=True, + init_kwargs={"cell_filter": ts.CellFilter.unit}, ) assert isinstance(final_states, SimState) @@ -810,8 +810,8 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: system=final_state, model=lj_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, # autobatcher=True, # disabled for CPU-based LJ model in test + init_kwargs={"cell_filter": ts.CellFilter.frechet}, ) assert relaxed_state.energy.shape == (final_state.n_systems,) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 2d5ea0dcf..62146473b 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import Any import torch from tqdm import tqdm @@ -26,10 +26,6 @@ from torch_sim.units import UnitSystem -if TYPE_CHECKING: - from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs - - def _configure_reporter( trajectory_reporter: TrajectoryReporter | dict, *, @@ -374,7 +370,6 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 model: ModelInterface, *, optimizer: OptimFlavor | tuple[Callable[..., T], Callable[..., T]], - cell_filter: "CellFilter | CellFilterFuncs | None" = None, convergence_fn: Callable[[T, torch.Tensor | None], torch.Tensor] | None = None, trajectory_reporter: TrajectoryReporter | dict | None = None, autobatcher: InFlightAutoBatcher | bool = False, @@ -393,8 +388,6 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 optimizer (OptimFlavor | tuple): Optimization algorithm function convergence_fn (Callable | None): Condition for convergence, should return a boolean tensor of length n_systems - cell_filter (CellFilter | CellFilterFuncs | None): Optional cell filter to use. - If None, the system will not optimize the cell. trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking optimization trajectory. If a dict, will be passed to the TrajectoryReporter constructor. @@ -451,7 +444,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 init_fn, initial_state, model, - init_kwargs=dict(cell_filter=cell_filter, **init_kwargs or {}), + init_kwargs=dict(**init_kwargs or {}), max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, ) @@ -582,12 +575,11 @@ class StaticState(SimState): forces: torch.Tensor stress: torch.Tensor - _atom_attributes = ( - state._atom_attributes | {"forces"} # noqa: SLF001 - ) - _system_attributes = ( - state._system_attributes | {"energy", "stress"} # noqa: SLF001 - ) + _atom_attributes = state._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = state._system_attributes | { # noqa: SLF001 + "energy", + "stress", + } all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames @@ -613,12 +605,16 @@ class StaticState(SimState): sub_state = StaticState( **vars(sub_state), energy=model_outputs["energy"], - forces=model_outputs["forces"] - if model.compute_forces - else torch.full_like(sub_state.positions, fill_value=float("nan")), - stress=model_outputs["stress"] - if model.compute_stress - else torch.full_like(sub_state.cell, fill_value=float("nan")), + forces=( + model_outputs["forces"] + if model.compute_forces + else torch.full_like(sub_state.positions, fill_value=float("nan")) + ), + stress=( + model_outputs["stress"] + if model.compute_stress + else torch.full_like(sub_state.cell, fill_value=float("nan")) + ), ) props = trajectory_reporter.report(sub_state, 0, model=model) From 534fecd602f1cffa855a9098a242e900c628fcc0 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 20:57:15 -0400 Subject: [PATCH 33/40] swap state and model in elastic and monte_carlo files --- examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 4 ++-- examples/scripts/5_Workflow/5.3_Elastic.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 4 ++-- tests/test_elastic.py | 10 +++++++--- tests/test_monte_carlo.py | 4 ++-- torch_sim/elastic.py | 8 ++++++-- torch_sim/monte_carlo.py | 4 ++-- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 19f8dd48c..c9d6f049b 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -78,7 +78,7 @@ class HybridSwapMCState(ts.SwapMCState, MDState): md_state = ts.nvt_langevin_init(model=model, state=state, kT=torch.tensor(kT), seed=42) -swap_state = ts.swap_mc_init(model=model, state=md_state) +swap_state = ts.swap_mc_init(state=md_state, model=model) hybrid_state = HybridSwapMCState( **vars(md_state), last_permutation=torch.arange( @@ -93,7 +93,7 @@ class HybridSwapMCState(ts.SwapMCState, MDState): dt = torch.tensor(0.002) for step in range(n_steps): if step % 10 == 0: - hybrid_state = ts.swap_mc_step(model=model, state=hybrid_state, kT=kT, rng=rng) + hybrid_state = ts.swap_mc_step(state=hybrid_state, model=model, kT=kT, rng=rng) else: hybrid_state = ts.nvt_langevin_step( model=model, state=hybrid_state, dt=dt, kT=torch.tensor(kT) diff --git a/examples/scripts/5_Workflow/5.3_Elastic.py b/examples/scripts/5_Workflow/5.3_Elastic.py index 6f24851bf..a2b24d5e1 100644 --- a/examples/scripts/5_Workflow/5.3_Elastic.py +++ b/examples/scripts/5_Workflow/5.3_Elastic.py @@ -62,7 +62,7 @@ # Calculate elastic tensor elastic_tensor = ts.elastic.calculate_elastic_tensor( - model, state=state, bravais_type=bravais_type + state=state, model=model, bravais_type=bravais_type ) # Convert to GPa diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 7eef6c15b..b74a2ab96 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -125,7 +125,7 @@ class HybridSwapMCState(SwapMCState, MDState): md_state = ts.nvt_langevin_init(model=mace_model, state=state, kT=kT, seed=42) # Initialize swap Monte Carlo state -swap_state = ts.swap_mc_init(model=mace_model, state=md_state) +swap_state = ts.swap_mc_init(state=md_state, model=mace_model) # Create hybrid state combining both hybrid_state = HybridSwapMCState( @@ -159,7 +159,7 @@ class HybridSwapMCState(SwapMCState, MDState): for step in range(n_steps): if step % 10 == 0: # Attempt swap Monte Carlo move hybrid_state = ts.swap_mc_step( - model=mace_model, state=hybrid_state, kT=kT, rng=rng + state=hybrid_state, model=mace_model, kT=kT, rng=rng ) else: # Perform MD step hybrid_state = ts.nvt_langevin_step( diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 321ed0571..e060945e4 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -333,11 +333,15 @@ def test_elastic_tensor_symmetries( # Calculate elastic tensors C_symmetric = ( - calculate_elastic_tensor(model, state=state, bravais_type=expected_bravais_type) + calculate_elastic_tensor( + state=state, model=model, bravais_type=expected_bravais_type + ) * UnitConversion.eV_per_Ang3_to_GPa ) C_triclinic = ( - calculate_elastic_tensor(model, state=state, bravais_type=BravaisType.triclinic) + calculate_elastic_tensor( + state=state, model=model, bravais_type=BravaisType.triclinic + ) * UnitConversion.eV_per_Ang3_to_GPa ) @@ -373,7 +377,7 @@ def test_copper_elastic_properties( # Calculate elastic tensor bravais_type = get_bravais_type(state) elastic_tensor = calculate_elastic_tensor( - mace_model, state=state, bravais_type=bravais_type + state=state, model=mace_model, bravais_type=bravais_type ) # Convert to GPa diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index c88951318..f02c0ead3 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -165,7 +165,7 @@ def test_monte_carlo_integration( # Initialize rng = torch.Generator(device=DEVICE) rng.manual_seed(42) - mc_state = swap_mc_init(model=lj_model, state=batched_diverse_state) + mc_state = swap_mc_init(state=batched_diverse_state, model=lj_model) assert isinstance(mc_state, SwapMCState) assert mc_state.energy.shape == (batched_diverse_state.n_systems,) assert mc_state.last_permutation.shape == (batched_diverse_state.n_atoms,) @@ -174,7 +174,7 @@ def test_monte_carlo_integration( # Run steps for _step in range(n_steps): - mc_state = swap_mc_step(model=lj_model, state=mc_state, kT=kT, rng=rng) + mc_state = swap_mc_step(state=mc_state, model=lj_model, kT=kT, rng=rng) assert isinstance(mc_state, SwapMCState) # Verify conservation properties diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 0e91fac55..78547cbec 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -748,7 +748,11 @@ def get_elementary_deformations( else: # Shear strain # Generate symmetric strains around zero strains = torch.linspace( - -max_strain_shear, max_strain_shear, n_deform, device=device, dtype=dtype + -max_strain_shear, + max_strain_shear, + n_deform, + device=device, + dtype=dtype, ) # Skip zero strain @@ -1102,9 +1106,9 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 def calculate_elastic_tensor( + state: OptimState, model: ModelInterface, *, - state: OptimState, bravais_type: BravaisType = BravaisType.triclinic, max_strain_normal: float = 0.01, max_strain_shear: float = 0.06, diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 115688cc7..81fe2ab28 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -177,8 +177,8 @@ def metropolis_criterion( def swap_mc_init( - model: ModelInterface, state: SimState, + model: ModelInterface, ) -> SwapMCState: """Initialize a swap Monte Carlo state from input data. @@ -218,8 +218,8 @@ def swap_mc_init( def swap_mc_step( - model: ModelInterface, state: SwapMCState, + model: ModelInterface, *, kT: float, seed: int | None = None, From de6edc1f256850129fb366e4d68089a70c43a259 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 20:59:55 -0400 Subject: [PATCH 34/40] fix state and model order in autobatching and runners --- examples/tutorials/autobatching_tutorial.py | 2 +- torch_sim/autobatching.py | 6 +++--- torch_sim/runners.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index b0eedb90d..08b2c81a5 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -101,7 +101,7 @@ ] max_memory_metric = estimate_max_memory_scaler( - mace_model, state_list, metric_values=memory_metric_values + state_list, mace_model, metric_values=memory_metric_values ) print(f"Max memory metric: {max_memory_metric}") diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 0a0966830..3a45b267c 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -355,8 +355,8 @@ def calculate_memory_scaler( def estimate_max_memory_scaler( - model: ModelInterface, state_list: list[SimState], + model: ModelInterface, metric_values: list[float] | torch.Tensor, **kwargs: Any, ) -> float: @@ -533,8 +533,8 @@ def load_states(self, states: T | Sequence[T]) -> float: ] if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( - self.model, self.state_slices, + self.model, self.memory_scalers, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, @@ -914,8 +914,8 @@ def _get_first_batch(self) -> T: if not has_max_metric: self.max_memory_scaler = estimate_max_memory_scaler( - self.model, [first_state, *states], + self.model, self.current_scalers, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 62146473b..ded757f6a 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -66,8 +66,8 @@ def _configure_reporter( def _configure_batches_iterator( - model: ModelInterface, state: SimState, + model: ModelInterface, *, autobatcher: BinningAutoBatcher | bool, ) -> BinningAutoBatcher | list[tuple[SimState, list[int]]]: @@ -170,7 +170,7 @@ def integrate[T: SimState]( # noqa: C901 # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( - model, initial_state, autobatcher=autobatcher + initial_state, model, autobatcher=autobatcher ) if trajectory_reporter is not None: trajectory_reporter = _configure_reporter( @@ -223,8 +223,8 @@ def integrate[T: SimState]( # noqa: C901 def _configure_in_flight_autobatcher( - model: ModelInterface, state: SimState, + model: ModelInterface, *, autobatcher: InFlightAutoBatcher | bool, max_attempts: int, # TODO: change name to max_iterations @@ -434,7 +434,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 max_attempts = max_steps // steps_between_swaps autobatcher = _configure_in_flight_autobatcher( - model, initial_state, autobatcher=autobatcher, max_attempts=max_attempts + initial_state, model, autobatcher=autobatcher, max_attempts=max_attempts ) if isinstance(initial_state, OptimState): @@ -553,7 +553,7 @@ def static( """ state: SimState = ts.initialize_state(system, model.device, model.dtype) - batch_iterator = _configure_batches_iterator(model, state, autobatcher=autobatcher) + batch_iterator = _configure_batches_iterator(state, model, autobatcher=autobatcher) properties = ["potential_energy"] if model.compute_forces: properties.append("forces") From 97abba56f79550495f3f43c9e108643c033dd4ef Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:09:17 -0400 Subject: [PATCH 35/40] fix state and model order for nve and npt --- .../3.11_Lennard_Jones_NPT_Langevin.py | 4 +- .../3_Dynamics/3.12_MACE_NPT_Langevin.py | 4 +- .../3_Dynamics/3.13_MACE_NVE_non_pbc.py | 4 +- .../3_Dynamics/3.1_Lennard_Jones_NVE.py | 4 +- examples/scripts/3_Dynamics/3.2_MACE_NVE.py | 4 +- .../scripts/3_Dynamics/3.3_MACE_NVE_cueq.py | 4 +- .../3.7_Lennard_Jones_NPT_Nose_Hoover.py | 4 +- .../3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py | 8 +-- .../7_Others/7.4_Velocity_AutoCorrelation.py | 4 +- tests/test_integrators.py | 16 ++--- torch_sim/integrators/npt.py | 59 ++++++++++--------- torch_sim/integrators/nve.py | 4 +- 12 files changed, 60 insertions(+), 59 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py index 281cfb295..642083a6e 100644 --- a/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -106,7 +106,7 @@ torch.tensor(10_000, device=device, dtype=dtype) * Units.pressure ) # Target pressure (10 kbar) -state = ts.npt_langevin_init(model=model, state=state, dt=dt, kT=kT, seed=1) +state = ts.npt_langevin_init(state=state, model=model, dt=dt, kT=kT, seed=1) # Run the simulation for step in range(N_steps): @@ -132,8 +132,8 @@ f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) state = ts.npt_langevin_step( - model=model, state=state, + model=model, dt=dt, kT=kT, external_pressure=target_pressure, diff --git a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py index 4d8539da9..ec09c9ca1 100644 --- a/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py +++ b/examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -71,7 +71,7 @@ print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") state = ts.nvt_nose_hoover_step(model=model, state=state, dt=dt, kT=kT) -state = ts.npt_langevin_init(model=model, state=state, kT=kT, dt=dt, seed=1) +state = ts.npt_langevin_init(state=state, model=model, kT=kT, dt=dt, seed=1) for step in range(N_steps_npt): if step % 10 == 0: @@ -102,8 +102,8 @@ f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) state = ts.npt_langevin_step( - model=model, state=state, + model=model, dt=dt, kT=kT, external_pressure=target_pressure, diff --git a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py index 0ee349d4f..8acef6aba 100644 --- a/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py +++ b/examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py @@ -59,7 +59,7 @@ # Initialize NVE integrator -state = ts.nve_init(model=model, state=state, kT=kT, seed=1) +state = ts.nve_init(state=state, model=model, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") @@ -70,7 +70,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = ts.nve_step(model=model, state=state, dt=dt) + state = ts.nve_step(state=state, model=model, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py index 839d68d04..c8b8f6f55 100644 --- a/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py +++ b/examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py @@ -103,7 +103,7 @@ dt = torch.tensor(0.001 * Units.time, device=device, dtype=dtype) # Initialize NVE integrator -state = ts.nve_init(model=model, state=state, kT=kT, seed=1) +state = ts.nve_init(state=state, model=model, kT=kT, seed=1) # Run NVE simulation for 1000 steps for step in range(N_steps): @@ -115,7 +115,7 @@ print(f"{step=}: Total energy: {total_energy.item():.4f}") # Update state using NVE integrator - state = ts.nve_step(model=model, state=state, dt=dt) + state = ts.nve_step(state=state, model=model, dt=dt) final_total_energy = state.energy + ts.calc_kinetic_energy( masses=state.masses, momenta=state.momenta diff --git a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py index 7fe0ed262..06f97c1f4 100644 --- a/examples/scripts/3_Dynamics/3.2_MACE_NVE.py +++ b/examples/scripts/3_Dynamics/3.2_MACE_NVE.py @@ -67,7 +67,7 @@ # Initialize NVE integrator -state = ts.nve_init(model=model, state=state, kT=kT, seed=1) +state = ts.nve_init(state=state, model=model, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") @@ -78,7 +78,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = ts.nve_step(model=model, state=state, dt=dt) + state = ts.nve_step(state=state, model=model, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py index f21578a38..7d906c694 100644 --- a/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py +++ b/examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py @@ -58,7 +58,7 @@ # Initialize NVE integrator -state = nve_init(model=model, state=state, kT=kT, seed=1) +state = nve_init(state=state, model=model, kT=kT, seed=1) # Run MD simulation print("\nStarting NVE molecular dynamics simulation...") @@ -69,7 +69,7 @@ ) if step % 10 == 0: print(f"Step {step}: Total energy: {total_energy.item():.4f} eV") - state = nve_step(model=model, state=state, dt=dt) + state = nve_step(state=state, model=model, dt=dt) end_time = time.perf_counter() # Report simulation results diff --git a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py index f64277195..d69ddf951 100644 --- a/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py @@ -106,8 +106,8 @@ ) # Target pressure (10 kbar) state = ts.npt_nose_hoover_init( - model=model, state=state, + model=model, dt=dt, kT=kT, chain_length=3, # Chain length @@ -141,7 +141,7 @@ f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) state = ts.npt_nose_hoover_step( - model=model, state=state, dt=dt, kT=kT, external_pressure=target_pressure + state=state, model=model, dt=dt, kT=kT, external_pressure=target_pressure ) temp = ( diff --git a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py index 76dbe6627..e8ccf814e 100644 --- a/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py +++ b/examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py @@ -57,7 +57,7 @@ 0.0 * Units.pressure, device=device, dtype=dtype ) # Target pressure (0 bar) -state = ts.npt_nose_hoover_init(model=model, state=state, kT=kT, dt=torch.tensor(dt)) +state = ts.npt_nose_hoover_init(state=state, model=model, kT=kT, dt=torch.tensor(dt)) for step in range(N_steps_nvt): if step % 10 == 0: @@ -72,14 +72,14 @@ ) print(f"{step=}: Temperature: {temp.item():.4f}: {invariant=:.4f}, ") state = ts.npt_nose_hoover_step( - model=model, state=state, + model=model, dt=torch.tensor(dt), kT=kT, external_pressure=target_pressure, ) -state = ts.npt_nose_hoover_init(model=model, state=state, kT=kT, dt=torch.tensor(dt)) +state = ts.npt_nose_hoover_init(state=state, model=model, kT=kT, dt=torch.tensor(dt)) for step in range(N_steps_npt): if step % 10 == 0: @@ -105,8 +105,8 @@ f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" ) state = ts.npt_nose_hoover_step( - model=model, state=state, + model=model, dt=torch.tensor(dt), kT=kT, external_pressure=target_pressure, diff --git a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py index b58a1aad2..cdd912646 100644 --- a/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py +++ b/examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py @@ -72,7 +72,7 @@ def plot_results(*, time: np.ndarray, vacf: np.ndarray, window_count: int) -> No def main() -> None: """Run velocity autocorrelation simulation using Lennard-Jones model.""" state, lj_model, dt, kT, device, _dtype, timestep = prepare_system() - state = ts.nve_init(model=lj_model, state=state, kT=kT) + state = ts.nve_init(state=state, model=lj_model, kT=kT) window_size = 150 # Length of correlation: dt * correlation_dt * window_size vacf_calc = VelocityAutoCorrelation( @@ -93,7 +93,7 @@ def main() -> None: num_steps = 15000 # NOTE: short run for step in range(num_steps): - state = ts.nve_step(model=lj_model, state=state, dt=dt) # type: ignore[call-arg] + state = ts.nve_step(state=state, model=lj_model, dt=dt) # type: ignore[call-arg] reporter.report(state, step) reporter.close() diff --git a/tests/test_integrators.py b/tests/test_integrators.py index aad4ed28b..d20015a92 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -90,7 +90,7 @@ def test_npt_langevin( # Initialize integrator using new direct API state = ts.npt_langevin_init( - model=lj_model, state=ar_double_sim_state, dt=dt, kT=kT, alpha=alpha, seed=42 + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, alpha=alpha, seed=42 ) # Run dynamics for several steps @@ -98,8 +98,8 @@ def test_npt_langevin( temperatures = [] for _step in range(n_steps): state = ts.npt_langevin_step( - model=lj_model, state=state, + model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, @@ -161,7 +161,7 @@ def test_npt_langevin_multi_kt( # Initialize integrator using new direct API state = ts.npt_langevin_init( - model=lj_model, state=ar_double_sim_state, dt=dt, kT=kT, alpha=alpha, seed=42 + state=ar_double_sim_state, model=lj_model, dt=dt, kT=kT, alpha=alpha, seed=42 ) # Run dynamics for several steps @@ -169,8 +169,8 @@ def test_npt_langevin_multi_kt( temperatures = [] for _step in range(n_steps): state = ts.npt_langevin_step( - model=lj_model, state=state, + model=lj_model, dt=dt, kT=kT, external_pressure=external_pressure, @@ -301,12 +301,12 @@ def test_nve(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nve_init(model=lj_model, state=ar_double_sim_state, kT=kT, seed=42) + state = ts.nve_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) # Run dynamics for several steps energies = [] for _step in range(n_steps): - state = ts.nve_step(model=lj_model, state=state, dt=dt) + state = ts.nve_step(state=state, model=lj_model, dt=dt) energies.append(state.energy) @@ -346,13 +346,13 @@ def test_compare_single_vs_batched_integrators( # Initialize momenta (even if zero) and get forces state = ts.nve_init( - model=lj_model, state=state, kT=kT, seed=42 + state=state, model=lj_model, kT=kT, seed=42 ) # kT is ignored if momenta are set below # Ensure momenta start at zero AFTER init which might randomize them based on kT state.momenta = torch.zeros_like(state.momenta) # Start from rest for _step in range(n_steps): - state = ts.nve_step(model=lj_model, state=state, dt=dt) + state = ts.nve_step(state=state, model=lj_model, dt=dt) final_states[state_name] = state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index aca778521..1ab4e7c33 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -66,9 +66,10 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = ( - SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 - ) + _atom_attributes = SimState._atom_attributes | { # noqa: SLF001 + "forces", + "velocities", + } _system_attributes = SimState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", @@ -126,7 +127,10 @@ def _npt_langevin_beta( def _npt_langevin_cell_beta( - state: NPTLangevinState, cell_alpha: torch.Tensor, kT: torch.Tensor, dt: torch.Tensor + state: NPTLangevinState, + cell_alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, ) -> torch.Tensor: """Generate random noise for cell fluctuations in NPT dynamics. @@ -525,8 +529,8 @@ def _compute_cell_force( def npt_langevin_init( - model: ModelInterface, state: SimState | StateDict, + model: ModelInterface, *, kT: torch.Tensor, dt: torch.Tensor, @@ -643,8 +647,8 @@ def npt_langevin_init( def npt_langevin_step( - model: ModelInterface, state: NPTLangevinState, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -814,24 +818,18 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns - _system_attributes = ( - MDState._system_attributes # noqa: SLF001 - | { - "reference_cell", - "cell_position", - "cell_momentum", - "cell_mass", - } - ) - _global_attributes = ( - MDState._global_attributes # noqa: SLF001 - | { - "thermostat", - "barostat", - "thermostat_fns", - "barostat_fns", - } - ) + _system_attributes = MDState._system_attributes | { # noqa: SLF001 + "reference_cell", + "cell_position", + "cell_momentum", + "cell_mass", + } + _global_attributes = MDState._global_attributes | { # noqa: SLF001 + "thermostat", + "barostat", + "thermostat_fns", + "barostat_fns", + } @property def velocities(self) -> torch.Tensor: @@ -910,7 +908,10 @@ def volume_to_cell(V: torch.Tensor) -> torch.Tensor: def _npt_nose_hoover_update_cell_mass( - state: NPTNoseHooverState, kT: torch.Tensor, device: torch.device, dtype: torch.dtype + state: NPTNoseHooverState, + kT: torch.Tensor, + device: torch.device, + dtype: torch.dtype, ) -> NPTNoseHooverState: """Update the cell mass parameter in an NPT simulation. @@ -1183,8 +1184,8 @@ def _npt_nose_hoover_compute_cell_force( def _npt_nose_hoover_inner_step( - model: ModelInterface, state: NPTNoseHooverState, + model: ModelInterface, dt: torch.Tensor, external_pressure: torch.Tensor, ) -> NPTNoseHooverState: @@ -1292,8 +1293,8 @@ def _npt_nose_hoover_inner_step( def npt_nose_hoover_init( - model: ModelInterface, state: SimState | StateDict, + model: ModelInterface, *, kT: torch.Tensor, dt: torch.Tensor, @@ -1454,8 +1455,8 @@ def npt_nose_hoover_init( def npt_nose_hoover_step( - model: ModelInterface, state: NPTNoseHooverState, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -1500,7 +1501,7 @@ def npt_nose_hoover_step( ) # Perform inner NPT step - state = _npt_nose_hoover_inner_step(model, state, dt, external_pressure) + state = _npt_nose_hoover_inner_step(state, model, dt, external_pressure) # Update kinetic energies for thermostats KE = ts.calc_kinetic_energy( diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 53b5e31c9..d3773b3ce 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -16,8 +16,8 @@ def nve_init( - model: ModelInterface, state: SimState | StateDict, + model: ModelInterface, *, kT: torch.Tensor, seed: int | None = None, @@ -71,7 +71,7 @@ def nve_init( def nve_step( - model: ModelInterface, state: MDState, *, dt: torch.Tensor, **_kwargs: Any + state: MDState, model: ModelInterface, *, dt: torch.Tensor, **_kwargs: Any ) -> MDState: """Perform one complete NVE (microcanonical) integration step. From b974edd5b0b3bd2f40659d53190ec56bbfc563fa Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:10:51 -0400 Subject: [PATCH 36/40] fix state and model order for nvt --- examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py | 2 +- examples/scripts/4_High_level_api/4.2_auto_batching_api.py | 4 ++-- examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py | 4 ++-- examples/tutorials/autobatching_tutorial.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- examples/tutorials/low_level_tutorial.py | 2 +- tests/test_integrators.py | 4 ++-- torch_sim/integrators/nvt.py | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index c9d6f049b..331d69a14 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -76,7 +76,7 @@ class HybridSwapMCState(ts.SwapMCState, MDState): ) -md_state = ts.nvt_langevin_init(model=model, state=state, kT=torch.tensor(kT), seed=42) +md_state = ts.nvt_langevin_init(state=state, model=model, kT=torch.tensor(kT), seed=42) swap_state = ts.swap_mc_init(state=md_state, model=model) hybrid_state = HybridSwapMCState( diff --git a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py index b9033f4ba..9faae200c 100644 --- a/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py +++ b/examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py @@ -54,7 +54,7 @@ torch.tensor(1000, device=device, dtype=dtype) * Units.temperature ) # Initial temperature (K) -state = ts.nvt_langevin_init(model=model, state=state, kT=kT) +state = ts.nvt_langevin_init(state=state, model=model, kT=kT) stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) for step in range(N_steps): diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index 71005a00d..bb1af0167 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -85,14 +85,14 @@ # %% run binning autobatcher si_nvt_state = ts.nvt_langevin_init( - model=mace_model, state=si_state, + model=mace_model, dt=torch.tensor(0.001), kT=torch.tensor(300 * MetalUnits.temperature), ) fe_nvt_state = ts.nvt_langevin_init( - model=mace_model, state=fe_state, + model=mace_model, dt=torch.tensor(0.001), kT=torch.tensor(300 * MetalUnits.temperature), ) diff --git a/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py index f75beecb5..4f50568ec 100644 --- a/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py +++ b/examples/scripts/7_Others/7.7_Heat_flux_and_kappa.py @@ -59,7 +59,7 @@ dt = torch.tensor(timestep * Units.time, device=device, dtype=dtype) kT = torch.tensor(temperature * Units.temperature, device=device, dtype=dtype) -state = ts.nvt_langevin_init(model=lj_model, state=state, kT=kT) +state = ts.nvt_langevin_init(state=state, model=lj_model, kT=kT) # Short equilibration run # Shape: (num_steps, batch, dim) @@ -82,7 +82,7 @@ if step % 1000 == 0: print(f"Step {step} | {state.energy.item():.4f} eV") -state = ts.nvt_langevin_init(model=lj_model, state=state, kT=kT) +state = ts.nvt_langevin_init(state=state, model=lj_model, kT=kT) hfacf_calc = HeatFluxAutoCorrelation( model=lj_model, diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index 08b2c81a5..b31ddc50b 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -177,7 +177,7 @@ def process_batch(batch): """ # %% Initialize nvt langevin integrator -nvt_state = ts.nvt_langevin_init(mace_model, state, kT=0.01) +nvt_state = ts.nvt_langevin_init(state, mace_model, kT=0.01) # Initialize the batcher batcher = ts.BinningAutoBatcher( diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index b74a2ab96..c41a7f0b1 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -122,7 +122,7 @@ class HybridSwapMCState(SwapMCState, MDState): kT = 1000 * MetalUnits.temperature # Initialize NVT Langevin dynamics state -md_state = ts.nvt_langevin_init(model=mace_model, state=state, kT=kT, seed=42) +md_state = ts.nvt_langevin_init(state=state, model=mace_model, kT=kT, seed=42) # Initialize swap Monte Carlo state swap_state = ts.swap_mc_init(state=md_state, model=mace_model) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 88257a318..92115c62e 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -204,7 +204,7 @@ """ # %% -state = ts.nvt_langevin_init(model=model, state=state, kT=kT) +state = ts.nvt_langevin_init(state=state, model=model, kT=kT) initial_kT = kT for step in range(30): diff --git a/tests/test_integrators.py b/tests/test_integrators.py index d20015a92..31ccc75b0 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -209,7 +209,7 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo # Initialize integrator state = ts.nvt_langevin_init( - model=lj_model, state=ar_double_sim_state, kT=kT, seed=42 + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 ) energies = [] temperatures = [] @@ -265,7 +265,7 @@ def test_nvt_langevin_multi_kt( # Initialize integrator state = ts.nvt_langevin_init( - model=lj_model, state=ar_double_sim_state, kT=kT, seed=42 + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 ) energies = [] temperatures = [] diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 1c64e1d7f..b4cb41ce0 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -77,8 +77,8 @@ def _ou_step( def nvt_langevin_init( - model: ModelInterface, state: SimState | StateDict, + model: ModelInterface, *, kT: float | torch.Tensor, seed: int | None = None, From 9dbd05bd0fe07cf43665d155a5f46d43abd0c217 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:23:38 -0400 Subject: [PATCH 37/40] change state and model order for fire and gradient descent optimizers --- .../2.1_Lennard_Jones_FIRE.py | 4 +- .../2.2_Soft_Sphere_FIRE.py | 4 +- .../2.3_MACE_Gradient_Descent.py | 4 +- .../2.4_MACE_FIRE.py | 4 +- ....5_MACE_UnitCellFilter_Gradient_Descent.py | 4 +- .../2.6_MACE_UnitCellFilter_FIRE.py | 4 +- .../2.7_MACE_FrechetCellFilter_FIRE.py | 4 +- .../4_High_level_api/4.2_auto_batching_api.py | 8 +-- .../scripts/5_Workflow/5.2_In_Flight_WBM.py | 4 +- examples/scripts/5_Workflow/5.3_Elastic.py | 4 +- examples/tutorials/autobatching_tutorial.py | 4 +- examples/tutorials/low_level_tutorial.py | 8 +-- tests/test_autobatching.py | 12 ++--- tests/test_elastic.py | 8 +-- tests/test_optimizers.py | 50 +++++++++---------- torch_sim/optimizers/cell_filters.py | 4 +- torch_sim/optimizers/fire.py | 10 ++-- torch_sim/optimizers/gradient_descent.py | 6 +-- torch_sim/workflows/a2c.py | 20 ++++---- 19 files changed, 84 insertions(+), 82 deletions(-) diff --git a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py index 507367c87..f9475196c 100644 --- a/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py @@ -93,14 +93,14 @@ results = model(state) # Initialize FIRE optimizer -state = ts.fire_init(model=model, state=state, dt_start=0.005) +state = ts.fire_init(state=state, model=model, dt_start=0.005) # Run optimization for N_steps for step in range(N_steps): if step % 100 == 0: print(f"{step=}: Potential energy: {state.energy[0].item()} eV") - state = ts.fire_step(model, state, dt_max=0.01) + state = ts.fire_step(state=state, model=model, dt_max=0.01) # Print max force after optimization print(f"Initial energy: {results['energy'][0].item()} eV") diff --git a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py index 61eedb24e..76511f834 100644 --- a/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py @@ -94,14 +94,14 @@ results = model(state) # Initialize FIRE optimizer -state = ts.fire_init(model=model, state=state, dt_start=0.005) +state = ts.fire_init(state=state, model=model, dt_start=0.005) # Run optimization for N_steps for step in range(N_steps): if step % 100 == 0: print(f"{step=}: Total energy: {state.energy[0].item()} eV") - state = ts.fire_step(model, state, dt_max=0.01) + state = ts.fire_step(state=state, model=model, dt_max=0.01) # Print max force after optimization print(f"Initial energy: {results['energy'][0].item()} eV") diff --git a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py index 0a69d7a7c..7a9c3d53e 100644 --- a/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py @@ -109,7 +109,7 @@ learning_rate = 0.01 # Initialize batched gradient descent optimizer -state = ts.gradient_descent_init(model=batched_model, state=state) +state = ts.gradient_descent_init(state=state, model=batched_model) # Run batched optimization for a few steps print("\nRunning batched gradient descent:") @@ -117,7 +117,7 @@ if step % 10 == 0: print(f"Step {step}, Energy: {[res.item() for res in state.energy]} eV") state = ts.gradient_descent_step( - model=batched_model, state=state, pos_lr=learning_rate + state=state, model=batched_model, pos_lr=learning_rate ) print(f"Initial energies: {[res.item() for res in results['energy']]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py index 3e326dd64..3f7379553 100644 --- a/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py @@ -73,7 +73,7 @@ results = model(state) # Initialize unit cell gradient descent optimizer -state = ts.fire_init(model=model, state=state, dt_start=0.005) +state = ts.fire_init(state=state, model=model, dt_start=0.005) # Run optimization for a few steps @@ -82,7 +82,7 @@ if step % 20 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") - state = ts.fire_step(model, state, dt_max=0.01) + state = ts.fire_step(state=state, model=model, dt_max=0.01) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py index f790ca903..0185fcb8f 100644 --- a/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py @@ -78,8 +78,8 @@ state = ts.gradient_descent_init( - model=model, state=state, + model=model, cell_filter=ts.CellFilter.unit, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, @@ -102,7 +102,7 @@ ) state = ts.gradient_descent_step( - model=model, state=state, pos_lr=pos_lr, cell_lr=cell_lr + state=state, model=model, pos_lr=pos_lr, cell_lr=cell_lr ) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py index 53a564a76..29e244063 100644 --- a/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py @@ -75,8 +75,8 @@ # Initialize FIRE optimizer with unit cell filter state = ts.fire_init( - model=model, state=state, + model=model, cell_filter=ts.CellFilter.unit, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, @@ -97,7 +97,7 @@ f"P1={P1:.4f} GPa, P2={P2:.4f} GPa, P3={P3:.4f} GPa" ) - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py index bc8256560..3a1d881c9 100644 --- a/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py @@ -75,8 +75,8 @@ # Initialize FIRE optimizer with Frechet cell filter state = ts.fire_init( - model=model, state=state, + model=model, cell_filter=ts.CellFilter.frechet, cell_factor=None, # Will default to atoms per system hydrostatic_strain=False, @@ -97,7 +97,7 @@ f"P1={P1:.4f} GPa, P2={P2:.4f} GPa, P3={P3:.4f} GPa" ) - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") print(f"Final energies: {[energy.item() for energy in state.energy]} eV") diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index bb1af0167..9715d91aa 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -42,13 +42,13 @@ si_state = ts.io.atoms_to_state(si_atoms, device=device, dtype=torch.float64) fe_state = ts.io.atoms_to_state(fe_atoms, device=device, dtype=torch.float64) -state = ts.fire_init(model=mace_model, state=si_state, cell_filter=ts.CellFilter.unit) +state = ts.fire_init(state=si_state, model=mace_model, cell_filter=ts.CellFilter.unit) si_fire_state = ts.fire_init( - model=mace_model, state=si_state, cell_filter=ts.CellFilter.unit + state=si_state, model=mace_model, cell_filter=ts.CellFilter.unit ) fe_fire_state = ts.fire_init( - model=mace_model, state=fe_state, cell_filter=ts.CellFilter.unit + state=fe_state, model=mace_model, cell_filter=ts.CellFilter.unit ) fire_states = [si_fire_state, fe_fire_state] * (2 if SMOKE_TEST else 20) @@ -77,7 +77,7 @@ print(f"Total number of completed states {len(all_completed_states)}") for _step in range(10): - state = ts.fire_step(model=mace_model, state=state) + state = ts.fire_step(state=state, model=mace_model) convergence_tensor = converge_max_force(state, last_energy=None) all_completed_states.extend(result[1]) print(f"Total number of completed states {len(all_completed_states)}") diff --git a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py index 8bebb4f79..b95cc73eb 100644 --- a/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py +++ b/examples/scripts/5_Workflow/5.2_In_Flight_WBM.py @@ -61,8 +61,8 @@ # Initialize first batch fire_states = ts.fire_init( - model=mace_model, state=ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype), + model=mace_model, cell_filter=ts.CellFilter.frechet, ) @@ -86,7 +86,7 @@ print(f"Total number of completed states {len(all_completed_states)}") for _step in range(10): - state = ts.fire_step(model=mace_model, state=state) + state = ts.fire_step(state=state, model=mace_model) convergence_tensor = converge_max_force(state, last_energy=None) all_completed_states.extend(result[1]) print(f"Total number of completed states {len(all_completed_states)}") diff --git a/examples/scripts/5_Workflow/5.3_Elastic.py b/examples/scripts/5_Workflow/5.3_Elastic.py index a2b24d5e1..ba9bac775 100644 --- a/examples/scripts/5_Workflow/5.3_Elastic.py +++ b/examples/scripts/5_Workflow/5.3_Elastic.py @@ -42,7 +42,7 @@ # Relax positions and cell state = ts.io.atoms_to_state(atoms=struct, device=device, dtype=dtype) state = ts.fire_init( - model=model, state=state, scalar_pressure=0.0, cell_filter=ts.CellFilter.frechet + state=state, model=model, scalar_pressure=0.0, cell_filter=ts.CellFilter.frechet ) for step in range(300): @@ -55,7 +55,7 @@ ) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) # Get bravais type bravais_type = get_bravais_type(state) diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index b31ddc50b..225cce571 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -221,7 +221,7 @@ def process_batch(batch): # %% fire_state = ts.fire_init( - model=mace_model, state=state, cell_filter=ts.CellFilter.frechet + state=state, model=mace_model, cell_filter=ts.CellFilter.frechet ) # Initialize the batcher @@ -252,7 +252,7 @@ def process_batch(batch): # optimize the batch, we stagger the steps to avoid state processing overhead for _ in range(10): - fire_state = ts.fire_step(model=mace_model, state=fire_state) + fire_state = ts.fire_step(state=fire_state, model=mace_model) # Check which states have converged convergence_tensor = convergence_fn(fire_state, None) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 92115c62e..3539ade6c 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -139,13 +139,13 @@ """ # %% -state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.unit) +state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit) # add a little noise so we have something to relax state.positions = state.positions + torch.randn_like(state.positions) * 0.05 for step in range(20): - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) print(f"{step=}: Total energy: {state.energy} eV") @@ -159,11 +159,11 @@ # %% state = ts.fire_init( - model=model, state=state, dt_start=0.02, cell_filter=ts.CellFilter.unit + state=state, model=model, dt_start=0.02, cell_filter=ts.CellFilter.unit ) for step in range(5): - state = ts.fire_step(model=model, state=state, dt_max=0.1) + state = ts.fire_step(state=state, model=model, dt_max=0.1) # %% [markdown] diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 4549a1c12..f0da41c0a 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -456,9 +456,9 @@ def test_in_flight_with_fire( lj_model: LennardJonesModel, num_steps_per_batch: int, ) -> None: - si_fire_state = ts.fire_init(lj_model, si_sim_state, cell_filter=ts.CellFilter.unit) + si_fire_state = ts.fire_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_fire_state = ts.fire_init( - lj_model, fe_supercell_sim_state, cell_filter=ts.CellFilter.unit + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit ) fire_states = [si_fire_state, fe_fire_state] * 5 @@ -493,7 +493,7 @@ def convergence_fn(state: ts.FireState) -> torch.Tensor: break for _ in range(num_steps_per_batch): - state = ts.fire_step(lj_model, state) + state = ts.fire_step(state=state, model=lj_model) convergence_tensor = convergence_fn(state) assert len(all_completed_states) == len(fire_states) @@ -504,9 +504,9 @@ def test_binning_auto_batcher_with_fire( fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - si_fire_state = ts.fire_init(lj_model, si_sim_state, cell_filter=ts.CellFilter.unit) + si_fire_state = ts.fire_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_fire_state = ts.fire_init( - lj_model, fe_supercell_sim_state, cell_filter=ts.CellFilter.unit + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit ) fire_states = [si_fire_state, fe_fire_state] * 5 @@ -528,7 +528,7 @@ def test_binning_auto_batcher_with_fire( for batch, _ in batcher: n_systems += 1 for _ in range(5): - batch = ts.fire_step(lj_model, batch) + batch = ts.fire_step(state=batch, model=lj_model) finished_states.extend(batch.split()) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index e060945e4..d729175dc 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -308,8 +308,8 @@ def test_elastic_tensor_symmetries( # Relax positions and cell state = ts.fire_init( - model=model, state=state, + model=model, scalar_pressure=0.0, cell_filter=ts.CellFilter.frechet, ) @@ -322,7 +322,7 @@ def test_elastic_tensor_symmetries( current_fmax = torch.max(torch.abs(state.forces.squeeze())) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) # Verify the Bravais type of the relaxed structure actual_bravais_type = get_bravais_type(state) @@ -359,8 +359,8 @@ def test_copper_elastic_properties( # Relax positions and cell state = ts.fire_init( - model=mace_model, state=cu_sim_state, + model=mace_model, scalar_pressure=0.0, cell_filter=ts.CellFilter.frechet, ) @@ -372,7 +372,7 @@ def test_copper_elastic_properties( current_fmax = torch.max(torch.abs(state.forces.squeeze())) if current_fmax < fmax and abs(pressure) < 1e-2: break - state = ts.fire_step(model=mace_model, state=state) + state = ts.fire_step(state=state, model=mace_model) # Calculate elastic tensor bravais_type = get_bravais_type(state) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index a4bce4fcc..1b6cc97e3 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -28,13 +28,13 @@ def test_gradient_descent_optimization( # Initialize Gradient Descent optimizer state = ts.gradient_descent_init( - model=lj_model, state=ar_supercell_sim_state, lr=0.01 + state=ar_supercell_sim_state, model=lj_model, lr=0.01 ) # Run optimization for a few steps energies = [1000, state.energy.item()] while abs(energies[-2] - energies[-1]) > 1e-6: - state = ts.gradient_descent_step(model=lj_model, state=state, pos_lr=0.01) + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) energies.append(state.energy.item()) energies = energies[1:] @@ -67,14 +67,14 @@ def test_unit_cell_gradient_descent_optimization( # Initialize Gradient Descent optimizer with unit cell filter state = ts.gradient_descent_init( - model=lj_model, state=ar_supercell_sim_state, cell_filter=ts.CellFilter.unit + state=ar_supercell_sim_state, model=lj_model, cell_filter=ts.CellFilter.unit ) # Run optimization for a few steps energies = [1000, state.energy.item()] while abs(energies[-2] - energies[-1]) > 1e-6: state = ts.gradient_descent_step( - model=lj_model, state=state, pos_lr=0.01, cell_lr=0.1 + state=state, model=lj_model, pos_lr=0.01, cell_lr=0.1 ) energies.append(state.energy.item()) @@ -123,14 +123,14 @@ def test_fire_optimization( initial_state_positions = current_sim_state.positions.clone() # Initialize FIRE optimizer - state = ts.fire_init(lj_model, current_sim_state, md_flavor=md_flavor, dt_start=0.1) + state = ts.fire_init(current_sim_state, lj_model, md_flavor=md_flavor, dt_start=0.1) # Run optimization for a few steps energies = [1000, state.energy.item()] max_steps = 1000 # Add max step to prevent infinite loop steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = ts.fire_step(lj_model, state, dt_max=0.3) + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -209,8 +209,8 @@ def test_fire_ase_negative_power_branch( dt_start_val = 0.1 state = ts.fire_init( - model=lj_model, state=ar_supercell_sim_state, + model=lj_model, md_flavor="ase_fire", alpha_start=alpha_start, dt_start=dt_start_val, @@ -232,8 +232,8 @@ def test_fire_ase_negative_power_branch( # Deepcopy state as step_fn modifies it in-place state_to_update = copy.deepcopy(state) updated_state = ts.fire_step( - lj_model, - state_to_update, + state=state_to_update, + model=lj_model, f_dec=f_dec, dt_max=1.0, max_step=10.0, # Large max_step to not interfere with velocity check @@ -276,8 +276,8 @@ def test_fire_vv_negative_power_branch( dt_max_val = 2.0 state = ts.fire_init( - model=lj_model, state=ar_supercell_sim_state, + model=lj_model, md_flavor="vv_fire", alpha_start=alpha_start, dt_start=dt_start_val, @@ -288,8 +288,8 @@ def test_fire_vv_negative_power_branch( state_to_update = copy.deepcopy(state) updated_state = ts.fire_step( - lj_model, - state_to_update, + state=state_to_update, + model=lj_model, f_dec=f_dec, dt_max=dt_max_val, n_min=0, # Allow dt to change immediately @@ -350,8 +350,8 @@ def test_unit_cell_fire_optimization( # Initialize FIRE optimizer with unit cell filter state = ts.fire_init( - model=lj_model, state=current_sim_state, + model=lj_model, dt_start=0.1, md_flavor=md_flavor, cell_filter=ts.CellFilter.unit, @@ -363,7 +363,7 @@ def test_unit_cell_fire_optimization( steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = ts.fire_step(lj_model, state, dt_max=0.3) + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -509,8 +509,8 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( perturbed_state.cell[0] *= 2.0 state = ts.fire_init( - model=lj_model, state=perturbed_state, + model=lj_model, md_flavor="ase_fire", dt_start=1.0, alpha_start=0.99, # Aggressive alpha @@ -520,8 +520,8 @@ def test_unit_cell_fire_ase_non_positive_volume_warning( # Run a few steps hoping to trigger the warning for _ in range(5): state = ts.fire_step( - lj_model, - state, + state=state, + model=lj_model, dt_max=5.0, # Large dt max_step=2.0, # Large max_step f_dec=0.99, # Slow down dt decrease @@ -563,8 +563,8 @@ def test_frechet_cell_fire_optimization( initial_state_cell = current_sim_state.cell.clone() state = ts.fire_init( - model=lj_model, state=current_sim_state, + model=lj_model, dt_start=0.1, md_flavor=md_flavor, cell_filter=ts.CellFilter.frechet, @@ -576,7 +576,7 @@ def test_frechet_cell_fire_optimization( steps_taken = 0 while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: - state = ts.fire_step(model=lj_model, state=state, dt_max=0.3) + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) energies.append(state.energy.item()) steps_taken += 1 @@ -770,7 +770,7 @@ def test_unit_cell_fire_multi_batch( # Initialize FIRE optimizer with unit cell filter state = ts.fire_init( - model=lj_model, state=multi_state, dt_start=0.1, cell_filter=ts.CellFilter.unit + state=multi_state, model=lj_model, dt_start=0.1, cell_filter=ts.CellFilter.unit ) initial_state = copy.deepcopy(state) @@ -780,7 +780,7 @@ def test_unit_cell_fire_multi_batch( step = 0 while not torch.allclose(current_energy, prev_energy, atol=1e-9): prev_energy = current_energy - state = ts.fire_step(model=lj_model, state=state, dt_max=0.3) + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) current_energy = state.energy step += 1 @@ -841,8 +841,8 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): state_opt = ts.fire_init( - lj_model, state, + lj_model, dt_start=0.1, cell_filter=ts.CellFilter.unit, hydrostatic_strain=True, @@ -856,7 +856,7 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> step = 0 while energy_converged(current_energy, prev_energy): prev_energy = current_energy - state_opt = ts.fire_step(lj_model, state_opt, dt_max=0.3) + state_opt = ts.fire_step(state=state_opt, model=lj_model, dt_max=0.3) current_energy = state_opt.energy step += 1 if step > 1000: @@ -874,7 +874,7 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> return not torch.allclose(current_energy, prev_energy, atol=1e-6) for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): - state_opt = ts.fire_init(model=lj_model, state=state, dt_start=0.1) + state_opt = ts.fire_init(state=state, model=lj_model, dt_start=0.1) # Run optimization until convergence current_energy = state_opt.energy @@ -883,7 +883,7 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> step = 0 while energy_converged(current_energy, prev_energy): prev_energy = current_energy - state_opt = ts.fire_step(model=lj_model, state=state_opt, dt_max=0.3) + state_opt = ts.fire_step(state=state_opt, model=lj_model, dt_max=0.3) current_energy = state_opt.energy step += 1 if step > 1000: diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 653fc5a81..3ff0cf2de 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -88,8 +88,8 @@ def deform_grad(reference_cell: torch.Tensor, current_cell: torch.Tensor) -> tor def unit_cell_filter_init[T: AnyCellState]( - model: ModelInterface, state: T, + model: ModelInterface, *, cell_factor: float | torch.Tensor | None = None, hydrostatic_strain: bool = False, @@ -137,8 +137,8 @@ def unit_cell_filter_init[T: AnyCellState]( def frechet_cell_filter_init[T: AnyCellState]( - model: ModelInterface, state: T, + model: ModelInterface, *, cell_factor: float | torch.Tensor | None = None, hydrostatic_strain: bool = False, diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 147324c36..e3feb3a77 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -22,8 +22,8 @@ def fire_init( - model: "ModelInterface", state: SimState | StateDict, + model: "ModelInterface", *, dt_start: float = 0.1, alpha_start: float = 0.1, @@ -99,7 +99,7 @@ def fire_init( cell_state = CellFireState(**common_args) # Initialize cell-specific attributes - init_fn(model, cell_state, **filter_kwargs) + init_fn(cell_state, model, **filter_kwargs) # Initialize cell velocities after cell_forces is set cell_state.cell_velocities = torch.full( @@ -112,8 +112,8 @@ def fire_init( def fire_step( - model: "ModelInterface", state: "FireState | CellFireState", + model: "ModelInterface", *, dt_max: float = 1.0, n_min: int = 5, @@ -405,7 +405,9 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) dr_atom = torch.where( - dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + dr_scaling_atom > max_step, + max_step * dr_atom / (dr_scaling_atom + eps), + dr_atom, ) # Position updates diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 252cc88ac..bfdfcf3f9 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -17,8 +17,8 @@ def gradient_descent_init( - model: "ModelInterface", state: SimState | StateDict, + model: "ModelInterface", *, cell_filter: "CellFilter | CellFilterFuncs | None" = None, **filter_kwargs: Any, @@ -70,7 +70,7 @@ def gradient_descent_init( cell_state = CellOptimState(**common_args) # Initialize cell-specific attributes - init_fn(model, cell_state, **filter_kwargs) + init_fn(cell_state, model, **filter_kwargs) return cell_state # Create regular OptimState without cell optimization @@ -78,8 +78,8 @@ def gradient_descent_init( def gradient_descent_step( - model: "ModelInterface", state: "OptimState | CellOptimState", + model: "ModelInterface", *, pos_lr: float | torch.Tensor = 0.01, cell_lr: float | torch.Tensor = 0.1, diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 2d5e2c618..801d94adb 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -307,7 +307,7 @@ def random_packed_structure( cell=cell, pbc=True, ) - state = ts.fire_init(model, state) + state = ts.fire_init(state, model) print(f"Initial energy: {state.energy.item():.4f}") # Run FIRE optimization until convergence or max iterations for _step in range(max_iter): @@ -317,7 +317,7 @@ def random_packed_structure( log.append(state.positions.cpu().numpy()) - state = ts.fire_step(model, state) + state = ts.fire_step(state, model) print(f"Final energy: {state.energy.item():.4f}") @@ -432,7 +432,7 @@ def random_packed_structure_multi( pbc=True, ) # Set up FIRE optimizer with unit masses for all atoms - state = ts.fire_init(model, state_dict) + state = ts.fire_init(state_dict, model) print(f"Initial energy: {state.energy.item():.4f}") # Run FIRE optimization until convergence or max iterations for _step in range(max_iter): @@ -440,7 +440,7 @@ def random_packed_structure_multi( min_dist = min_distance(state.positions, cell, distance_tolerance) if min_dist > diameter_matrix.min() * 0.95: break - state = ts.fire_step(model, state) + state = ts.fire_step(state, model) print(f"Final energy: {state.energy.item():.4f}") return state @@ -741,14 +741,14 @@ def get_unit_cell_relaxed_structure( f"Initial pressure: {[f'{p:.4f}' for p in init_pressure]} eV/A^3" ) - state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.unit) + state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit) def step_fn( step: int, state: ts.FireState, logger: dict[str, torch.Tensor] ) -> tuple[ts.FireState, dict[str, torch.Tensor]]: logger["energy"][step] = state.energy logger["stress"][step] = state.stress - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) return state, logger for step in range(max_iter): @@ -814,14 +814,14 @@ def get_frechet_cell_relaxed_structure( f"Initial pressure: {[f'{p:.4f}' for p in init_pressure]} eV/A^3" ) - state = ts.fire_init(model=model, state=state, cell_filter=ts.CellFilter.frechet) + state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.frechet) def step_fn( step: int, state: ts.FireState, logger: dict[str, torch.Tensor] ) -> tuple[ts.FireState, dict]: logger["energy"][step] = state.energy logger["stress"][step] = state.stress - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) return state, logger for step in range(max_iter): @@ -875,13 +875,13 @@ def get_relaxed_structure( if verbose: print(f"Initial energy: {[f'{e:.4f}' for e in init_energy]} eV") - state = ts.fire_init(model=model, state=state) + state = ts.fire_init(state=state, model=model) def step_fn( idx: int, state: FireState, logger: dict[str, torch.Tensor] ) -> tuple[FireState, dict[str, torch.Tensor]]: logger["energy"][idx] = state.energy - state = ts.fire_step(model=model, state=state) + state = ts.fire_step(state=state, model=model) return state, logger for idx in range(max_iter): From a26a1bba38e7f94ec2ed5f7c9572bbd65d12e237 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:27:38 -0400 Subject: [PATCH 38/40] standardize state and model order in integrate function in runners --- torch_sim/runners.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index ded757f6a..e1b798856 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -191,7 +191,7 @@ def integrate[T: SimState]( # noqa: C901 # Handle both BinningAutoBatcher and list of tuples for state, system_indices in batch_iterator: # Pass correct parameters based on integrator type - state = init_func(model=model, state=state, kT=kTs[0], dt=dt, **integrator_kwargs) + state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **integrator_kwargs) # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: @@ -202,7 +202,9 @@ def integrate[T: SimState]( # noqa: C901 # run the simulation for step in range(1, n_steps + 1): - state = step_func(model, state, dt=dt, kT=kTs[step - 1], **integrator_kwargs) + state = step_func( + state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs + ) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) @@ -490,7 +492,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 if hasattr(state, "energy"): last_energy = state.energy - state = step_fn(model=model, state=state, **optimizer_kwargs) + state = step_fn(state=state, model=model, **optimizer_kwargs) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) From 077218b58df65942af0c68b5ce893da5b8ef8f70 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:34:29 -0400 Subject: [PATCH 39/40] remove top level use of cell_filter in all optimize api calls --- README.md | 2 +- .../scripts/4_High_level_api/4.1_high_level_api.py | 4 ++-- examples/scripts/6_Phonons/6.1_Phonons_MACE.py | 5 +++-- .../scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py | 14 ++++++++++---- .../scripts/6_Phonons/6.3_Conductivity_MACE.py | 5 +++-- examples/tutorials/high_level_tutorial.py | 4 ++-- 6 files changed, 21 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index f5c52c54c..a6f8a0109 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,8 @@ relaxed_state = ts.optimize( system=final_state, model=mace_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, autobatcher=True, + init_kwargs=dict(cell_filter=ts.CellFilter.frechet), ) print(relaxed_state.energy) diff --git a/examples/scripts/4_High_level_api/4.1_high_level_api.py b/examples/scripts/4_High_level_api/4.1_high_level_api.py index 5568b3bc4..4caf068e3 100644 --- a/examples/scripts/4_High_level_api/4.1_high_level_api.py +++ b/examples/scripts/4_High_level_api/4.1_high_level_api.py @@ -160,8 +160,8 @@ system=systems, model=mace_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, max_steps=10 if SMOKE_TEST else 1000, + init_kwargs=dict(cell_filter=ts.CellFilter.unit), ) rng = np.random.default_rng() @@ -172,10 +172,10 @@ system=systems, model=mace_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=lambda state, last_energy: last_energy - state.energy < 1e-6 * MetalUnits.energy, max_steps=10 if SMOKE_TEST else 1000, + init_kwargs=dict(cell_filter=ts.CellFilter.unit), ) diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index 1508b0681..57cbe2f07 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -117,9 +117,10 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b system=struct, model=model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, max_steps=max_steps, - init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), + init_kwargs=dict( + cell_filter=ts.CellFilter.frechet, constant_volume=True, hydrostatic_strain=True + ), ) # Define atoms and Phonopy object diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index d5bfadca2..1d52d7749 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -63,12 +63,15 @@ def get_relaxed_structure( system=struct, model=model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, max_steps=max_steps, convergence_fn=converge_max_force, trajectory_reporter=reporter, autobatcher=use_autobatcher, - init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), + init_kwargs=dict( + cell_filter=ts.CellFilter.frechet, + constant_volume=True, + hydrostatic_strain=True, + ), ) os.remove(trajectory_file) @@ -116,11 +119,14 @@ def get_qha_structures( system=scaled_structs, model=model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, max_steps=Nmax, convergence_fn=ts.runners.generate_force_convergence_fn(force_tol=fmax), autobatcher=use_autobatcher, - init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), + init_kwargs=dict( + cell_filter=ts.CellFilter.frechet, + constant_volume=True, + hydrostatic_strain=True, + ), ) return scaled_state.to_phonopy() diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index 0fc842169..4bd3a705c 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -102,11 +102,12 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: system=struct, model=model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.frechet, max_steps=max_steps, convergence_fn=converge_max_force, trajectory_reporter=reporter, - init_kwargs=dict(constant_volume=True, hydrostatic_strain=True), + init_kwargs=dict( + cell_filter=ts.CellFilter.frechet, constant_volume=True, hydrostatic_strain=True + ), ) print_relax_info(trajectory_file, device) diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index 834cac7f1..9a265e61e 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -357,7 +357,7 @@ system=systems, model=mace_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, + init_kwargs=dict(cell_filter=ts.CellFilter.unit), ) final_atoms = final_state.to_atoms() @@ -405,8 +405,8 @@ def default_energy_convergence(state, last_energy): system=systems, model=mace_model, optimizer=ts.OptimFlavor.fire, - cell_filter=ts.CellFilter.unit, convergence_fn=force_convergence_fn, # Custom convergence function + init_kwargs=dict(cell_filter=ts.CellFilter.unit), ) final_atoms = final_state.to_atoms() From 2954183005ae474fedef2cfbf57e605a1e70bd7d Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 8 Oct 2025 21:41:18 -0400 Subject: [PATCH 40/40] fix cell filter comparison --- tests/test_optimizers_vs_ase.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index f8e83b386..b4eb75977 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -112,6 +112,7 @@ def _run_and_compare_optimizers( ts_mace_mpa: MaceModel, ase_mace_mpa: "MACECalculator", fire_type: ts.OptimFlavor, + cell_filter: ts.CellFilter, ase_filter_cls: FrechetCellFilter | UnitCellFilter, checkpoints: list[int], force_tol: float, @@ -162,6 +163,7 @@ def _run_and_compare_optimizers( convergence_fn=convergence_fn, steps_between_swaps=1, md_flavor="ase_fire", # optimizer kwargs + init_kwargs=dict(cell_filter=cell_filter), **optim_kwargs, ) state = updated_ts_state.clone()