From d525009e2ceace7306c2b205b9be8e834500da45 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 30 Oct 2025 10:22:47 +0100 Subject: [PATCH 01/19] Add CSVR thermostat --- tests/test_integrators.py | 52 +++++++++ torch_sim/__init__.py | 2 + torch_sim/integrators/__init__.py | 22 +++- torch_sim/integrators/nvt.py | 181 ++++++++++++++++++++++++++++++ 4 files changed, 251 insertions(+), 6 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 2e53cb302..2edb1ce5a 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -477,6 +477,58 @@ def test_nvt_nose_hoover_multi_kt( assert invariant_std / invariant_traj.mean() < 0.1 +def test_nvt_csvr(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + n_steps = 100 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + # Initialize integrator + state = ts.nvt_csvr_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_csvr_step(model=lj_model, state=state, dt=dt, kT=kT) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index a5a8af765..f99e7ab25 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -29,6 +29,8 @@ NVTNoseHooverState, nve_init, nve_step, + nvt_csvr_init, + nvt_csvr_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index 8fc88f47a..098833d0a 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -8,23 +8,27 @@ NVE: - Velocity Verlet integrator for constant energy simulations :func:`nve.nve_step` NVT: + - Canonical Sampling velocity rescaling (CSVR) thermostat integrator + :func:`nvt.nvt_csvr_step` [1] - Langevin thermostat integrator :func:`nvt.nvt_langevin_step` - using BAOAB scheme [1] - - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [2] + using BAOAB scheme [2] + - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [3] NPT: - Langevin barostat integrator :func:`npt.npt_langevin_step` [3, 4] - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [2] References: - [1] Leimkuhler B, Matthews C.2016 Efficient molecular dynamics using geodesic + [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." + The Journal of chemical physics, 126(1), 014101 (2007). + [2] Leimkuhler B, Matthews C.2016 Efficient molecular dynamics using geodesic integration and solvent-solute splitting. Proc. R. Soc. A 472: 20160138 - [2] Martyna, G. J., Tuckerman, M. E., Tobias, D. J., & Klein, M. L. (1996). + [3] Martyna, G. J., Tuckerman, M. E., Tobias, D. J., & Klein, M. L. (1996). Explicit reversible integrators for extended systems dynamics. Molecular Physics, 87(5), 1117-1157. - [3] Grønbech-Jensen, N., & Farago, O. (2014). + [4] Grønbech-Jensen, N., & Farago, O. (2014). Constant pressure and temperature discrete-time Langevin molecular dynamics. The Journal of chemical physics, 141(19). - [4] LAMMPS: https://docs.lammps.org/fix_press_langevin.html + [5] LAMMPS: https://docs.lammps.org/fix_press_langevin.html Examples: @@ -62,6 +66,8 @@ from .nve import nve_init, nve_step from .nvt import ( NVTNoseHooverState, + nvt_csvr_init, + nvt_csvr_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, @@ -79,6 +85,7 @@ class Integrator(StrEnum): Available options: - ``nve``: Constant energy (microcanonical) ensemble. + - ``nvt_csvr``: CSVR thermostat for constant temperature. - ``nvt_langevin``: Langevin thermostat for constant temperature. - ``nvt_nose_hoover``: Nosé-Hoover thermostat for constant temperature. - ``npt_langevin``: Langevin barostat for constant temperature and pressure. @@ -93,6 +100,7 @@ class Integrator(StrEnum): """ nve = "nve" + nvt_csvr = "nvt_csvr" nvt_langevin = "nvt_langevin" nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" @@ -116,6 +124,7 @@ class Integrator(StrEnum): #: The available integrators are: #: #: - ``Integrator.nve``: Velocity Verlet (microcanonical) +#: - ``Integrator.nvt_csvr``: Canonical Sampling velocity rescaling (CSVR) thermostat #: - ``Integrator.nvt_langevin``: Langevin thermostat #: - ``Integrator.nvt_nose_hoover``: Nosé-Hoover thermostat #: - ``Integrator.npt_langevin``: Langevin barostat @@ -126,6 +135,7 @@ class Integrator(StrEnum): dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] ] = { Integrator.nve: (nve_init, nve_step), + Integrator.nvt_csvr: (nvt_csvr_init, nvt_csvr_step), Integrator.nvt_langevin: (nvt_langevin_init, nvt_langevin_step), Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 478abfa9f..30d06de83 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -458,3 +458,184 @@ def nvt_nose_hoover_invariant( e_tot = e_tot + chain_ke + chain_pe return e_tot + + +def _csvr_update( + state: MDState, + tau: float | torch.Tensor, + kT: float | torch.Tensor, + dt: float | torch.Tensor, +) -> MDState: + """Update the momentum by a scaling factor as described by Eq.A7 Bussi et al. + + Note that we don't implement the optimize code from Bussi, which won't be useful + on a high level framework like PyTorch. + + Args: + state: Current MD state + tau: Thermostat relaxation time + kT: Target temperature + dt: Integration timestep + + Returns: + Updated state with rescaled momenta + """ + device, dtype = state.device, state.dtype + + # Convert all inputs to tensors + tau_tensor = torch.as_tensor(tau, device=device, dtype=dtype) + kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) + dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) + + # Calculate current temperature per system + current_kT = ts.quantities.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + + # Calculate degrees of freedom per system + n_atoms_per_system = torch.bincount(state.system_idx) + dof = n_atoms_per_system * state.positions.shape[-1] + + # Ensure kT and tau have proper batch dimensions + n_systems = current_kT.shape[0] + if kT_tensor.dim() == 0: + kT_tensor = kT_tensor.expand(n_systems) + if tau_tensor.dim() == 0: + tau_tensor = tau_tensor.expand(n_systems) + + # Calculate kinetic energies + KE_old = dof * current_kT / 2 + KE_new = dof * kT_tensor / 2 + + # Generate random numbers + r1 = torch.randn(n_systems, device=device, dtype=dtype) + # Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1) + r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample() + + # Calculate scaling coefficients + c1 = torch.exp(-dt_tensor / tau_tensor) + c2 = (1 - c1) * KE_new / KE_old / dof + + # Calculate scaling factor + scale = c1 + (c2 * (torch.square(r1) + r2)) + (2 * r1 * torch.sqrt(c1 * c2)) + lam = torch.sqrt(scale) + + # Apply scaling to momenta - map from system to atom indices + state.momenta = state.momenta * lam[state.system_idx].unsqueeze(-1) + return state + + +def nvt_csvr_init( + state: SimState | StateDict, + model: ModelInterface, + *, + kT: float | torch.Tensor, + seed: int | None = None, + **_kwargs: Any, +) -> MDState: + """Initialize an NVT state from input data for CSVR dynamics. + + Creates an initial state for NVT molecular dynamics using the canonical + sampling through velocity rescaling (CSVR) thermostat. This thermostat + samples from the canonical ensemble by rescaling velocities with an + appropriately chosen random factor. + + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Either a SimState object or a dictionary containing positions, + masses, cell, pbc, and other required state vars + kT: Temperature in energy units for initializing momenta, + either scalar or with shape [n_systems] + seed: Random seed for reproducibility + + Returns: + MDState: Initialized state for NVT integration containing positions, + momenta, forces, energy, and other required attributes + + Notes: + The initial momenta are sampled from a Maxwell-Boltzmann distribution + at the specified temperature. The CSVR thermostat provides proper + canonical sampling through stochastic velocity rescaling. + """ + if not isinstance(state, SimState): + state = SimState(**state) + + model_output = model(state) + + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + return MDState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + ) + + +def nvt_csvr_step( + model: ModelInterface, + state: MDState, + *, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + tau: float | torch.Tensor | None = None, +) -> MDState: + """Perform one complete CSVR dynamics integration step. + + This function implements the canonical sampling through velocity rescaling (CSVR) + thermostat combined with velocity Verlet integration. The CSVR thermostat samples + the canonical distribution by rescaling velocities with a properly chosen random + factor that ensures correct canonical sampling. + + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Current system state containing positions, momenta, forces + dt: Integration timestep, either scalar or shape [n_systems] + kT: Target temperature in energy units, either scalar or + with shape [n_systems] + tau: Thermostat relaxation time controlling the coupling strength, + either scalar or with shape [n_systems]. Defaults to 100*dt. + seed: Random seed for reproducibility + + Returns: + MDState: Updated state after one complete CSVR step with new positions, + momenta, forces, and energy + + Notes: + - Uses CSVR thermostat for proper canonical ensemble sampling + - Unlike Berendsen thermostat, CSVR samples the true canonical distribution + - Integration sequence: CSVR rescaling + Velocity Verlet step + - The rescaling factor follows the distribution derived in Bussi et al. + + References: + Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." + The Journal of chemical physics, 126(1), 014101 (2007). + """ + device, dtype = model.device, model.dtype + + if tau is None: + tau = 100 * dt + + if isinstance(tau, float): + tau = torch.tensor(tau, device=device, dtype=dtype) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + + # Apply CSVR rescaling + state = _csvr_update(state, tau, kT, dt) + + # Perform velocity Verlet step + return velocity_verlet(state=state, dt=dt, model=model) From 4ddb14cac748c777835202cb4b79314fe485687c Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 30 Oct 2025 10:51:06 +0100 Subject: [PATCH 02/19] rename to vrescale --- tests/test_integrators.py | 6 +++--- torch_sim/__init__.py | 4 ++-- torch_sim/integrators/__init__.py | 16 ++++++++-------- torch_sim/integrators/nvt.py | 28 ++++++++++++++-------------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 2edb1ce5a..1f838d20e 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -477,17 +477,17 @@ def test_nvt_nose_hoover_multi_kt( assert invariant_std / invariant_traj.mean() < 0.1 -def test_nvt_csvr(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): +def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): n_steps = 100 dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nvt_csvr_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) + state = ts.nvt_vrescale_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) energies = [] temperatures = [] for _step in range(n_steps): - state = ts.nvt_csvr_step(model=lj_model, state=state, dt=dt, kT=kT) + state = ts.nvt_vrescale_step(model=lj_model, state=state, dt=dt, kT=kT) # Calculate instantaneous temperature from kinetic energy temp = ts.calc_kT( diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index f99e7ab25..f35f58eae 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -29,8 +29,8 @@ NVTNoseHooverState, nve_init, nve_step, - nvt_csvr_init, - nvt_csvr_step, + nvt_vrescale_init, + nvt_vrescale_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index 098833d0a..cfd2bbd70 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -8,8 +8,8 @@ NVE: - Velocity Verlet integrator for constant energy simulations :func:`nve.nve_step` NVT: - - Canonical Sampling velocity rescaling (CSVR) thermostat integrator - :func:`nvt.nvt_csvr_step` [1] + - Velocity Rescaling thermostat integrator + :func:`nvt.nvt_vrescale_step` [1] - Langevin thermostat integrator :func:`nvt.nvt_langevin_step` using BAOAB scheme [2] - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [3] @@ -66,8 +66,8 @@ from .nve import nve_init, nve_step from .nvt import ( NVTNoseHooverState, - nvt_csvr_init, - nvt_csvr_step, + nvt_vrescale_init, + nvt_vrescale_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, @@ -85,7 +85,7 @@ class Integrator(StrEnum): Available options: - ``nve``: Constant energy (microcanonical) ensemble. - - ``nvt_csvr``: CSVR thermostat for constant temperature. + - ``nvt_vrescale``: Velocity rescaling thermostat for constant temperature. - ``nvt_langevin``: Langevin thermostat for constant temperature. - ``nvt_nose_hoover``: Nosé-Hoover thermostat for constant temperature. - ``npt_langevin``: Langevin barostat for constant temperature and pressure. @@ -100,7 +100,7 @@ class Integrator(StrEnum): """ nve = "nve" - nvt_csvr = "nvt_csvr" + nvt_vrescale = "nvt_vrescale" nvt_langevin = "nvt_langevin" nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" @@ -124,7 +124,7 @@ class Integrator(StrEnum): #: The available integrators are: #: #: - ``Integrator.nve``: Velocity Verlet (microcanonical) -#: - ``Integrator.nvt_csvr``: Canonical Sampling velocity rescaling (CSVR) thermostat +#: - ``Integrator.nvt_vrescale``: V-Rescale thermostat #: - ``Integrator.nvt_langevin``: Langevin thermostat #: - ``Integrator.nvt_nose_hoover``: Nosé-Hoover thermostat #: - ``Integrator.npt_langevin``: Langevin barostat @@ -135,7 +135,7 @@ class Integrator(StrEnum): dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] ] = { Integrator.nve: (nve_init, nve_step), - Integrator.nvt_csvr: (nvt_csvr_init, nvt_csvr_step), + Integrator.nvt_vrescale: (nvt_vrescale_init, nvt_vrescale_step), Integrator.nvt_langevin: (nvt_langevin_init, nvt_langevin_step), Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 30d06de83..d38e33d7a 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -460,7 +460,7 @@ def nvt_nose_hoover_invariant( return e_tot -def _csvr_update( +def _vrescale_update( state: MDState, tau: float | torch.Tensor, kT: float | torch.Tensor, @@ -525,7 +525,7 @@ def _csvr_update( return state -def nvt_csvr_init( +def nvt_vrescale_init( state: SimState | StateDict, model: ModelInterface, *, @@ -533,7 +533,7 @@ def nvt_csvr_init( seed: int | None = None, **_kwargs: Any, ) -> MDState: - """Initialize an NVT state from input data for CSVR dynamics. + """Initialize an NVT state from input data for velocity rescaling dynamics. Creates an initial state for NVT molecular dynamics using the canonical sampling through velocity rescaling (CSVR) thermostat. This thermostat @@ -555,7 +555,7 @@ def nvt_csvr_init( Notes: The initial momenta are sampled from a Maxwell-Boltzmann distribution - at the specified temperature. The CSVR thermostat provides proper + at the specified temperature. The V-Rescale thermostat provides proper canonical sampling through stochastic velocity rescaling. """ if not isinstance(state, SimState): @@ -582,7 +582,7 @@ def nvt_csvr_init( ) -def nvt_csvr_step( +def nvt_vrescale_step( model: ModelInterface, state: MDState, *, @@ -590,10 +590,10 @@ def nvt_csvr_step( kT: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> MDState: - """Perform one complete CSVR dynamics integration step. + """Perform one complete V-Rescale dynamics integration step. - This function implements the canonical sampling through velocity rescaling (CSVR) - thermostat combined with velocity Verlet integration. The CSVR thermostat samples + This function implements the canonical sampling through velocity rescaling (V-Rescale) + thermostat combined with velocity Verlet integration. The V-Rescale thermostat samples the canonical distribution by rescaling velocities with a properly chosen random factor that ensures correct canonical sampling. @@ -609,13 +609,13 @@ def nvt_csvr_step( seed: Random seed for reproducibility Returns: - MDState: Updated state after one complete CSVR step with new positions, + MDState: Updated state after one complete V-Rescale step with new positions, momenta, forces, and energy Notes: - - Uses CSVR thermostat for proper canonical ensemble sampling - - Unlike Berendsen thermostat, CSVR samples the true canonical distribution - - Integration sequence: CSVR rescaling + Velocity Verlet step + - Uses V-Rescale thermostat for proper canonical ensemble sampling + - Unlike Berendsen thermostat, V-Rescale samples the true canonical distribution + - Integration sequence: V-Rescale rescaling + Velocity Verlet step - The rescaling factor follows the distribution derived in Bussi et al. References: @@ -634,8 +634,8 @@ def nvt_csvr_step( if isinstance(kT, float): kT = torch.tensor(kT, device=device, dtype=dtype) - # Apply CSVR rescaling - state = _csvr_update(state, tau, kT, dt) + # Apply V-Rescale rescaling + state = _vrescale_update(state, tau, kT, dt) # Perform velocity Verlet step return velocity_verlet(state=state, dt=dt, model=model) From e5bc24e13edd40412a9234b478f5dd690e0f7bf1 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 30 Oct 2025 11:32:18 +0100 Subject: [PATCH 03/19] remove redundunt variable with inheritance from MDState --- torch_sim/integrators/npt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 187405f91..685c5a196 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -55,8 +55,6 @@ class NPTLangevinState(MDState): """ # System state variables - energy: torch.Tensor - forces: torch.Tensor stress: torch.Tensor # Cell variables From 5adf95ed51db49d06de5af287c1a7e2467d076bf Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 30 Oct 2025 17:58:48 +0100 Subject: [PATCH 04/19] Add Cell rescaling NPT --- tests/test_integrators.py | 77 ++++++- torch_sim/__init__.py | 6 +- torch_sim/integrators/__init__.py | 22 +- torch_sim/integrators/npt.py | 329 ++++++++++++++++++++++++++++++ torch_sim/integrators/nvt.py | 1 + 5 files changed, 428 insertions(+), 7 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 1f838d20e..03fedc6cf 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -483,7 +483,9 @@ def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Initialize integrator - state = ts.nvt_vrescale_init(state=ar_double_sim_state, model=lj_model, kT=kT, seed=42) + state = ts.nvt_vrescale_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) energies = [] temperatures = [] for _step in range(n_steps): @@ -529,6 +531,79 @@ def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo assert pos_diff > 0.0001 # Systems should remain separated +def test_npt_crescale( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + n_steps = 200 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + tau_p = torch.tensor(0.1, dtype=DTYPE) + isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) + + # Initialize integrator using new direct API + state = ts.npt_crescale_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + seed=42, + ) + + # Run dynamics for several steps + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.npt_crescale_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index f35f58eae..ccf0fb32a 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -29,17 +29,19 @@ NVTNoseHooverState, nve_init, nve_step, - nvt_vrescale_init, - nvt_vrescale_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, nvt_nose_hoover_invariant, nvt_nose_hoover_step, + nvt_vrescale_init, + nvt_vrescale_step, ) from torch_sim.integrators.npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_init, + npt_crescale_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index cfd2bbd70..b8e67e9c8 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -14,8 +14,9 @@ using BAOAB scheme [2] - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [3] NPT: - - Langevin barostat integrator :func:`npt.npt_langevin_step` [3, 4] - - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [2] + - Langevin barostat integrator :func:`npt.npt_langevin_step` [4, 5] + - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [3] + - C-Rescale barostat integrator :func:`npt.npt_crescale_step` from [6, 7, 8] References: [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." @@ -29,6 +30,14 @@ Constant pressure and temperature discrete-time Langevin molecular dynamics. The Journal of chemical physics, 141(19). [5] LAMMPS: https://docs.lammps.org/fix_press_langevin.html + [6] Bernetti, Mattia, and Giovanni Bussi. + "Pressure control using stochastic cell rescaling." + The Journal of Chemical Physics 153.11 (2020). + [7] Del Tatto, Vittorio, et al. "Molecular dynamics of solids at + constant pressure and stress using anisotropic stochastic cell rescaling." + Applied Sciences 12.3 (2022): 1139. + [8] Bussi Anisotropic C-Rescale SimpleMD implementation: + https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp Examples: @@ -57,6 +66,8 @@ from .npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_init, + npt_crescale_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, @@ -66,13 +77,13 @@ from .nve import nve_init, nve_step from .nvt import ( NVTNoseHooverState, - nvt_vrescale_init, - nvt_vrescale_step, nvt_langevin_init, nvt_langevin_step, nvt_nose_hoover_init, nvt_nose_hoover_invariant, nvt_nose_hoover_step, + nvt_vrescale_init, + nvt_vrescale_step, ) @@ -105,6 +116,7 @@ class Integrator(StrEnum): nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" npt_nose_hoover = "npt_nose_hoover" + npt_crescale = "npt_crescale" #: Integrator registry - maps integrator names to (init_fn, step_fn) pairs. @@ -129,6 +141,7 @@ class Integrator(StrEnum): #: - ``Integrator.nvt_nose_hoover``: Nosé-Hoover thermostat #: - ``Integrator.npt_langevin``: Langevin barostat #: - ``Integrator.npt_nose_hoover``: Nosé-Hoover barostat +#: - ``Integrator.npt_crescale``: C-Rescale barostat #: #: :type: dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] INTEGRATOR_REGISTRY: Final[ @@ -140,4 +153,5 @@ class Integrator(StrEnum): Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), Integrator.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_step), + Integrator.npt_crescale: (npt_crescale_init, npt_crescale_step), } diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 685c5a196..1758244f1 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -13,7 +13,9 @@ NoseHooverChainFns, calculate_momenta, construct_nose_hoover_chain, + momentum_step, ) +from torch_sim.integrators.nvt import _vrescale_update from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -1600,3 +1602,330 @@ def npt_nose_hoover_invariant( e_tot += torch.square(cell_momentum) / (2 * state.cell_mass) return e_tot + + +################### +# Implement full anisotropic NPT with cell rescaling barostat +# Choices: +# - Time reversible integrator +# - Instantenous kinetic energy (not not the average) +# - According to the authors should be better for constraints +# Inspiration from Bussi SimpleMD repo +# https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp#L681C6-L688C16 +# thermostat +# // update velocities +# // barostat +# // update positions +# // (eventually recompute neighbour list) +# // compute forces +# // update velocities +# // thermostat + + +@dataclass(kw_only=True) +class NPTCRescaleState(MDState): + """State for NPT ensemble with cell rescaling barostat. + + This class extends the MDState to include variables and properties + specific to the NPT ensemble with a cell rescaling barostat. + """ + + # System state variables + stress: torch.Tensor + isothermal_compressibility: torch.Tensor # shape: [n_systems] + tau_p: torch.Tensor # shape: [n_systems] + + _system_attributes = MDState._system_attributes | { # noqa: SLF001 + "stress", + "isothermal_compressibility", + "tau_p", + } + + +def _compute_instantaneous_internal_pressure( + state: NPTLangevinState, + volumes: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the forces acting on the simulation cell + based on the difference between internal stress and external pressure, + plus a kinetic contribution. These forces drive the volume changes + needed to maintain constant pressure. + + Args: + state (NPTLangevinState): Current NPT state + volumes (torch.Tensor): Current system volumes [n_systems] + kT (torch.Tensor): Temperature in energy units, either scalar or + shape [n_systems] + + Returns: + torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] + """ + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Calculate virials: 2/V * (K_{tensor} - Virial_{tensor}) + twice_kinetic_energy_tensor = torch.einsum( + "bi,bj,b->bij", state.momenta, state.momenta, 1 / state.masses + ) + twice_kinetic_energy_tensor = torch.scatter_add( + torch.zeros( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ), + 0, + state.system_idx.unsqueeze(-1) + .unsqueeze(-1) + .expand_as(twice_kinetic_energy_tensor), + twice_kinetic_energy_tensor, + ) + return twice_kinetic_energy_tensor / volumes - state.stress + + +def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: + """Convert a batch of 3x3 box matrices into lower-triangular form. + Correspond to a Gram-Schmidt orthogonalization of the box vectors. + + Args: + box (torch.Tensor): shape [n_systems, 3, 3] + + Returns: + torch.Tensor: shape [n_systems, 3, 3] lower-triangular boxes + """ + box_buffer = box.clone() + + # Row vectors (a, b, c) + a = box_buffer[:, 0, :] + b = box_buffer[:, 1, :] + c = box_buffer[:, 2, :] + + # --- Compute the lower-triangular entries --- + + # a-axis + box[:, 0, 0] = torch.norm(a, dim=1) + + # b projections + box[:, 1, 0] = torch.sum(a * b, dim=1) / box[:, 0, 0] + box[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - box[:, 1, 0] ** 2) + + # c projections + box[:, 2, 0] = torch.sum(a * c, dim=1) / box[:, 0, 0] + box[:, 2, 1] = (torch.sum(b * c, dim=1) - box[:, 2, 0] * box[:, 1, 0]) / box[:, 1, 1] + box[:, 2, 2] = torch.sqrt( + torch.sum(c * c, dim=1) - box[:, 2, 0] ** 2 - box[:, 2, 1] ** 2 + ) + + # Upper-triangular entries are 0 by initialization + return box + +def batch_matrix_vector( + matrices: torch.Tensor, + vectors: torch.Tensor, +) -> torch.Tensor: + """Perform batch matrix-vector multiplication. + + Args: + matrices (torch.Tensor): shape [n_systems, n, n] + vectors (torch.Tensor): shape [n_systems, n, m] + + Returns: + torch.Tensor: shape [n_systems, n, m] result of multiplication + """ + return torch.matmul(matrices, vectors.unsqueeze(-1)).squeeze(-1) + + +def npt_crescale_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Update momenta with forces + 8. Thermostat (velocity scaling) + + Only allow isotropic external stress. Can run isotropic or anisotropic + cell rescaling. + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + ## Step 1: propagate sqrt(volume) for dt/2 + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = _compute_instantaneous_internal_pressure(state, volume) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + # prefactor_random_matrix = prefactor_random_matrix + a_tilde = ( + -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] + * (P_int - trace_P_int[:, None, None] / 3) + ) + deformation_matrix = torch.matrix_exp( + a_tilde * dt + + prefactor_random_matrix[:, None, None] + * torch.randn( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + ) + deformation_matrix = rotate_gram_schmidt(deformation_matrix) + + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( + -1, 1, 1 + ) + vscaling = torch.inverse(rscaling).transpose(-2, -1) + # print(rscaling[0, 0, 0], rscaling[0, 1, 1], rscaling[0, 2, 2]) + + # Update positions + state.positions = (batch_matrix_vector(rscaling[state.system_idx], state.positions) + + batch_matrix_vector((vscaling + rscaling)[state.system_idx], state.momenta) + * dt / (2 * state.masses.unsqueeze(-1)) + ) + state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) + state.cell = rscaling @ state.cell + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_init( + state: SimState | StateDict, + model: ModelInterface, + *, + kT: torch.Tensor, + dt: torch.Tensor, + tau_p: torch.Tensor | None = None, + isothermal_compressibility: torch.Tensor | None = None, + seed: int | None = None, +) -> NPTCRescaleState: + """Initialize the NPT cell rescaling state. + + This function initializes a state for NPT molecular dynamics with a + cell rescaling barostat. It sets up the system with appropriate initial + conditions including particle positions, momenta, and cell variables. + + Only allow isotropic external stress. + + Args: + state: Initial system state as MDState or dict containing positions, masses, + cell, and PBC information + model (ModelInterface): Model to compute forces and energies + kT: Target temperature in energy units + dt: Integration timestep + tau_p: Barostat relaxation time. Controls how quickly pressure equilibrates. + isothermal_compressibility: Isothermal compressibility of the system. + seed: Random seed for momenta initialization. + """ + device, dtype = model.device, model.dtype + + # Set default values if not provided + if tau_p is None: + tau_p = 5000 * dt # 5ps for dt=1fs + if isothermal_compressibility is None: + isothermal_compressibility = 1e-1 # (eV/A^3)^-1 + + # Convert all parameters to tensors with correct device and dtype + tau_p = torch.as_tensor(tau_p, device=device, dtype=dtype) + isothermal_compressibility = torch.as_tensor( + isothermal_compressibility, device=device, dtype=dtype + ) + if tau_p.ndim == 0: + tau_p = tau_p.expand(state.n_systems) + if isothermal_compressibility.ndim == 0: + isothermal_compressibility = isothermal_compressibility.expand(state.n_systems) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + + if not isinstance(state, SimState): + state = SimState(**state) + + # Get model output to initialize forces and stress + model_output = model(state) + + # Initialize momenta if not provided + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + # Create the initial state + return NPTCRescaleState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index d38e33d7a..2b6087c64 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -460,6 +460,7 @@ def nvt_nose_hoover_invariant( return e_tot +### Number of DOF?###################################################################################### def _vrescale_update( state: MDState, tau: float | torch.Tensor, From fcaa9d51b1a63b208a7fc93fccd664db2f27cfeb Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 31 Oct 2025 17:34:39 +0100 Subject: [PATCH 05/19] move instantaneous pressure tensor --- torch_sim/quantities.py | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index bcb824c02..2ff930705 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -141,6 +141,52 @@ def get_pressure( """ return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) +def compute_instantaneous_pressure_tensor( + *, + momenta: torch.Tensor, + masses: torch.Tensor, + system_idx: torch.Tensor, + stress: torch.Tensor, + volumes: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the instantaneous internal pressure tensor. + + Args: + momenta (torch.Tensor): Particle momenta, shape (n_particles, 3) + masses (torch.Tensor): Particle masses, shape (n_particles,) + system_idx (torch.Tensor): Tensor indicating system membership of each particle + stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) + volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) + + Returns: + torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] + """ + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Calculate virials: 2/V * (K_{tensor} - Virial_{tensor}) + twice_kinetic_energy_tensor = torch.einsum( + "bi,bj,b->bij", momenta, momenta, 1 / masses + ) + n_systems = stress.shape[0] + twice_kinetic_energy_tensor = torch.scatter_add( + torch.zeros( + n_systems, + 3, + 3, + device=momenta.device, + dtype=momenta.dtype, + ), + 0, + system_idx.unsqueeze(-1) + .unsqueeze(-1) + .expand_as(twice_kinetic_energy_tensor), + twice_kinetic_energy_tensor, + ) + return twice_kinetic_energy_tensor / volumes - stress + def calc_heat_flux( momenta: torch.Tensor | None, From 00583a8c5884063bdd45da9230f69176c9f58ea2 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 31 Oct 2025 17:35:06 +0100 Subject: [PATCH 06/19] correct gram schmidt to output an upper triangular matrix --- torch_sim/integrators/npt.py | 131 +++++++++++------------------------ 1 file changed, 40 insertions(+), 91 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1758244f1..b591271c3 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1604,24 +1604,6 @@ def npt_nose_hoover_invariant( return e_tot -################### -# Implement full anisotropic NPT with cell rescaling barostat -# Choices: -# - Time reversible integrator -# - Instantenous kinetic energy (not not the average) -# - According to the authors should be better for constraints -# Inspiration from Bussi SimpleMD repo -# https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp#L681C6-L688C16 -# thermostat -# // update velocities -# // barostat -# // update positions -# // (eventually recompute neighbour list) -# // compute forces -# // update velocities -# // thermostat - - @dataclass(kw_only=True) class NPTCRescaleState(MDState): """State for NPT ensemble with cell rescaling barostat. @@ -1642,85 +1624,42 @@ class NPTCRescaleState(MDState): } -def _compute_instantaneous_internal_pressure( - state: NPTLangevinState, - volumes: torch.Tensor, -) -> torch.Tensor: - """Compute forces on the cell for NPT dynamics. - - This function calculates the forces acting on the simulation cell - based on the difference between internal stress and external pressure, - plus a kinetic contribution. These forces drive the volume changes - needed to maintain constant pressure. - - Args: - state (NPTLangevinState): Current NPT state - volumes (torch.Tensor): Current system volumes [n_systems] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] - - Returns: - torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] - """ - # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) - - # Calculate virials: 2/V * (K_{tensor} - Virial_{tensor}) - twice_kinetic_energy_tensor = torch.einsum( - "bi,bj,b->bij", state.momenta, state.momenta, 1 / state.masses - ) - twice_kinetic_energy_tensor = torch.scatter_add( - torch.zeros( - state.n_systems, - 3, - 3, - device=state.positions.device, - dtype=state.positions.dtype, - ), - 0, - state.system_idx.unsqueeze(-1) - .unsqueeze(-1) - .expand_as(twice_kinetic_energy_tensor), - twice_kinetic_energy_tensor, - ) - return twice_kinetic_energy_tensor / volumes - state.stress - - def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: - """Convert a batch of 3x3 box matrices into lower-triangular form. - Correspond to a Gram-Schmidt orthogonalization of the box vectors. + """Convert a batch of 3x3 box matrices into upper-triangular form. Args: box (torch.Tensor): shape [n_systems, 3, 3] Returns: - torch.Tensor: shape [n_systems, 3, 3] lower-triangular boxes + torch.Tensor: shape [n_systems, 3, 3] upper-triangular boxes """ - box_buffer = box.clone() + out = box.clone() - # Row vectors (a, b, c) - a = box_buffer[:, 0, :] - b = box_buffer[:, 1, :] - c = box_buffer[:, 2, :] + # Columns (a, b, c) correspond to box vectors in column form + a = out[:, :, 0] + b = out[:, :, 1] + c = out[:, :, 2] - # --- Compute the lower-triangular entries --- + # --- Compute upper-triangular entries --- - # a-axis - box[:, 0, 0] = torch.norm(a, dim=1) + # First vector (x-axis) + out[:, 0, 0] = torch.norm(a, dim=1) - # b projections - box[:, 1, 0] = torch.sum(a * b, dim=1) / box[:, 0, 0] - box[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - box[:, 1, 0] ** 2) + # Project b onto a + out[:, 0, 1] = torch.sum(a * b, dim=1) / out[:, 0, 0] + out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 0, 1]**2) - # c projections - box[:, 2, 0] = torch.sum(a * c, dim=1) / box[:, 0, 0] - box[:, 2, 1] = (torch.sum(b * c, dim=1) - box[:, 2, 0] * box[:, 1, 0]) / box[:, 1, 1] - box[:, 2, 2] = torch.sqrt( - torch.sum(c * c, dim=1) - box[:, 2, 0] ** 2 - box[:, 2, 1] ** 2 - ) + # Project c onto a and b + out[:, 0, 2] = torch.sum(a * c, dim=1) / out[:, 0, 0] + out[:, 1, 2] = (torch.sum(b * c, dim=1) - out[:, 0, 2] * out[:, 0, 1]) / out[:, 1, 1] + out[:, 2, 2] = torch.sqrt(torch.sum(c * c, dim=1) - out[:, 0, 2]**2 - out[:, 1, 2]**2) + + # Upper-triangular form → zero lower elements + out[:, 1, 0] = 0.0 + out[:, 2, 0] = 0.0 + out[:, 2, 1] = 0.0 - # Upper-triangular entries are 0 by initialization - return box + return out def batch_matrix_vector( matrices: torch.Tensor, @@ -1760,12 +1699,16 @@ def npt_crescale_step( 4. Update positions (from barostat + half momenta) 5. Update forces with new positions and cell 6. Compute forces - 7. Update momenta with forces - 8. Thermostat (velocity scaling) + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) - Only allow isotropic external stress. Can run isotropic or anisotropic + Only allow isotropic external stress. Can only run anisotropic cell rescaling. + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Instantaneous kinetic energy (not not the average from equipartition) + Args: model (ModelInterface): Model to compute forces and energies state (NPTCRescaleState): Current system state @@ -1788,7 +1731,13 @@ def npt_crescale_step( # Barostat step ## Step 1: propagate sqrt(volume) for dt/2 volume = torch.det(state.cell) # shape: (n_systems,) - P_int = _compute_instantaneous_internal_pressure(state, volume) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) sqrt_vol = torch.sqrt(volume) trace_P_int = torch.einsum("bii->b", P_int) prefactor_random = torch.sqrt( @@ -1799,11 +1748,11 @@ def npt_crescale_step( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix prefactor_random_matrix = ( torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) / new_sqrt_volume ) - # prefactor_random_matrix = prefactor_random_matrix a_tilde = ( -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * (P_int - trace_P_int[:, None, None] / 3) @@ -1821,6 +1770,7 @@ def npt_crescale_step( ) deformation_matrix = rotate_gram_schmidt(deformation_matrix) + ## Step 3: propagate sqrt(volume) for dt/2 new_sqrt_volume += -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) @@ -1828,9 +1778,8 @@ def npt_crescale_step( -1, 1, 1 ) vscaling = torch.inverse(rscaling).transpose(-2, -1) - # print(rscaling[0, 0, 0], rscaling[0, 1, 1], rscaling[0, 2, 2]) - # Update positions + # Update positions and momenta (barostat + half momentum step) state.positions = (batch_matrix_vector(rscaling[state.system_idx], state.positions) + batch_matrix_vector((vscaling + rscaling)[state.system_idx], state.momenta) * dt / (2 * state.masses.unsqueeze(-1)) From 3d983c2d162ddbd54190de97e0bf3e8d4d4bab01 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 31 Oct 2025 18:12:15 +0100 Subject: [PATCH 07/19] ruff --- torch_sim/quantities.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 2ff930705..06c4ef21d 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -141,6 +141,7 @@ def get_pressure( """ return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) + def compute_instantaneous_pressure_tensor( *, momenta: torch.Tensor, @@ -180,9 +181,7 @@ def compute_instantaneous_pressure_tensor( dtype=momenta.dtype, ), 0, - system_idx.unsqueeze(-1) - .unsqueeze(-1) - .expand_as(twice_kinetic_energy_tensor), + system_idx.unsqueeze(-1).unsqueeze(-1).expand_as(twice_kinetic_energy_tensor), twice_kinetic_energy_tensor, ) return twice_kinetic_energy_tensor / volumes - stress From f566f9070ad81e05eb58727bd07ab2d37062ff1d Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 31 Oct 2025 18:14:22 +0100 Subject: [PATCH 08/19] take into account reduced dof because of fixed COM --- torch_sim/integrators/npt.py | 29 ++++++++++++++++++-------- torch_sim/integrators/nvt.py | 40 +++++++++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index b591271c3..47998da26 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1623,6 +1623,14 @@ class NPTCRescaleState(MDState): "tau_p", } + def calc_dof(self) -> torch.Tensor: + """Calculate degrees of freedom for each system in the batch. + + Returns: + torch.Tensor: Degrees of freedom for each system, shape [n_systems] + """ + return super().calc_dof() - 3 # Subtract 3 for center of mass motion + def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: """Convert a batch of 3x3 box matrices into upper-triangular form. @@ -1647,12 +1655,14 @@ def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: # Project b onto a out[:, 0, 1] = torch.sum(a * b, dim=1) / out[:, 0, 0] - out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 0, 1]**2) + out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 0, 1] ** 2) # Project c onto a and b out[:, 0, 2] = torch.sum(a * c, dim=1) / out[:, 0, 0] out[:, 1, 2] = (torch.sum(b * c, dim=1) - out[:, 0, 2] * out[:, 0, 1]) / out[:, 1, 1] - out[:, 2, 2] = torch.sqrt(torch.sum(c * c, dim=1) - out[:, 0, 2]**2 - out[:, 1, 2]**2) + out[:, 2, 2] = torch.sqrt( + torch.sum(c * c, dim=1) - out[:, 0, 2] ** 2 - out[:, 1, 2] ** 2 + ) # Upper-triangular form → zero lower elements out[:, 1, 0] = 0.0 @@ -1661,6 +1671,7 @@ def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: return out + def batch_matrix_vector( matrices: torch.Tensor, vectors: torch.Tensor, @@ -1753,9 +1764,8 @@ def npt_crescale_step( torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) / new_sqrt_volume ) - a_tilde = ( - -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] - * (P_int - trace_P_int[:, None, None] / 3) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + P_int - trace_P_int[:, None, None] / 3 ) deformation_matrix = torch.matrix_exp( a_tilde * dt @@ -1780,10 +1790,11 @@ def npt_crescale_step( vscaling = torch.inverse(rscaling).transpose(-2, -1) # Update positions and momenta (barostat + half momentum step) - state.positions = (batch_matrix_vector(rscaling[state.system_idx], state.positions) - + batch_matrix_vector((vscaling + rscaling)[state.system_idx], state.momenta) - * dt / (2 * state.masses.unsqueeze(-1)) - ) + state.positions = batch_matrix_vector( + rscaling[state.system_idx], state.positions + ) + batch_matrix_vector( + (vscaling + rscaling)[state.system_idx], state.momenta + ) * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) state.cell = rscaling @ state.cell diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 2b6087c64..87c8047f9 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -460,7 +460,34 @@ def nvt_nose_hoover_invariant( return e_tot -### Number of DOF?###################################################################################### +class NVTVRescaleState(MDState): + """State information for an NVT system with a V-Rescale thermostat. + + This class represents the complete state of a molecular system being integrated + in the NVT (constant particle number, volume, temperature) ensemble using a + Velocity Rescaling thermostat. The thermostat maintains constant temperature + through stochastic velocity rescaling. + + Attributes: + positions: Particle positions with shape [n_particles, n_dimensions] + masses: Particle masses with shape [n_particles] + cell: Simulation cell matrix with shape [n_dimensions, n_dimensions] + pbc: Whether to use periodic boundary conditions + momenta: Particle momenta with shape [n_particles, n_dimensions] + energy: Energy of the system + forces: Forces on particles with shape [n_particles, n_dimensions] + + Notes: + - The V-Rescale thermostat provides proper canonical sampling + - Stochastic velocity rescaling ensures correct temperature distribution + - Time-reversible when integrated with appropriate algorithms + """ + + def calc_dof(self) -> torch.Tensor: + """Calculate the degrees of freedom per system.""" + return super().calc_dof() - 3 # Subtract 3 for center of mass motion + + def _vrescale_update( state: MDState, tau: float | torch.Tensor, @@ -494,8 +521,7 @@ def _vrescale_update( ) # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - dof = n_atoms_per_system * state.positions.shape[-1] + dof = state.calc_dof() # Ensure kT and tau have proper batch dimensions n_systems = current_kT.shape[0] @@ -533,7 +559,7 @@ def nvt_vrescale_init( kT: float | torch.Tensor, seed: int | None = None, **_kwargs: Any, -) -> MDState: +) -> NVTVRescaleState: """Initialize an NVT state from input data for velocity rescaling dynamics. Creates an initial state for NVT molecular dynamics using the canonical @@ -570,7 +596,7 @@ def nvt_vrescale_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( + return NVTVRescaleState( positions=state.positions, momenta=momenta, energy=model_output["energy"], @@ -585,12 +611,12 @@ def nvt_vrescale_init( def nvt_vrescale_step( model: ModelInterface, - state: MDState, + state: NVTVRescaleState, *, dt: float | torch.Tensor, kT: float | torch.Tensor, tau: float | torch.Tensor | None = None, -) -> MDState: +) -> NVTVRescaleState: """Perform one complete V-Rescale dynamics integration step. This function implements the canonical sampling through velocity rescaling (V-Rescale) From db92c4beca2f8baef517a987e144f429997f02cd Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 11:33:17 +0100 Subject: [PATCH 09/19] add isotropic c-rescale --- tests/test_integrators.py | 77 ++++++++++- torch_sim/__init__.py | 3 +- torch_sim/integrators/__init__.py | 25 +++- torch_sim/integrators/npt.py | 218 +++++++++++++++++++++++------- 4 files changed, 265 insertions(+), 58 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 03fedc6cf..1a9b89b01 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -531,7 +531,7 @@ def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo assert pos_diff > 0.0001 # Systems should remain separated -def test_npt_crescale( +def test_npt_anisotropic_crescale( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ) -> None: n_steps = 200 @@ -556,7 +556,80 @@ def test_npt_crescale( energies = [] temperatures = [] for _step in range(n_steps): - state = ts.npt_crescale_step( + state = ts.npt_crescale_anisotropic_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_npt_isotropic_crescale( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + n_steps = 200 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + tau_p = torch.tensor(0.1, dtype=DTYPE) + isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) + + # Initialize integrator using new direct API + state = ts.npt_crescale_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + seed=42, + ) + + # Run dynamics for several steps + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.npt_crescale_isotropic_step( state=state, model=lj_model, dt=dt, diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index ccf0fb32a..2d22aaaf6 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -40,8 +40,9 @@ from torch_sim.integrators.npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_anisotropic_step, npt_crescale_init, - npt_crescale_step, + npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index b8e67e9c8..fc91fd375 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -16,7 +16,10 @@ NPT: - Langevin barostat integrator :func:`npt.npt_langevin_step` [4, 5] - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [3] - - C-Rescale barostat integrator :func:`npt.npt_crescale_step` from [6, 7, 8] + - Isotropic C-Rescale barostat integrator :func:`npt.npt_crescale_isotropic_step` + from [6, 8, 9] + - Anisotropic C-Rescale barostat integrator :func:`npt.npt_crescale_anisotropic_step` + from [7, 8, 9] References: [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." @@ -38,6 +41,7 @@ Applied Sciences 12.3 (2022): 1139. [8] Bussi Anisotropic C-Rescale SimpleMD implementation: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + [9] Supplementary Information for [6]. Examples: @@ -66,8 +70,9 @@ from .npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_anisotropic_step, npt_crescale_init, - npt_crescale_step, + npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, @@ -102,6 +107,10 @@ class Integrator(StrEnum): - ``npt_langevin``: Langevin barostat for constant temperature and pressure. - ``npt_nose_hoover``: Nosé-Hoover barostat for constant temperature and constant pressure. + - ``npt_isotropic_crescale``: Isotropic C-Rescale barostat for constant + temperature and pressure with fixed cell shape. + - ``npt_anisotropic_crescale``: Anisotropic C-Rescale barostat for constant + temperature and pressure with variable cell shape. Example: >>> integrator = Integrator.nvt_langevin @@ -116,7 +125,8 @@ class Integrator(StrEnum): nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" npt_nose_hoover = "npt_nose_hoover" - npt_crescale = "npt_crescale" + npt_isotropic_crescale = "npt_isotropic_crescale" + npt_anisotropic_crescale = "npt_anisotropic_crescale" #: Integrator registry - maps integrator names to (init_fn, step_fn) pairs. @@ -141,7 +151,8 @@ class Integrator(StrEnum): #: - ``Integrator.nvt_nose_hoover``: Nosé-Hoover thermostat #: - ``Integrator.npt_langevin``: Langevin barostat #: - ``Integrator.npt_nose_hoover``: Nosé-Hoover barostat -#: - ``Integrator.npt_crescale``: C-Rescale barostat +#: - ``Integrator.npt_isotropic_crescale``: Isotropic NPT C-Rescale barostat +#: - ``Integrator.npt_anisotropic_crescale``: Anisotropic NPT C-Rescale barostat #: #: :type: dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] INTEGRATOR_REGISTRY: Final[ @@ -153,5 +164,9 @@ class Integrator(StrEnum): Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), Integrator.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_step), - Integrator.npt_crescale: (npt_crescale_init, npt_crescale_step), + Integrator.npt_isotropic_crescale: (npt_crescale_init, npt_crescale_isotropic_step), + Integrator.npt_anisotropic_crescale: ( + npt_crescale_init, + npt_crescale_anisotropic_step, + ), } diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 47998da26..5b2258084 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1688,59 +1688,12 @@ def batch_matrix_vector( return torch.matmul(matrices, vectors.unsqueeze(-1)).squeeze(-1) -def npt_crescale_step( +def _crescale_anisotropic_barostat_step( state: NPTCRescaleState, - model: ModelInterface, - *, - dt: torch.Tensor, kT: torch.Tensor, + dt: torch.Tensor, external_pressure: torch.Tensor, - tau: torch.Tensor | None = None, ) -> NPTCRescaleState: - """Perform one NPT integration step with cell rescaling barostat. - - This function performs a single integration step for NPT dynamics using - a cell rescaling barostat. It updates particle positions, momenta, and - the simulation cell based on the target temperature and pressure. - - Trotter based splitting: - 1. Half Thermostat (velocity scaling) - 2. Half Update momenta with forces - 3. Barostat (cell rescaling) - 4. Update positions (from barostat + half momenta) - 5. Update forces with new positions and cell - 6. Compute forces - 7. Half Update momenta with forces - 8. Half Thermostat (velocity scaling) - - Only allow isotropic external stress. Can only run anisotropic - cell rescaling. - - Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - - Time reversible integrator - - Instantaneous kinetic energy (not not the average from equipartition) - - Args: - model (ModelInterface): Model to compute forces and energies - state (NPTCRescaleState): Current system state - dt (torch.Tensor): Integration timestep - kT (torch.Tensor): Target temperature - external_pressure (torch.Tensor): Target external pressure - tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, - defaults to 100*dt - - Returns: - NPTCRescaleState: Updated state after one integration step - """ - # Note: would probably be better to have tau in NVTCRescaleState - if tau is None: - tau = 100 * dt - state = _vrescale_update(state, tau, kT, dt / 2) - - state = momentum_step(state, dt / 2) - - # Barostat step - ## Step 1: propagate sqrt(volume) for dt/2 volume = torch.det(state.cell) # shape: (n_systems,) P_int = ts.quantities.compute_instantaneous_pressure_tensor( momenta=state.momenta, @@ -1797,6 +1750,170 @@ def npt_crescale_step( ) * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) state.cell = rscaling @ state.cell + return state + + +def _crescale_isotropic_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + + # Update positions and momenta (barostat + half momentum step) + # SI (S13ab): notice there is a typo in the SI where q_i(t) + # should be scaled as well by rscaling + rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view(-1, 1, 1) + state.positions = rscaling[state.system_idx] * state.positions + ( + rscaling + 1 / rscaling + )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) + state.momenta = (1 / rscaling)[state.system_idx] * state.momenta + state.cell = rscaling * state.cell + return state + + +def npt_crescale_anisotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. Can only run anisotropic + cell rescaling. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Instantaneous kinetic energy (not not the average from equipartition) + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_isotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. This performs isotropic + cell rescaling: cell shape is preserved, cell lengths are scaled equally. + For anisotropic cell rescaling, use npt_crescale_anisotropic_step. + + References: + - Bernetti, Mattia, and Giovanni Bussi. + "Pressure control using stochastic cell rescaling." + The Journal of Chemical Physics 153.11 (2020). + - And the corresponding Supplementary Information which details + the integration scheme. Notice an error in scaling of positions in SI Eq. S13a. + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_isotropic_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) @@ -1827,7 +1944,8 @@ def npt_crescale_init( cell rescaling barostat. It sets up the system with appropriate initial conditions including particle positions, momenta, and cell variables. - Only allow isotropic external stress. + Only allow isotropic external stress, but can run both isotropic and + anisotropic cell rescaling. Args: state: Initial system state as MDState or dict containing positions, masses, From e1e36739b9286f8991fd35c796a81e4567dd3080 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 11:37:57 +0100 Subject: [PATCH 10/19] change calc_dof name and fix iso rscaling shape --- torch_sim/integrators/npt.py | 7 ++++--- torch_sim/integrators/nvt.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index d6f5c5d95..e9d14ac7a 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1619,13 +1619,13 @@ class NPTCRescaleState(MDState): "tau_p", } - def calc_dof(self) -> torch.Tensor: + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom for each system in the batch. Returns: torch.Tensor: Degrees of freedom for each system, shape [n_systems] """ - return super().calc_dof() - 3 # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 # Subtract 3 for center of mass motion def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: @@ -1777,11 +1777,12 @@ def _crescale_isotropic_barostat_step( # Update positions and momenta (barostat + half momentum step) # SI (S13ab): notice there is a typo in the SI where q_i(t) # should be scaled as well by rscaling - rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view(-1, 1, 1) + rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).unsqueeze(-1) state.positions = rscaling[state.system_idx] * state.positions + ( rscaling + 1 / rscaling )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = (1 / rscaling)[state.system_idx] * state.momenta + rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] state.cell = rscaling * state.cell return state diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 220540f0c..c0451a9c1 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -490,9 +490,9 @@ class NVTVRescaleState(MDState): - Time-reversible when integrated with appropriate algorithms """ - def calc_dof(self) -> torch.Tensor: + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate the degrees of freedom per system.""" - return super().calc_dof() - 3 # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 # Subtract 3 for center of mass motion def _vrescale_update( @@ -528,7 +528,7 @@ def _vrescale_update( ) # Calculate degrees of freedom per system - dof = state.calc_dof() + dof = state.get_number_of_degrees_of_freedom() # Ensure kT and tau have proper batch dimensions n_systems = current_kT.shape[0] From 2a5ccccfa30fa9841e235eef9484508f1a337a88 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 11:39:13 +0100 Subject: [PATCH 11/19] lint --- torch_sim/integrators/npt.py | 5 +++-- torch_sim/integrators/nvt.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e9d14ac7a..a3985ac39 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1625,7 +1625,8 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: Returns: torch.Tensor: Degrees of freedom for each system, shape [n_systems] """ - return super().get_number_of_degrees_of_freedom() - 3 # Subtract 3 for center of mass motion + # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: @@ -1782,7 +1783,7 @@ def _crescale_isotropic_barostat_step( rscaling + 1 / rscaling )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = (1 / rscaling)[state.system_idx] * state.momenta - rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] + rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] state.cell = rscaling * state.cell return state diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index c0451a9c1..b5fdf1219 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -492,7 +492,8 @@ class NVTVRescaleState(MDState): def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate the degrees of freedom per system.""" - return super().get_number_of_degrees_of_freedom() - 3 # Subtract 3 for center of mass motion + # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 def _vrescale_update( From 3d00e94c075356c2fce698cd3b33e5ceb68d2166 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 18:47:14 +0100 Subject: [PATCH 12/19] update take into account dof --- torch_sim/integrators/nvt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index b5fdf1219..955983cc2 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -524,9 +524,7 @@ def _vrescale_update( dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) # Calculate current temperature per system - current_kT = ts.quantities.calc_kT( - masses=state.masses, momenta=state.momenta, system_idx=state.system_idx - ) + current_kT = state.calc_kT() # Calculate degrees of freedom per system dof = state.get_number_of_degrees_of_freedom() From 6180a28d96d09ccdd20726d7a67b79bb0fde350f Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 18:47:29 +0100 Subject: [PATCH 13/19] calc_kT taking into account dof --- torch_sim/integrators/md.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 92de1ccc9..e18add9df 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -7,7 +7,7 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_temperature +from torch_sim.quantities import calc_kT, calc_temperature from torch_sim.state import SimState from torch_sim.units import MetalUnits @@ -76,6 +76,19 @@ def calc_temperature( units=units, ) + def calc_kT(self) -> torch.Tensor: # noqa: N802 + """Calculate kT from momenta, masses, and system indices. + + Returns: + torch.Tensor: Calculated kT + """ + return calc_kT( + masses=self.masses, + momenta=self.momenta, + system_idx=self.system_idx, + dof_per_system=self.get_number_of_degrees_of_freedom(), + ) + def calculate_momenta( positions: torch.Tensor, From 715166226eae081aae99f14011f1d225c9549348 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 3 Nov 2025 18:47:59 +0100 Subject: [PATCH 14/19] add average anisotropic, correct error in equation (a_tilde IS traceless) --- torch_sim/integrators/npt.py | 146 ++++++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 12 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index a3985ac39..9a297cccb 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1715,18 +1715,27 @@ def _crescale_anisotropic_barostat_step( / new_sqrt_volume ) a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( - P_int - trace_P_int[:, None, None] / 3 + P_int + - trace_P_int[:, None, None] + / 3 + * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(P_int) ) + random_matrix = torch.randn( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ + :, None, None + ] / 3 * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(random_matrix) deformation_matrix = torch.matrix_exp( - a_tilde * dt - + prefactor_random_matrix[:, None, None] - * torch.randn( - state.n_systems, - 3, - 3, - device=state.positions.device, - dtype=state.positions.dtype, - ) + a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde ) deformation_matrix = rotate_gram_schmidt(deformation_matrix) @@ -1750,6 +1759,118 @@ def _crescale_anisotropic_barostat_step( return state +def compute_average_pressure_tensor( + *, + degrees_of_freedom: torch.Tensor, + kT: torch.Tensor, + stress: torch.Tensor, + volumes: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the instantaneous internal pressure tensor. + + Args: + degrees_of_freedom (torch.Tensor): Degrees of freedom of + the system, shape (n_systems,) + kT (torch.Tensor): Thermal energy (k_B * T), shape (n_system + stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) + volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) + + Returns: + torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] + """ + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Calculate virials: 2/V * (N_{atoms}k_B T / 2 - Virial_{tensor}) + n_systems = stress.shape[0] + average_kinetic_energy_tensor = ( + degrees_of_freedom + * kT + / volumes + * torch.eye(3, device=stress.device, dtype=stress.dtype).expand(n_systems, 3, 3) + ) + return average_kinetic_energy_tensor - stress + + +def _crescale_average_anisotropic_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = compute_average_pressure_tensor( + degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, + kT=kT, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + P_int + - trace_P_int[:, None, None] + / 3 + * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(P_int) + ) + random_matrix = torch.randn( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ + :, None, None + ] / 3 * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(random_matrix) + deformation_matrix = torch.matrix_exp( + a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde + ) + deformation_matrix = rotate_gram_schmidt(deformation_matrix) + + ## Step 3: propagate sqrt(volume) for dt/2 + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( + -1, 1, 1 + ) + + # Update positions and momenta (barostat + half momentum step) + state.positions = batch_matrix_vector( + rscaling[state.system_idx], state.positions + ) + batch_matrix_vector( + ( + torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(rscaling) + + rscaling + )[state.system_idx], + state.momenta, + ) * dt / (2 * state.masses.unsqueeze(-1)) + state.cell = rscaling @ state.cell + return state + + def _crescale_isotropic_barostat_step( state: NPTCRescaleState, kT: torch.Tensor, @@ -1781,7 +1902,7 @@ def _crescale_isotropic_barostat_step( rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).unsqueeze(-1) state.positions = rscaling[state.system_idx] * state.positions + ( rscaling + 1 / rscaling - )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) + )[state.system_idx] * state.momenta * (0.5 * dt) / state.masses.unsqueeze(-1) state.momenta = (1 / rscaling)[state.system_idx] * state.momenta rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] state.cell = rscaling * state.cell @@ -1840,7 +1961,8 @@ def npt_crescale_anisotropic_step( state = momentum_step(state, dt / 2) # Barostat step - state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) + # state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) + state = _crescale_average_anisotropic_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) From 35cf275ad3685e29cd963b71ea96217af561da8a Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 4 Nov 2025 11:20:00 +0100 Subject: [PATCH 15/19] correct shape for average pressure tensor and added new crescale step fn --- torch_sim/integrators/npt.py | 79 ++++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 9 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 9a297cccb..8f6245711 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1773,22 +1773,18 @@ def compute_average_pressure_tensor( Args: degrees_of_freedom (torch.Tensor): Degrees of freedom of the system, shape (n_systems,) - kT (torch.Tensor): Thermal energy (k_B * T), shape (n_system + kT (torch.Tensor): Thermal energy (k_B * T), shape (n_systems,) stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) Returns: torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] """ - # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) - # Calculate virials: 2/V * (N_{atoms}k_B T / 2 - Virial_{tensor}) n_systems = stress.shape[0] + prefactor = degrees_of_freedom * kT / volumes # shape: (n_systems,) average_kinetic_energy_tensor = ( - degrees_of_freedom - * kT - / volumes + prefactor[:, None, None] * torch.eye(3, device=stress.device, dtype=stress.dtype).expand(n_systems, 3, 3) ) return average_kinetic_energy_tensor - stress @@ -1961,8 +1957,7 @@ def npt_crescale_anisotropic_step( state = momentum_step(state, dt / 2) # Barostat step - # state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) - state = _crescale_average_anisotropic_barostat_step(state, kT, dt, external_pressure) + state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) # Forces model_output = model(state) @@ -1977,6 +1972,72 @@ def npt_crescale_anisotropic_step( return _vrescale_update(state, tau, kT, dt / 2) +def npt_crescale_average_anisotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. Can only run anisotropic + cell rescaling. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Average kinetic energy, scaling only positions + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_average_anisotropic_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + def npt_crescale_isotropic_step( state: NPTCRescaleState, model: ModelInterface, From cb304411b3e48f8510025a1d67083af8d7cb75df Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 4 Nov 2025 11:21:30 +0100 Subject: [PATCH 16/19] lint --- torch_sim/integrators/npt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 8f6245711..cbfd60778 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1782,11 +1782,10 @@ def compute_average_pressure_tensor( """ # Calculate virials: 2/V * (N_{atoms}k_B T / 2 - Virial_{tensor}) n_systems = stress.shape[0] - prefactor = degrees_of_freedom * kT / volumes # shape: (n_systems,) - average_kinetic_energy_tensor = ( - prefactor[:, None, None] - * torch.eye(3, device=stress.device, dtype=stress.dtype).expand(n_systems, 3, 3) - ) + prefactor = degrees_of_freedom * kT / volumes # shape: (n_systems,) + average_kinetic_energy_tensor = prefactor[:, None, None] * torch.eye( + 3, device=stress.device, dtype=stress.dtype + ).expand(n_systems, 3, 3) return average_kinetic_energy_tensor - stress @@ -2038,6 +2037,7 @@ def npt_crescale_average_anisotropic_step( # Final thermostat step return _vrescale_update(state, tau, kT, dt / 2) + def npt_crescale_isotropic_step( state: NPTCRescaleState, model: ModelInterface, From 373ff460ca4be9e3892c3a47904834fed86375ac Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 6 Nov 2025 16:55:32 +0100 Subject: [PATCH 17/19] fix equation deformation for anisotropic barostat --- torch_sim/integrators/npt.py | 40 ++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index cbfd60778..740b2e729 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1630,42 +1630,38 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: - """Convert a batch of 3x3 box matrices into upper-triangular form. + """Convert a batch of 3x3 box matrices into lower-triangular form. Args: box (torch.Tensor): shape [n_systems, 3, 3] Returns: - torch.Tensor: shape [n_systems, 3, 3] upper-triangular boxes + torch.Tensor: shape [n_systems, 3, 3] lower-triangular boxes """ - out = box.clone() + out = torch.zeros_like(box) # Columns (a, b, c) correspond to box vectors in column form - a = out[:, :, 0] - b = out[:, :, 1] - c = out[:, :, 2] + a = box[:, :, 0] + b = box[:, :, 1] + c = box[:, :, 2] - # --- Compute upper-triangular entries --- + # --- Compute the lower-triangular entries --- - # First vector (x-axis) + # a-axis out[:, 0, 0] = torch.norm(a, dim=1) - # Project b onto a - out[:, 0, 1] = torch.sum(a * b, dim=1) / out[:, 0, 0] - out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 0, 1] ** 2) + # b projections + out[:, 1, 0] = torch.sum(a * b, dim=1) / out[:, 0, 0] + out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 1, 0] ** 2) - # Project c onto a and b - out[:, 0, 2] = torch.sum(a * c, dim=1) / out[:, 0, 0] - out[:, 1, 2] = (torch.sum(b * c, dim=1) - out[:, 0, 2] * out[:, 0, 1]) / out[:, 1, 1] + # c projections + out[:, 2, 0] = torch.sum(a * c, dim=1) / out[:, 0, 0] + out[:, 2, 1] = (torch.sum(b * c, dim=1) - out[:, 2, 0] * out[:, 1, 0]) / out[:, 1, 1] out[:, 2, 2] = torch.sqrt( - torch.sum(c * c, dim=1) - out[:, 0, 2] ** 2 - out[:, 1, 2] ** 2 + torch.sum(c * c, dim=1) - out[:, 2, 0] ** 2 - out[:, 2, 1] ** 2 ) - # Upper-triangular form → zero lower elements - out[:, 1, 0] = 0.0 - out[:, 2, 0] = 0.0 - out[:, 2, 1] = 0.0 - + # Upper-triangular entries are 0 by initialization return out @@ -1755,7 +1751,7 @@ def _crescale_anisotropic_barostat_step( (vscaling + rscaling)[state.system_idx], state.momenta ) * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) - state.cell = rscaling @ state.cell + state.cell = rscaling.mT @ state.cell return state @@ -1862,7 +1858,7 @@ def _crescale_average_anisotropic_barostat_step( )[state.system_idx], state.momenta, ) * dt / (2 * state.masses.unsqueeze(-1)) - state.cell = rscaling @ state.cell + state.cell = rscaling.mT @ state.cell return state From 1a4ae9901953b16a9a2d49d7252acb123b5b4510 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 12 Nov 2025 11:56:05 +0100 Subject: [PATCH 18/19] update to match paper equation for average anisotropic NPT --- torch_sim/integrators/npt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 740b2e729..3f30ba816 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1793,7 +1793,8 @@ def _crescale_average_anisotropic_barostat_step( ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = compute_average_pressure_tensor( - degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, + # Should it be degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, + degrees_of_freedom=state.n_atoms_per_system, kT=kT, stress=state.stress, volumes=volume, From c40ad9e58a2d64051ca2a53eed06a34900ce4e2f Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 12 Nov 2025 18:20:37 +0100 Subject: [PATCH 19/19] add anisotropic NPT only for cell lengths (not angles) --- torch_sim/integrators/__init__.py | 6 +- torch_sim/integrators/npt.py | 142 +++++++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 6 deletions(-) diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index fc91fd375..f93cb47e0 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -18,8 +18,10 @@ - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [3] - Isotropic C-Rescale barostat integrator :func:`npt.npt_crescale_isotropic_step` from [6, 8, 9] - - Anisotropic C-Rescale barostat integrator :func:`npt.npt_crescale_anisotropic_step` - from [7, 8, 9] + - C-Rescale barostat integrator :func:`npt.npt_crescale_anisotropic_step` + from [7, 8, 9]. Available implementations include isotropic and + anisotropic cell rescaling, allowing to change cell lengths, and potentially angles + as well. References: [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 3f30ba816..968a981cd 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1755,6 +1755,69 @@ def _crescale_anisotropic_barostat_step( return state +def _crescale_independent_lengths_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + # Note: it corresponds to using a diagonal isothermal compressibility tensor + P_int_diagonal = torch.diagonal(P_int, dim1=-2, dim2=-1) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None] * ( + P_int_diagonal - trace_P_int[:, None] / 3 + ) + + random_matrix = torch.randn( + state.n_systems, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.mean(random_matrix, dim=1, keepdim=True) + deformation_matrix = torch.exp( + a_tilde * dt + prefactor_random_matrix[:, None] * random_matrix_tilde + ) + + ## Step 3: propagate sqrt(volume) for dt/2 + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow( + (new_sqrt_volume / sqrt_vol), 2 / 3 + ).unsqueeze(-1) + + # Update positions and momenta (barostat + half momentum step) + state.positions = rscaling[state.system_idx] * state.positions + ( + rscaling + 1 / rscaling + )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) + state.momenta = (1 / rscaling)[state.system_idx] * state.momenta + state.cell = torch.diag_embed(rscaling) @ state.cell + return state + + def compute_average_pressure_tensor( *, degrees_of_freedom: torch.Tensor, @@ -1926,8 +1989,9 @@ def npt_crescale_anisotropic_step( 7. Half Update momenta with forces 8. Half Thermostat (velocity scaling) - Only allow isotropic external stress. Can only run anisotropic - cell rescaling. + Only allow isotropic external stress. This method performs anisotropic + cell rescaling. Lengths and angles can change independently. Based on + pressure using kinetic energy. Positions and momenta are scaled when scaling the cell. Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - Time reversible integrator @@ -1968,6 +2032,74 @@ def npt_crescale_anisotropic_step( return _vrescale_update(state, tau, kT, dt / 2) +def npt_crescale_independent_lengths_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. + This method has 3 degrees of freedom for each cell length, + allowing independent scaling of each cell vector. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Instantaneous kinetic energy (not not the average from equipartition) + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_independent_lengths_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + def npt_crescale_average_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, @@ -1993,8 +2125,10 @@ def npt_crescale_average_anisotropic_step( 7. Half Update momenta with forces 8. Half Thermostat (velocity scaling) - Only allow isotropic external stress. Can only run anisotropic - cell rescaling. + Only allow isotropic external stress. This method performs anisotropic + cell rescaling. Lengths and angles can change independently. Based on + pressure using average kinetic energy from equipartition theorem. + Only positions are scaled when scaling the cell. Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - Time reversible integrator