diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 89eb5fc10..6ddafa064 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -485,6 +485,206 @@ def test_nvt_nose_hoover_multi_kt( assert invariant_std / invariant_traj.mean() < 0.1 +def test_nvt_vrescale(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + n_steps = 100 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + # Initialize integrator + state = ts.nvt_vrescale_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_vrescale_step(model=lj_model, state=state, dt=dt, kT=kT) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_npt_anisotropic_crescale( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + n_steps = 200 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + tau_p = torch.tensor(0.1, dtype=DTYPE) + isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) + + # Initialize integrator using new direct API + state = ts.npt_crescale_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + seed=42, + ) + + # Run dynamics for several steps + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.npt_crescale_anisotropic_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + +def test_npt_isotropic_crescale( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + n_steps = 200 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(100.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=DTYPE) * MetalUnits.pressure + tau_p = torch.tensor(0.1, dtype=DTYPE) + isothermal_compressibility = torch.tensor(1e-4, dtype=DTYPE) + + # Initialize integrator using new direct API + state = ts.npt_crescale_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + seed=42, + ) + + # Run dynamics for several steps + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.npt_crescale_isotropic_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + # Calculate instantaneous temperature from kinetic energy + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 3bc1bc922..f632cbfa8 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -34,10 +34,15 @@ nvt_nose_hoover_init, nvt_nose_hoover_invariant, nvt_nose_hoover_step, + nvt_vrescale_init, + nvt_vrescale_step, ) from torch_sim.integrators.npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_anisotropic_step, + npt_crescale_init, + npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index 258efcd4b..bae4b5cd3 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -8,23 +8,42 @@ NVE: - Velocity Verlet integrator for constant energy simulations :func:`nve.nve_step` NVT: + - Velocity Rescaling thermostat integrator + :func:`nvt.nvt_vrescale_step` [1] - Langevin thermostat integrator :func:`nvt.nvt_langevin_step` - using BAOAB scheme [1] - - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [2] + using BAOAB scheme [2] + - Nosé-Hoover thermostat integrator :func:`nvt.nvt_nose_hoover_step` from [3] NPT: - - Langevin barostat integrator :func:`npt.npt_langevin_step` [3, 4] - - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [2] + - Langevin barostat integrator :func:`npt.npt_langevin_step` [4, 5] + - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [3] + - Isotropic C-Rescale barostat integrator :func:`npt.npt_crescale_isotropic_step` + from [6, 8, 9] + - C-Rescale barostat integrator :func:`npt.npt_crescale_anisotropic_step` + from [7, 8, 9]. Available implementations include isotropic and + anisotropic cell rescaling, allowing to change cell lengths, and potentially angles + as well. References: - [1] Leimkuhler B, Matthews C.2016 Efficient molecular dynamics using geodesic + [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." + The Journal of chemical physics, 126(1), 014101 (2007). + [2] Leimkuhler B, Matthews C.2016 Efficient molecular dynamics using geodesic integration and solvent-solute splitting. Proc. R. Soc. A 472: 20160138 - [2] Martyna, G. J., Tuckerman, M. E., Tobias, D. J., & Klein, M. L. (1996). + [3] Martyna, G. J., Tuckerman, M. E., Tobias, D. J., & Klein, M. L. (1996). Explicit reversible integrators for extended systems dynamics. Molecular Physics, 87(5), 1117-1157. - [3] Grønbech-Jensen, N., & Farago, O. (2014). + [4] Grønbech-Jensen, N., & Farago, O. (2014). Constant pressure and temperature discrete-time Langevin molecular dynamics. The Journal of chemical physics, 141(19). - [4] LAMMPS: https://docs.lammps.org/fix_press_langevin.html + [5] LAMMPS: https://docs.lammps.org/fix_press_langevin.html + [6] Bernetti, Mattia, and Giovanni Bussi. + "Pressure control using stochastic cell rescaling." + The Journal of Chemical Physics 153.11 (2020). + [7] Del Tatto, Vittorio, et al. "Molecular dynamics of solids at + constant pressure and stress using anisotropic stochastic cell rescaling." + Applied Sciences 12.3 (2022): 1139. + [8] Bussi Anisotropic C-Rescale SimpleMD implementation: + https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + [9] Supplementary Information for [6]. Examples: @@ -53,6 +72,9 @@ from .npt import ( NPTLangevinState, NPTNoseHooverState, + npt_crescale_anisotropic_step, + npt_crescale_init, + npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, npt_nose_hoover_init, @@ -67,6 +89,8 @@ nvt_nose_hoover_init, nvt_nose_hoover_invariant, nvt_nose_hoover_step, + nvt_vrescale_init, + nvt_vrescale_step, ) @@ -79,11 +103,16 @@ class Integrator(StrEnum): Available options: - ``nve``: Constant energy (microcanonical) ensemble. + - ``nvt_vrescale``: Velocity rescaling thermostat for constant temperature. - ``nvt_langevin``: Langevin thermostat for constant temperature. - ``nvt_nose_hoover``: Nosé-Hoover thermostat for constant temperature. - ``npt_langevin``: Langevin barostat for constant temperature and pressure. - ``npt_nose_hoover``: Nosé-Hoover barostat for constant temperature and constant pressure. + - ``npt_isotropic_crescale``: Isotropic C-Rescale barostat for constant + temperature and pressure with fixed cell shape. + - ``npt_anisotropic_crescale``: Anisotropic C-Rescale barostat for constant + temperature and pressure with variable cell shape. Example: >>> integrator = Integrator.nvt_langevin @@ -93,10 +122,13 @@ class Integrator(StrEnum): """ nve = "nve" + nvt_vrescale = "nvt_vrescale" nvt_langevin = "nvt_langevin" nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" npt_nose_hoover = "npt_nose_hoover" + npt_isotropic_crescale = "npt_isotropic_crescale" + npt_anisotropic_crescale = "npt_anisotropic_crescale" #: Integrator registry - maps integrator names to (init_fn, step_fn) pairs. @@ -116,18 +148,27 @@ class Integrator(StrEnum): #: The available integrators are: #: #: - ``Integrator.nve``: Velocity Verlet (microcanonical) +#: - ``Integrator.nvt_vrescale``: V-Rescale thermostat #: - ``Integrator.nvt_langevin``: Langevin thermostat #: - ``Integrator.nvt_nose_hoover``: Nosé-Hoover thermostat #: - ``Integrator.npt_langevin``: Langevin barostat #: - ``Integrator.npt_nose_hoover``: Nosé-Hoover barostat +#: - ``Integrator.npt_isotropic_crescale``: Isotropic NPT C-Rescale barostat +#: - ``Integrator.npt_anisotropic_crescale``: Anisotropic NPT C-Rescale barostat #: #: :type: dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] INTEGRATOR_REGISTRY: Final[ dict[Integrator, tuple[Callable[..., Any], Callable[..., Any]]] ] = { Integrator.nve: (nve_init, nve_step), + Integrator.nvt_vrescale: (nvt_vrescale_init, nvt_vrescale_step), Integrator.nvt_langevin: (nvt_langevin_init, nvt_langevin_step), Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), Integrator.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_step), + Integrator.npt_isotropic_crescale: (npt_crescale_init, npt_crescale_isotropic_step), + Integrator.npt_anisotropic_crescale: ( + npt_crescale_init, + npt_crescale_anisotropic_step, + ), } diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index b03fb7a30..1196468d7 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -7,7 +7,7 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_temperature +from torch_sim.quantities import calc_kT, calc_temperature from torch_sim.state import SimState from torch_sim.units import MetalUnits @@ -76,6 +76,19 @@ def calc_temperature( units=units, ) + def calc_kT(self) -> torch.Tensor: # noqa: N802 + """Calculate kT from momenta, masses, and system indices. + + Returns: + torch.Tensor: Calculated kT + """ + return calc_kT( + masses=self.masses, + momenta=self.momenta, + system_idx=self.system_idx, + dof_per_system=self.get_number_of_degrees_of_freedom(), + ) + def calculate_momenta( positions: torch.Tensor, diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index db6e2b153..300035245 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -13,7 +13,9 @@ NoseHooverChainFns, calculate_momenta, construct_nose_hoover_chain, + momentum_step, ) +from torch_sim.integrators.nvt import _vrescale_update from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -55,8 +57,6 @@ class NPTLangevinState(MDState): """ # System state variables - energy: torch.Tensor - forces: torch.Tensor stress: torch.Tensor alpha: torch.Tensor @@ -1598,3 +1598,724 @@ def npt_nose_hoover_invariant( e_tot += torch.square(cell_momentum) / (2 * state.cell_mass) return e_tot + + +@dataclass(kw_only=True) +class NPTCRescaleState(MDState): + """State for NPT ensemble with cell rescaling barostat. + + This class extends the MDState to include variables and properties + specific to the NPT ensemble with a cell rescaling barostat. + """ + + # System state variables + stress: torch.Tensor + isothermal_compressibility: torch.Tensor # shape: [n_systems] + tau_p: torch.Tensor # shape: [n_systems] + + _system_attributes = MDState._system_attributes | { # noqa: SLF001 + "stress", + "isothermal_compressibility", + "tau_p", + } + + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: + """Calculate degrees of freedom for each system in the batch. + + Returns: + torch.Tensor: Degrees of freedom for each system, shape [n_systems] + """ + # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 + + +def rotate_gram_schmidt(box: torch.Tensor) -> torch.Tensor: + """Convert a batch of 3x3 box matrices into lower-triangular form. + + Args: + box (torch.Tensor): shape [n_systems, 3, 3] + + Returns: + torch.Tensor: shape [n_systems, 3, 3] lower-triangular boxes + """ + out = torch.zeros_like(box) + + # Columns (a, b, c) correspond to box vectors in column form + a = box[:, :, 0] + b = box[:, :, 1] + c = box[:, :, 2] + + # --- Compute the lower-triangular entries --- + + # a-axis + out[:, 0, 0] = torch.norm(a, dim=1) + + # b projections + out[:, 1, 0] = torch.sum(a * b, dim=1) / out[:, 0, 0] + out[:, 1, 1] = torch.sqrt(torch.sum(b * b, dim=1) - out[:, 1, 0] ** 2) + + # c projections + out[:, 2, 0] = torch.sum(a * c, dim=1) / out[:, 0, 0] + out[:, 2, 1] = (torch.sum(b * c, dim=1) - out[:, 2, 0] * out[:, 1, 0]) / out[:, 1, 1] + out[:, 2, 2] = torch.sqrt( + torch.sum(c * c, dim=1) - out[:, 2, 0] ** 2 - out[:, 2, 1] ** 2 + ) + + # Upper-triangular entries are 0 by initialization + return out + + +def batch_matrix_vector( + matrices: torch.Tensor, + vectors: torch.Tensor, +) -> torch.Tensor: + """Perform batch matrix-vector multiplication. + + Args: + matrices (torch.Tensor): shape [n_systems, n, n] + vectors (torch.Tensor): shape [n_systems, n, m] + + Returns: + torch.Tensor: shape [n_systems, n, m] result of multiplication + """ + return torch.matmul(matrices, vectors.unsqueeze(-1)).squeeze(-1) + + +def _crescale_anisotropic_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + P_int + - trace_P_int[:, None, None] + / 3 + * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(P_int) + ) + random_matrix = torch.randn( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ + :, None, None + ] / 3 * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(random_matrix) + deformation_matrix = torch.matrix_exp( + a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde + ) + deformation_matrix = rotate_gram_schmidt(deformation_matrix) + + ## Step 3: propagate sqrt(volume) for dt/2 + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( + -1, 1, 1 + ) + vscaling = torch.inverse(rscaling).transpose(-2, -1) + + # Update positions and momenta (barostat + half momentum step) + state.positions = batch_matrix_vector( + rscaling[state.system_idx], state.positions + ) + batch_matrix_vector( + (vscaling + rscaling)[state.system_idx], state.momenta + ) * dt / (2 * state.masses.unsqueeze(-1)) + state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) + state.cell = rscaling.mT @ state.cell + return state + + +def _crescale_independent_lengths_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + # Note: it corresponds to using a diagonal isothermal compressibility tensor + P_int_diagonal = torch.diagonal(P_int, dim1=-2, dim2=-1) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None] * ( + P_int_diagonal - trace_P_int[:, None] / 3 + ) + + random_matrix = torch.randn( + state.n_systems, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.mean(random_matrix, dim=1, keepdim=True) + deformation_matrix = torch.exp( + a_tilde * dt + prefactor_random_matrix[:, None] * random_matrix_tilde + ) + + ## Step 3: propagate sqrt(volume) for dt/2 + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow( + (new_sqrt_volume / sqrt_vol), 2 / 3 + ).unsqueeze(-1) + + # Update positions and momenta (barostat + half momentum step) + state.positions = rscaling[state.system_idx] * state.positions + ( + rscaling + 1 / rscaling + )[state.system_idx] * state.momenta * dt / (2 * state.masses.unsqueeze(-1)) + state.momenta = (1 / rscaling)[state.system_idx] * state.momenta + state.cell = torch.diag_embed(rscaling) @ state.cell + return state + + +def compute_average_pressure_tensor( + *, + degrees_of_freedom: torch.Tensor, + kT: torch.Tensor, + stress: torch.Tensor, + volumes: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the instantaneous internal pressure tensor. + + Args: + degrees_of_freedom (torch.Tensor): Degrees of freedom of + the system, shape (n_systems,) + kT (torch.Tensor): Thermal energy (k_B * T), shape (n_systems,) + stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) + volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) + + Returns: + torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] + """ + # Calculate virials: 2/V * (N_{atoms}k_B T / 2 - Virial_{tensor}) + n_systems = stress.shape[0] + prefactor = degrees_of_freedom * kT / volumes # shape: (n_systems,) + average_kinetic_energy_tensor = prefactor[:, None, None] * torch.eye( + 3, device=stress.device, dtype=stress.dtype + ).expand(n_systems, 3, 3) + return average_kinetic_energy_tensor - stress + + +def _crescale_average_anisotropic_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = compute_average_pressure_tensor( + # Should it be degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, + degrees_of_freedom=state.n_atoms_per_system, + kT=kT, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + ## Step 2: compute deformation matrix + prefactor_random_matrix = ( + torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) + / new_sqrt_volume + ) + a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + P_int + - trace_P_int[:, None, None] + / 3 + * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(P_int) + ) + random_matrix = torch.randn( + state.n_systems, + 3, + 3, + device=state.positions.device, + dtype=state.positions.dtype, + ) + random_matrix_tilde = random_matrix - torch.einsum("bii->b", random_matrix)[ + :, None, None + ] / 3 * torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(random_matrix) + deformation_matrix = torch.matrix_exp( + a_tilde * dt + prefactor_random_matrix[:, None, None] * random_matrix_tilde + ) + deformation_matrix = rotate_gram_schmidt(deformation_matrix) + + ## Step 3: propagate sqrt(volume) for dt/2 + new_sqrt_volume += -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt / 2 + prefactor_random * torch.randn_like(sqrt_vol) + rscaling = deformation_matrix * torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).view( + -1, 1, 1 + ) + + # Update positions and momenta (barostat + half momentum step) + state.positions = batch_matrix_vector( + rscaling[state.system_idx], state.positions + ) + batch_matrix_vector( + ( + torch.eye( + 3, device=state.positions.device, dtype=state.positions.dtype + ).expand_as(rscaling) + + rscaling + )[state.system_idx], + state.momenta, + ) * dt / (2 * state.masses.unsqueeze(-1)) + state.cell = rscaling.mT @ state.cell + return state + + +def _crescale_isotropic_barostat_step( + state: NPTCRescaleState, + kT: torch.Tensor, + dt: torch.Tensor, + external_pressure: torch.Tensor, +) -> NPTCRescaleState: + volume = torch.det(state.cell) # shape: (n_systems,) + P_int = ts.quantities.compute_instantaneous_pressure_tensor( + momenta=state.momenta, + masses=state.masses, + system_idx=state.system_idx, + stress=state.stress, + volumes=volume, + ) + sqrt_vol = torch.sqrt(volume) + trace_P_int = torch.einsum("bii->b", P_int) + prefactor_random = torch.sqrt( + kT * state.isothermal_compressibility * dt / (4 * state.tau_p) + ) + prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) + change_sqrt_vol = -prefactor * ( + external_pressure - trace_P_int / 3 - kT / (2 * volume) + ) * dt + prefactor_random * torch.randn_like(sqrt_vol) + new_sqrt_volume = sqrt_vol + change_sqrt_vol + + # Update positions and momenta (barostat + half momentum step) + # SI (S13ab): notice there is a typo in the SI where q_i(t) + # should be scaled as well by rscaling + rscaling = torch.pow((new_sqrt_volume / sqrt_vol), 2 / 3).unsqueeze(-1) + state.positions = rscaling[state.system_idx] * state.positions + ( + rscaling + 1 / rscaling + )[state.system_idx] * state.momenta * (0.5 * dt) / state.masses.unsqueeze(-1) + state.momenta = (1 / rscaling)[state.system_idx] * state.momenta + rscaling = rscaling.unsqueeze(-1) # make [n_systems, 1, 1] + state.cell = rscaling * state.cell + return state + + +def npt_crescale_anisotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. This method performs anisotropic + cell rescaling. Lengths and angles can change independently. Based on + pressure using kinetic energy. Positions and momenta are scaled when scaling the cell. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Instantaneous kinetic energy (not not the average from equipartition) + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_anisotropic_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_independent_lengths_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. + This method has 3 degrees of freedom for each cell length, + allowing independent scaling of each cell vector. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Instantaneous kinetic energy (not not the average from equipartition) + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_independent_lengths_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_average_anisotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. This method performs anisotropic + cell rescaling. Lengths and angles can change independently. Based on + pressure using average kinetic energy from equipartition theorem. + Only positions are scaled when scaling the cell. + + Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + - Time reversible integrator + - Average kinetic energy, scaling only positions + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_average_anisotropic_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_isotropic_step( + state: NPTCRescaleState, + model: ModelInterface, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + tau: torch.Tensor | None = None, +) -> NPTCRescaleState: + """Perform one NPT integration step with cell rescaling barostat. + + This function performs a single integration step for NPT dynamics using + a cell rescaling barostat. It updates particle positions, momenta, and + the simulation cell based on the target temperature and pressure. + + Trotter based splitting: + 1. Half Thermostat (velocity scaling) + 2. Half Update momenta with forces + 3. Barostat (cell rescaling) + 4. Update positions (from barostat + half momenta) + 5. Update forces with new positions and cell + 6. Compute forces + 7. Half Update momenta with forces + 8. Half Thermostat (velocity scaling) + + Only allow isotropic external stress. This performs isotropic + cell rescaling: cell shape is preserved, cell lengths are scaled equally. + For anisotropic cell rescaling, use npt_crescale_anisotropic_step. + + References: + - Bernetti, Mattia, and Giovanni Bussi. + "Pressure control using stochastic cell rescaling." + The Journal of Chemical Physics 153.11 (2020). + - And the corresponding Supplementary Information which details + the integration scheme. Notice an error in scaling of positions in SI Eq. S13a. + + Args: + model (ModelInterface): Model to compute forces and energies + state (NPTCRescaleState): Current system state + dt (torch.Tensor): Integration timestep + kT (torch.Tensor): Target temperature + external_pressure (torch.Tensor): Target external pressure + tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, + defaults to 100*dt + + Returns: + NPTCRescaleState: Updated state after one integration step + """ + # Note: would probably be better to have tau in NVTCRescaleState + if tau is None: + tau = 100 * dt + state = _vrescale_update(state, tau, kT, dt / 2) + + state = momentum_step(state, dt / 2) + + # Barostat step + state = _crescale_isotropic_barostat_step(state, kT, dt, external_pressure) + + # Forces + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + + # Final momentum step + state = momentum_step(state, dt / 2) + + # Final thermostat step + return _vrescale_update(state, tau, kT, dt / 2) + + +def npt_crescale_init( + state: SimState | StateDict, + model: ModelInterface, + *, + kT: torch.Tensor, + dt: torch.Tensor, + tau_p: torch.Tensor | None = None, + isothermal_compressibility: torch.Tensor | None = None, + seed: int | None = None, +) -> NPTCRescaleState: + """Initialize the NPT cell rescaling state. + + This function initializes a state for NPT molecular dynamics with a + cell rescaling barostat. It sets up the system with appropriate initial + conditions including particle positions, momenta, and cell variables. + + Only allow isotropic external stress, but can run both isotropic and + anisotropic cell rescaling. + + Args: + state: Initial system state as MDState or dict containing positions, masses, + cell, and PBC information + model (ModelInterface): Model to compute forces and energies + kT: Target temperature in energy units + dt: Integration timestep + tau_p: Barostat relaxation time. Controls how quickly pressure equilibrates. + isothermal_compressibility: Isothermal compressibility of the system. + seed: Random seed for momenta initialization. + """ + device, dtype = model.device, model.dtype + + # Set default values if not provided + if tau_p is None: + tau_p = 5000 * dt # 5ps for dt=1fs + if isothermal_compressibility is None: + isothermal_compressibility = 1e-1 # (eV/A^3)^-1 + + # Convert all parameters to tensors with correct device and dtype + tau_p = torch.as_tensor(tau_p, device=device, dtype=dtype) + isothermal_compressibility = torch.as_tensor( + isothermal_compressibility, device=device, dtype=dtype + ) + if tau_p.ndim == 0: + tau_p = tau_p.expand(state.n_systems) + if isothermal_compressibility.ndim == 0: + isothermal_compressibility = isothermal_compressibility.expand(state.n_systems) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + + if not isinstance(state, SimState): + state = SimState(**state) + + # Get model output to initialize forces and stress + model_output = model(state) + + # Initialize momenta if not provided + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + # Create the initial state + return NPTCRescaleState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + tau_p=tau_p, + isothermal_compressibility=isothermal_compressibility, + ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index e04984b4b..4bbcdb63b 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -465,3 +465,210 @@ def nvt_nose_hoover_invariant( e_tot = e_tot + chain_ke + chain_pe return e_tot + + +class NVTVRescaleState(MDState): + """State information for an NVT system with a V-Rescale thermostat. + + This class represents the complete state of a molecular system being integrated + in the NVT (constant particle number, volume, temperature) ensemble using a + Velocity Rescaling thermostat. The thermostat maintains constant temperature + through stochastic velocity rescaling. + + Attributes: + positions: Particle positions with shape [n_particles, n_dimensions] + masses: Particle masses with shape [n_particles] + cell: Simulation cell matrix with shape [n_dimensions, n_dimensions] + pbc: Whether to use periodic boundary conditions + momenta: Particle momenta with shape [n_particles, n_dimensions] + energy: Energy of the system + forces: Forces on particles with shape [n_particles, n_dimensions] + + Notes: + - The V-Rescale thermostat provides proper canonical sampling + - Stochastic velocity rescaling ensures correct temperature distribution + - Time-reversible when integrated with appropriate algorithms + """ + + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: + """Calculate the degrees of freedom per system.""" + # Subtract 3 for center of mass motion + return super().get_number_of_degrees_of_freedom() - 3 + + +def _vrescale_update( + state: MDState, + tau: float | torch.Tensor, + kT: float | torch.Tensor, + dt: float | torch.Tensor, +) -> MDState: + """Update the momentum by a scaling factor as described by Eq.A7 Bussi et al. + + Note that we don't implement the optimize code from Bussi, which won't be useful + on a high level framework like PyTorch. + + Args: + state: Current MD state + tau: Thermostat relaxation time + kT: Target temperature + dt: Integration timestep + + Returns: + Updated state with rescaled momenta + """ + device, dtype = state.device, state.dtype + + # Convert all inputs to tensors + tau_tensor = torch.as_tensor(tau, device=device, dtype=dtype) + kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) + dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) + + # Calculate current temperature per system + current_kT = state.calc_kT() + + # Calculate degrees of freedom per system + dof = state.get_number_of_degrees_of_freedom() + + # Ensure kT and tau have proper batch dimensions + n_systems = current_kT.shape[0] + if kT_tensor.dim() == 0: + kT_tensor = kT_tensor.expand(n_systems) + if tau_tensor.dim() == 0: + tau_tensor = tau_tensor.expand(n_systems) + + # Calculate kinetic energies + KE_old = dof * current_kT / 2 + KE_new = dof * kT_tensor / 2 + + # Generate random numbers + r1 = torch.randn(n_systems, device=device, dtype=dtype) + # Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1) + r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample() + + # Calculate scaling coefficients + c1 = torch.exp(-dt_tensor / tau_tensor) + c2 = (1 - c1) * KE_new / KE_old / dof + + # Calculate scaling factor + scale = c1 + (c2 * (torch.square(r1) + r2)) + (2 * r1 * torch.sqrt(c1 * c2)) + lam = torch.sqrt(scale) + + # Apply scaling to momenta - map from system to atom indices + state.momenta = state.momenta * lam[state.system_idx].unsqueeze(-1) + return state + + +def nvt_vrescale_init( + state: SimState | StateDict, + model: ModelInterface, + *, + kT: float | torch.Tensor, + seed: int | None = None, + **_kwargs: Any, +) -> NVTVRescaleState: + """Initialize an NVT state from input data for velocity rescaling dynamics. + + Creates an initial state for NVT molecular dynamics using the canonical + sampling through velocity rescaling (CSVR) thermostat. This thermostat + samples from the canonical ensemble by rescaling velocities with an + appropriately chosen random factor. + + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Either a SimState object or a dictionary containing positions, + masses, cell, pbc, and other required state vars + kT: Temperature in energy units for initializing momenta, + either scalar or with shape [n_systems] + seed: Random seed for reproducibility + + Returns: + MDState: Initialized state for NVT integration containing positions, + momenta, forces, energy, and other required attributes + + Notes: + The initial momenta are sampled from a Maxwell-Boltzmann distribution + at the specified temperature. The V-Rescale thermostat provides proper + canonical sampling through stochastic velocity rescaling. + """ + if not isinstance(state, SimState): + state = SimState(**state) + + model_output = model(state) + + momenta = getattr( + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), + ) + + return NVTVRescaleState( + positions=state.positions, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + system_idx=state.system_idx, + atomic_numbers=state.atomic_numbers, + ) + + +def nvt_vrescale_step( + model: ModelInterface, + state: NVTVRescaleState, + *, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + tau: float | torch.Tensor | None = None, +) -> NVTVRescaleState: + """Perform one complete V-Rescale dynamics integration step. + + This function implements the canonical sampling through velocity rescaling (V-Rescale) + thermostat combined with velocity Verlet integration. The V-Rescale thermostat samples + the canonical distribution by rescaling velocities with a properly chosen random + factor that ensures correct canonical sampling. + + Args: + model: Neural network model that computes energies and forces. + Must return a dict with 'energy' and 'forces' keys. + state: Current system state containing positions, momenta, forces + dt: Integration timestep, either scalar or shape [n_systems] + kT: Target temperature in energy units, either scalar or + with shape [n_systems] + tau: Thermostat relaxation time controlling the coupling strength, + either scalar or with shape [n_systems]. Defaults to 100*dt. + seed: Random seed for reproducibility + + Returns: + MDState: Updated state after one complete V-Rescale step with new positions, + momenta, forces, and energy + + Notes: + - Uses V-Rescale thermostat for proper canonical ensemble sampling + - Unlike Berendsen thermostat, V-Rescale samples the true canonical distribution + - Integration sequence: V-Rescale rescaling + Velocity Verlet step + - The rescaling factor follows the distribution derived in Bussi et al. + + References: + Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." + The Journal of chemical physics, 126(1), 014101 (2007). + """ + device, dtype = model.device, model.dtype + + if tau is None: + tau = 100 * dt + + if isinstance(tau, float): + tau = torch.tensor(tau, device=device, dtype=dtype) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + + # Apply V-Rescale rescaling + state = _vrescale_update(state, tau, kT, dt) + + # Perform velocity Verlet step + return velocity_verlet(state=state, dt=dt, model=model) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 35f7d6f06..30c8e6903 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -153,6 +153,51 @@ def get_pressure( return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) +def compute_instantaneous_pressure_tensor( + *, + momenta: torch.Tensor, + masses: torch.Tensor, + system_idx: torch.Tensor, + stress: torch.Tensor, + volumes: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the instantaneous internal pressure tensor. + + Args: + momenta (torch.Tensor): Particle momenta, shape (n_particles, 3) + masses (torch.Tensor): Particle masses, shape (n_particles,) + system_idx (torch.Tensor): Tensor indicating system membership of each particle + stress (torch.Tensor): Stress tensor of the system, shape (n_systems, 3, 3) + volumes (torch.Tensor): Volumes of the systems, shape (n_systems,) + + Returns: + torch.Tensor: Instanteneous internal pressure tesnor [n_systems, 3, 3] + """ + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) + + # Calculate virials: 2/V * (K_{tensor} - Virial_{tensor}) + twice_kinetic_energy_tensor = torch.einsum( + "bi,bj,b->bij", momenta, momenta, 1 / masses + ) + n_systems = stress.shape[0] + twice_kinetic_energy_tensor = torch.scatter_add( + torch.zeros( + n_systems, + 3, + 3, + device=momenta.device, + dtype=momenta.dtype, + ), + 0, + system_idx.unsqueeze(-1).unsqueeze(-1).expand_as(twice_kinetic_energy_tensor), + twice_kinetic_energy_tensor, + ) + return twice_kinetic_energy_tensor / volumes - stress + + def calc_heat_flux( momenta: torch.Tensor | None, masses: torch.Tensor,