From 8a8521aefd6178ce7236815d66b3244029901ad9 Mon Sep 17 00:00:00 2001 From: Stefan Bringuier Date: Wed, 9 Apr 2025 10:19:01 -0700 Subject: [PATCH 1/2] enh: adds heat flux function with tests --- tests/test_quantities.py | 147 +++++++++++++++++++++++++++++++++++++++ torch_sim/quantities.py | 126 +++++++++++++++++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 tests/test_quantities.py diff --git a/tests/test_quantities.py b/tests/test_quantities.py new file mode 100644 index 000000000..0c7b3afb8 --- /dev/null +++ b/tests/test_quantities.py @@ -0,0 +1,147 @@ +"""Tests for quantities module functions.""" + +import pytest +import torch +from numpy.testing import assert_allclose + +from torch_sim.quantities import calc_heat_flux + + +class TestHeatFlux: + """Test suite for heat flux calculations.""" + + @pytest.fixture + def mock_simple_system( + self, + device: torch.device, + ) -> dict[str, torch.Tensor]: + """Simple system with known values.""" + return { + "velocities": torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=device, + ), + "energies": torch.tensor([1.0, 2.0, 3.0], device=device), + "stress": torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=device, + ), + "masses": torch.ones(3, device=device), + } + + def test_unbatched_total_flux( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test total heat flux calculation for unbatched case.""" + flux = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stress=mock_simple_system["stress"], + is_virial_only=False, + ) + + # Heat flux parts should cancel out + expected = torch.zeros(3, device=flux.device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_unbatched_virial_only( + self, mock_simple_system: dict[str, torch.Tensor] + ) -> None: + """Test virial-only heat flux calculation for unbatched case.""" + virial = calc_heat_flux( + momenta=None, + masses=mock_simple_system["masses"], + velocities=mock_simple_system["velocities"], + energies=mock_simple_system["energies"], + stress=mock_simple_system["stress"], + is_virial_only=True, + ) + + expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device) + assert_allclose(virial.cpu().numpy(), expected.cpu().numpy()) + + def test_batched_calculation(self, device: torch.device) -> None: + """Test heat flux calculation with batched data.""" + velocities = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + device=device, + ) + energies = torch.tensor([1.0, 2.0, 3.0], device=device) + stress = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], + ], + device=device, + ) + batch = torch.tensor([0, 0, 1], device=device) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(3, device=device), + velocities=velocities, + energies=energies, + stress=stress, + batch=batch, + ) + + # Each batch should cancel heat flux parts + expected = torch.zeros((2, 3), device=device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_centroid_stress(self, device: torch.device) -> None: + """Test heat flux with centroid stress formulation.""" + velocities = torch.tensor([[1.0, 1.0, 1.0]], device=device) + energies = torch.tensor([1.0], device=device) + + # Symmetric cross-terms + stress = torch.tensor( + [[1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], device=device + ) + + flux = calc_heat_flux( + momenta=None, + masses=torch.ones(1, device=device), + velocities=velocities, + energies=energies, + stress=stress, + is_centroid_stress=True, + ) + + # Heatflux should be [-1,-1,-1] + expected = torch.full((3,), -1.0, device=device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) + + def test_momenta_input(self, device: torch.device) -> None: + """Test heat flux calculation using momenta instead.""" + momenta = torch.tensor([[1.0, 0.0, 0.0]], device=device) + masses = torch.tensor([2.0], device=device) + energies = torch.tensor([1.0], device=device) + stress = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], device=device) + + flux = calc_heat_flux( + momenta=momenta, + masses=masses, + velocities=None, + energies=energies, + stress=stress, + ) + + # Heat flux terms should cancel out + expected = torch.zeros(3, device=device) + assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 594ebe5b0..9ed82e2ee 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -133,6 +133,132 @@ def calc_kinetic_energy( ) +def calc_heat_flux( + momenta: torch.Tensor | None, + masses: torch.Tensor, + velocities: torch.Tensor | None, + energies: torch.Tensor, + stress: torch.Tensor, + batch: torch.Tensor | None = None, + *, # Force keyword arguments for booleans + is_centroid_stress: bool = False, + is_virial_only: bool = False, +) -> torch.Tensor: + r"""Calculate the heat flux vector. + + Computes the microscopic heat flux, :math:`\mathbf{J}` + defined as: + + .. math:: + \mathbf{J} = \mathbf{J}^c + \mathbf{J}^v + + where the convective part :math:`\mathbf{J}^c` and virial part + :math:`\mathbf{J}^v` are: + + .. math:: + \mathbf{J}^c &= \sum_i \epsilon_i \mathbf{v}_i \\ + \mathbf{J}^v &= \sum_i \sum_j \mathbf{S}_{ij} \cdot \mathbf{v}_j + + where :math:`\epsilon_i` is the per-atom energy (p.e. + k.e.), + :math:`\mathbf{v}_i` is velocity, and :math:`\mathbf{S}_{ij}` is the + per-atom stress tensor. + + Args: + momenta: Particle momenta, shape (n_particles, n_dim) + masses: Particle masses, shape (n_particles,) + velocities: Particle velocities, shape (n_particles, n_dim) + energies: Per-atom energies (p.e. + k.e.), shape (n_particles,) + stress: Per-atom stress tensor components: + - If is_centroid_stress=False: shape (n_particles, 6) for + :math:`[\sigma_{xx}, \sigma_{yy}, \sigma_{zz}, + \sigma_{xy}, \sigma_{xz}, \sigma_{yz}]` + - If is_centroid_stress=True: shape (n_particles, 9) for + :math:`[\mathbf{r}_{ix}f_{ix}, \mathbf{r}_{iy}f_{iy}, + \mathbf{r}_{iz}f_{iz}, \mathbf{r}_{ix}f_{iy}, + \mathbf{r}_{ix}f_{iz}, \mathbf{r}_{iy}f_{iz}, + \mathbf{r}_{iy}f_{ix}, \mathbf{r}_{iz}f_{ix}, + \mathbf{r}_{iz}f_{iy}]` + batch: Optional tensor indicating batch membership + is_centroid_stress: Whether stress uses centroid formulation + is_virial_only: If True, returns only virial part :math:`\mathbf{J}^v` + + Returns: + Heat flux vector of shape (3,) or (n_batches, 3) + """ + if momenta is not None and velocities is not None: + raise ValueError("Must pass either momenta or velocities, not both") + if momenta is None and velocities is None: + raise ValueError("Must pass either momenta or velocities") + + # Deduce velocities + if velocities is None: + velocities = momenta / masses.unsqueeze(-1) + + convective_flux = energies.unsqueeze(-1) * velocities + + # Calculate virial flux + if is_centroid_stress: + # Centroid formulation: r_i[x,y,z] . f_i[x,y,z] + virial_x = -( + stress[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x + + stress[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y + + stress[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z + ) + virial_y = -( + stress[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x + + stress[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y + + stress[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z + ) + virial_z = -( + stress[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x + + stress[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y + + stress[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z + ) + else: + # Standard stress tensor components + virial_x = -( + stress[:, 0] * velocities[:, 0] # s_xx.v_x + + stress[:, 3] * velocities[:, 1] # s_xy.v_y + + stress[:, 4] * velocities[:, 2] # s_xz.v_z + ) + virial_y = -( + stress[:, 3] * velocities[:, 0] # s_xy.v_x + + stress[:, 1] * velocities[:, 1] # s_yy.v_y + + stress[:, 5] * velocities[:, 2] # s_yz.v_z + ) + virial_z = -( + stress[:, 4] * velocities[:, 0] # s_xz.v_x + + stress[:, 5] * velocities[:, 1] # s_yz.v_y + + stress[:, 2] * velocities[:, 2] # s_zz.v_z + ) + + virial_flux = torch.stack([virial_x, virial_y, virial_z], dim=-1) + + if batch is None: + # All atoms + virial_sum = torch.sum(virial_flux, dim=0) + if is_virial_only: + return virial_sum + conv_sum = torch.sum(convective_flux, dim=0) + return conv_sum + virial_sum + + # All atoms in each batch + n_batches = int(torch.max(batch).item() + 1) + virial_sum = torch.zeros( + (n_batches, 3), device=velocities.device, dtype=velocities.dtype + ) + virial_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), virial_flux) + + if is_virial_only: + return virial_sum + + conv_sum = torch.zeros( + (n_batches, 3), device=velocities.device, dtype=velocities.dtype + ) + conv_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), convective_flux) + return conv_sum + virial_sum + + def batchwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per batch. From c0d30c15b5911e8d35ae0978ce3e178e30b89ae8 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 1 Oct 2025 17:43:30 -0400 Subject: [PATCH 2/2] pick relevant changes from #264 --- tests/test_quantities.py | 97 +++++++++++++++++---------------- torch_sim/quantities.py | 115 ++++++++++++++++++--------------------- 2 files changed, 102 insertions(+), 110 deletions(-) diff --git a/tests/test_quantities.py b/tests/test_quantities.py index c3a885ba9..ee17a6fd3 100644 --- a/tests/test_quantities.py +++ b/tests/test_quantities.py @@ -5,23 +5,24 @@ from numpy.testing import assert_allclose from torch._tensor import Tensor -from torch_sim import quantities -from torch_sim.quantities import calc_heat_flux +from torch_sim.quantities import ( + calc_heat_flux, + calc_kinetic_energy, + calc_kT, + calc_temperature, +) from torch_sim.units import MetalUnits -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -DTYPE = torch.double +DTYPE = torch.float64 +DEVICE = torch.device("cpu") class TestHeatFlux: """Test suite for heat flux calculations.""" @pytest.fixture - def mock_simple_system( - self, - device: torch.device, - ) -> dict[str, torch.Tensor]: + def mock_simple_system(self) -> dict[str, torch.Tensor]: """Simple system with known values.""" return { "velocities": torch.tensor( @@ -30,18 +31,18 @@ def mock_simple_system( [0.0, 2.0, 0.0], [0.0, 0.0, 3.0], ], - device=device, + device=DEVICE, ), - "energies": torch.tensor([1.0, 2.0, 3.0], device=device), + "energies": torch.tensor([1.0, 2.0, 3.0], device=DEVICE), "stress": torch.tensor( [ [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], ], - device=device, + device=DEVICE, ), - "masses": torch.ones(3, device=device), + "masses": torch.ones(3, device=DEVICE), } def test_unbatched_total_flux( @@ -53,7 +54,7 @@ def test_unbatched_total_flux( masses=mock_simple_system["masses"], velocities=mock_simple_system["velocities"], energies=mock_simple_system["energies"], - stress=mock_simple_system["stress"], + stresses=mock_simple_system["stress"], is_virial_only=False, ) @@ -70,14 +71,14 @@ def test_unbatched_virial_only( masses=mock_simple_system["masses"], velocities=mock_simple_system["velocities"], energies=mock_simple_system["energies"], - stress=mock_simple_system["stress"], + stresses=mock_simple_system["stress"], is_virial_only=True, ) expected = -torch.tensor([1.0, 4.0, 9.0], device=virial.device) assert_allclose(virial.cpu().numpy(), expected.cpu().numpy()) - def test_batched_calculation(self, device: torch.device) -> None: + def test_batched_calculation(self) -> None: """Test heat flux calculation with batched data.""" velocities = torch.tensor( [ @@ -85,72 +86,72 @@ def test_batched_calculation(self, device: torch.device) -> None: [0.0, 2.0, 0.0], [0.0, 0.0, 3.0], ], - device=device, + device=DEVICE, ) - energies = torch.tensor([1.0, 2.0, 3.0], device=device) + energies = torch.tensor([1.0, 2.0, 3.0], device=DEVICE) stress = torch.tensor( [ [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.0, 0.0, 0.0, 0.0], ], - device=device, + device=DEVICE, ) - batch = torch.tensor([0, 0, 1], device=device) + batch = torch.tensor([0, 0, 1], device=DEVICE) flux = calc_heat_flux( momenta=None, - masses=torch.ones(3, device=device), + masses=torch.ones(3, device=DEVICE), velocities=velocities, energies=energies, - stress=stress, + stresses=stress, batch=batch, ) # Each batch should cancel heat flux parts - expected = torch.zeros((2, 3), device=device) + expected = torch.zeros((2, 3), device=DEVICE) assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) - def test_centroid_stress(self, device: torch.device) -> None: + def test_centroid_stress(self) -> None: """Test heat flux with centroid stress formulation.""" - velocities = torch.tensor([[1.0, 1.0, 1.0]], device=device) - energies = torch.tensor([1.0], device=device) + velocities = torch.tensor([[1.0, 1.0, 1.0]], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) # Symmetric cross-terms stress = torch.tensor( - [[1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], device=device + [[1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], device=DEVICE ) flux = calc_heat_flux( momenta=None, - masses=torch.ones(1, device=device), + masses=torch.ones(1, device=DEVICE), velocities=velocities, energies=energies, - stress=stress, + stresses=stress, is_centroid_stress=True, ) # Heatflux should be [-1,-1,-1] - expected = torch.full((3,), -1.0, device=device) + expected = torch.full((3,), -1.0, device=DEVICE) assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) - def test_momenta_input(self, device: torch.device) -> None: + def test_momenta_input(self) -> None: """Test heat flux calculation using momenta instead.""" - momenta = torch.tensor([[1.0, 0.0, 0.0]], device=device) - masses = torch.tensor([2.0], device=device) - energies = torch.tensor([1.0], device=device) - stress = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], device=device) + momenta = torch.tensor([[1.0, 0.0, 0.0]], device=DEVICE) + masses = torch.tensor([2.0], device=DEVICE) + energies = torch.tensor([1.0], device=DEVICE) + stress = torch.tensor([[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], device=DEVICE) flux = calc_heat_flux( momenta=momenta, masses=masses, velocities=None, energies=energies, - stress=stress, + stresses=stress, ) # Heat flux terms should cancel out - expected = torch.zeros(3, device=device) + expected = torch.zeros(3, device=DEVICE) assert_allclose(flux.cpu().numpy(), expected.cpu().numpy()) @@ -190,14 +191,14 @@ def batched_system_data() -> dict[str, Tensor]: def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: # With velocities - ke_vel = quantities.calc_kinetic_energy( + ke_vel = calc_kinetic_energy( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) assert torch.allclose(ke_vel, single_system_data["ke"]) # With momenta - ke_mom = quantities.calc_kinetic_energy( + ke_mom = calc_kinetic_energy( masses=single_system_data["masses"], momenta=single_system_data["momenta"] ) assert torch.allclose(ke_mom, single_system_data["ke"]) @@ -207,7 +208,7 @@ def test_calc_kinetic_energy_batched_system( batched_system_data: dict[str, Tensor], ) -> None: # With velocities - ke_vel = quantities.calc_kinetic_energy( + ke_vel = calc_kinetic_energy( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], @@ -215,7 +216,7 @@ def test_calc_kinetic_energy_batched_system( assert torch.allclose(ke_vel, batched_system_data["ke"]) # With momenta - ke_mom = quantities.calc_kinetic_energy( + ke_mom = calc_kinetic_energy( masses=batched_system_data["masses"], momenta=batched_system_data["momenta"], system_idx=batched_system_data["system_idx"], @@ -225,26 +226,26 @@ def test_calc_kinetic_energy_batched_system( def test_calc_kinetic_energy_errors(single_system_data: dict[str, Tensor]) -> None: with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy( + calc_kinetic_energy( masses=single_system_data["masses"], momenta=single_system_data["momenta"], velocities=single_system_data["velocities"], ) with pytest.raises(ValueError, match="Must pass either one of momenta or velocities"): - quantities.calc_kinetic_energy(masses=single_system_data["masses"]) + calc_kinetic_energy(masses=single_system_data["masses"]) def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: # With velocities - kt_vel = quantities.calc_kT( + kt_vel = calc_kT( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) assert torch.allclose(kt_vel, single_system_data["kt"]) # With momenta - kt_mom = quantities.calc_kT( + kt_mom = calc_kT( masses=single_system_data["masses"], momenta=single_system_data["momenta"] ) assert torch.allclose(kt_mom, single_system_data["kt"]) @@ -252,7 +253,7 @@ def test_calc_kt_single_system(single_system_data: dict[str, Tensor]) -> None: def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: # With velocities - kt_vel = quantities.calc_kT( + kt_vel = calc_kT( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], @@ -260,7 +261,7 @@ def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: assert torch.allclose(kt_vel, batched_system_data["kt"]) # With momenta - kt_mom = quantities.calc_kT( + kt_mom = calc_kT( masses=batched_system_data["masses"], momenta=batched_system_data["momenta"], system_idx=batched_system_data["system_idx"], @@ -269,11 +270,11 @@ def test_calc_kt_batched_system(batched_system_data: dict[str, Tensor]) -> None: def test_calc_temperature(single_system_data: dict[str, Tensor]) -> None: - temp = quantities.calc_temperature( + temp = calc_temperature( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) - kt = quantities.calc_kT( + kt = calc_kT( masses=single_system_data["masses"], velocities=single_system_data["velocities"], ) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 97d4589f8..404d07f32 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -1,26 +1,11 @@ """Functions for computing physical quantities.""" -from typing import cast - import torch from torch_sim.state import SimState from torch_sim.units import MetalUnits -# @torch.jit.script -def count_dof(tensor: torch.Tensor) -> int: - """Count the degrees of freedom in the system. - - Args: - tensor: Tensor to count the degrees of freedom in - - Returns: - Number of degrees of freedom - """ - return tensor.numel() - - # @torch.jit.script def calc_kT( # noqa: N802 *, @@ -44,17 +29,18 @@ def calc_kT( # noqa: N802 if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: + if momenta is None and velocities is not None: # If velocity provided, calculate mv^2 - velocities = cast("torch.Tensor", velocities) - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # If momentum provided, calculate v^2 = p^2/m^2 - squared_term = (momenta**2) / masses.unsqueeze(-1) + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: # Count total degrees of freedom - dof = count_dof(squared_term) + dof = squared_term.numel() return torch.sum(squared_term) / dof # Sum squared terms for each system flattened_squared = torch.sum(squared_term, dim=-1) @@ -121,10 +107,12 @@ def calc_kinetic_energy( if not ((momenta is not None) ^ (velocities is not None)): raise ValueError("Must pass either one of momenta or velocities") - if momenta is None: # Using velocities - squared_term = (velocities**2) * masses.unsqueeze(-1) - else: # Using momenta - squared_term = (momenta**2) / masses.unsqueeze(-1) + if momenta is None and velocities is not None: # Using velocities + squared_term = torch.square(velocities) * masses.unsqueeze(-1) + elif momenta is not None and velocities is None: # Using momenta + squared_term = torch.square(momenta) / masses.unsqueeze(-1) + else: + raise ValueError("Must pass either one of momenta or velocities") if system_idx is None: return 0.5 * torch.sum(squared_term) @@ -134,12 +122,26 @@ def calc_kinetic_energy( ) +def get_pressure( + stress: torch.Tensor, + kinetic_energy: float | torch.Tensor, + volume: torch.Tensor, + dim: int = 3, +) -> torch.Tensor: + """Compute the pressure from the stress tensor. + + The stress tensor is defined as 1/volume * dU/de_ij + So the pressure is -1/volume * trace(dU/de_ij) + """ + return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) + + def calc_heat_flux( momenta: torch.Tensor | None, masses: torch.Tensor, velocities: torch.Tensor | None, energies: torch.Tensor, - stress: torch.Tensor, + stresses: torch.Tensor, batch: torch.Tensor | None = None, *, # Force keyword arguments for booleans is_centroid_stress: bool = False, @@ -169,7 +171,7 @@ def calc_heat_flux( masses: Particle masses, shape (n_particles,) velocities: Particle velocities, shape (n_particles, n_dim) energies: Per-atom energies (p.e. + k.e.), shape (n_particles,) - stress: Per-atom stress tensor components: + stresses: Per-atom stress tensor components: - If is_centroid_stress=False: shape (n_particles, 6) for :math:`[\sigma_{xx}, \sigma_{yy}, \sigma_{zz}, \sigma_{xy}, \sigma_{xz}, \sigma_{yz}]` @@ -179,12 +181,12 @@ def calc_heat_flux( \mathbf{r}_{ix}f_{iz}, \mathbf{r}_{iy}f_{iz}, \mathbf{r}_{iy}f_{ix}, \mathbf{r}_{iz}f_{ix}, \mathbf{r}_{iz}f_{iy}]` - batch: Optional tensor indicating batch membership + batch: Optional tensor indicating system membership is_centroid_stress: Whether stress uses centroid formulation is_virial_only: If True, returns only virial part :math:`\mathbf{J}^v` Returns: - Heat flux vector of shape (3,) or (n_batches, 3) + Heat flux vector of shape (3,) or (n_systems, 3) """ if momenta is not None and velocities is not None: raise ValueError("Must pass either momenta or velocities, not both") @@ -201,36 +203,36 @@ def calc_heat_flux( if is_centroid_stress: # Centroid formulation: r_i[x,y,z] . f_i[x,y,z] virial_x = -( - stress[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x - + stress[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y - + stress[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z + stresses[:, 0] * velocities[:, 0] # r_ix.f_ix.v_x + + stresses[:, 3] * velocities[:, 1] # r_ix.f_iy.v_y + + stresses[:, 4] * velocities[:, 2] # r_ix.f_iz.v_z ) virial_y = -( - stress[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x - + stress[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y - + stress[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z + stresses[:, 6] * velocities[:, 0] # r_iy.f_ix.v_x + + stresses[:, 1] * velocities[:, 1] # r_iy.f_iy.v_y + + stresses[:, 5] * velocities[:, 2] # r_iy.f_iz.v_z ) virial_z = -( - stress[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x - + stress[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y - + stress[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z + stresses[:, 7] * velocities[:, 0] # r_iz.f_ix.v_x + + stresses[:, 8] * velocities[:, 1] # r_iz.f_iy.v_y + + stresses[:, 2] * velocities[:, 2] # r_iz.f_iz.v_z ) else: # Standard stress tensor components virial_x = -( - stress[:, 0] * velocities[:, 0] # s_xx.v_x - + stress[:, 3] * velocities[:, 1] # s_xy.v_y - + stress[:, 4] * velocities[:, 2] # s_xz.v_z + stresses[:, 0] * velocities[:, 0] # s_xx.v_x + + stresses[:, 3] * velocities[:, 1] # s_xy.v_y + + stresses[:, 4] * velocities[:, 2] # s_xz.v_z ) virial_y = -( - stress[:, 3] * velocities[:, 0] # s_xy.v_x - + stress[:, 1] * velocities[:, 1] # s_yy.v_y - + stress[:, 5] * velocities[:, 2] # s_yz.v_z + stresses[:, 3] * velocities[:, 0] # s_xy.v_x + + stresses[:, 1] * velocities[:, 1] # s_yy.v_y + + stresses[:, 5] * velocities[:, 2] # s_yz.v_z ) virial_z = -( - stress[:, 4] * velocities[:, 0] # s_xz.v_x - + stress[:, 5] * velocities[:, 1] # s_yz.v_y - + stress[:, 2] * velocities[:, 2] # s_zz.v_z + stresses[:, 4] * velocities[:, 0] # s_xz.v_x + + stresses[:, 5] * velocities[:, 1] # s_yz.v_y + + stresses[:, 2] * velocities[:, 2] # s_zz.v_z ) virial_flux = torch.stack([virial_x, virial_y, virial_z], dim=-1) @@ -243,10 +245,10 @@ def calc_heat_flux( conv_sum = torch.sum(convective_flux, dim=0) return conv_sum + virial_sum - # All atoms in each batch - n_batches = int(torch.max(batch).item() + 1) + # All atoms in each system + n_systems = int(torch.max(batch) + 1) virial_sum = torch.zeros( - (n_batches, 3), device=velocities.device, dtype=velocities.dtype + (n_systems, 3), device=velocities.device, dtype=velocities.dtype ) virial_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), virial_flux) @@ -254,23 +256,12 @@ def calc_heat_flux( return virial_sum conv_sum = torch.zeros( - (n_batches, 3), device=velocities.device, dtype=velocities.dtype + (n_systems, 3), device=velocities.device, dtype=velocities.dtype ) conv_sum.scatter_add_(0, batch.unsqueeze(-1).expand(-1, 3), convective_flux) return conv_sum + virial_sum -def get_pressure( - stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 -) -> torch.Tensor: - """Compute the pressure from the stress tensor. - - The stress tensor is defined as 1/volume * dU/de_ij - So the pressure is -1/volume * trace(dU/de_ij) - """ - return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) - - def systemwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per system.