Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https

TorchSim's package structure is summarized in the [API reference](https://radical-ai.github.io/torch-sim/reference/index.html) documentation and drawn as a treemap below.

![TorchSim package treemap](https://github.com/user-attachments/assets/1e67879b-cdca-4ebc-bbbd-061fed90dfed)
![TorchSim package treemap](https://github.com/user-attachments/assets/1ccb3a15-233d-4bc0-b11c-35a676a2bcf3)

## License

Expand Down
12 changes: 5 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,12 @@ docs = [
Repo = "https://github.com/radical-ai/torch-sim"

[build-system]
requires = ["hatchling>=1.27.0"]
build-backend = "hatchling.build"
requires = ["uv_build>=0.7.12"]
build-backend = "uv_build"

[tool.hatch.build.targets.wheel]
packages = ["torch_sim"]

[tool.hatch.build.targets.sdist]
include = ["/torch_sim"]
[tool.uv.build-backend]
module-name = "torch_sim"
module-root = ""

[tool.ruff]
target-version = "py311"
Expand Down
239 changes: 237 additions & 2 deletions tests/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
calculate_elastic_moduli,
calculate_elastic_tensor,
get_bravais_type,
get_cart_deformed_cell,
get_elementary_deformations,
get_strain,
)
from torch_sim.optimizers import frechet_cell_fire
from torch_sim.typing import BravaisType
Expand All @@ -20,8 +23,238 @@
pytest.skip("MACE not installed", allow_module_level=True)


def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None:
"""Test that zero deformation produces zero strain."""
# Test with same state as reference and deformed - should give zero strain
strain = get_strain(cu_sim_state, cu_sim_state)

expected_strain = torch.zeros(6, device=cu_sim_state.device, dtype=cu_sim_state.dtype)
torch.testing.assert_close(strain, expected_strain, atol=1e-12, rtol=1e-12)


def test_get_strain_pure_normal_strain(cu_sim_state: ts.SimState) -> None:
"""Test pure normal strain calculations (uniaxial extension/compression)."""
device = cu_sim_state.device
dtype = cu_sim_state.dtype

# Test pure xx strain (axis 0)
strain_magnitude = 0.05
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=0, size=strain_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

# Expected: only εxx should be non-zero and equal to strain_magnitude
# For pure normal strain, the symmetric tensor should give εxx = strain_magnitude
expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[0] = strain_magnitude # εxx

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)

# Test pure yy strain (axis 1)
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=1, size=strain_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[1] = strain_magnitude # εyy

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)

# Test pure zz strain (axis 2)
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=2, size=strain_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[2] = strain_magnitude # εzz

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)


def test_get_strain_pure_shear_strain(cu_sim_state: ts.SimState) -> None:
"""Test pure shear strain calculations and verify symmetric strain tensor."""
device = cu_sim_state.device
dtype = cu_sim_state.dtype

# Test yz shear strain (axis 3)
shear_magnitude = 0.08
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=3, size=shear_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

# For shear deformation, the displacement gradient u will have:
# u[1, 2] = shear_magnitude, but the symmetric strain is (u + u^T)/2
# So εyz = (u[1,2] + u[2,1])/2 = (shear_magnitude + 0)/2 = shear_magnitude/2
# This demonstrates the key symmetric strain tensor calculation at line 815
expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[3] = shear_magnitude / 2 # εyz = symmetric shear strain

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)

# Test xz shear strain (axis 4)
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=4, size=shear_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[4] = shear_magnitude / 2 # εxz = symmetric shear strain

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)

# Test xy shear strain (axis 5)
deformed_state = get_cart_deformed_cell(cu_sim_state, axis=5, size=shear_magnitude)
calculated_strain = get_strain(deformed_state, cu_sim_state)

expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[5] = shear_magnitude / 2 # εxy = symmetric shear strain

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)


def test_get_strain_hydrostatic_strain(cu_sim_state: ts.SimState) -> None:
"""Test hydrostatic strain (equal expansion/compression in all directions)."""
device = cu_sim_state.device
dtype = cu_sim_state.dtype

# Create hydrostatic deformation by scaling all cell vectors equally
hydro_strain = 0.03
original_cell = cu_sim_state.row_vector_cell.squeeze()

# Scale the cell uniformly (hydrostatic deformation)
hydro_deformation = torch.eye(3, device=device, dtype=dtype) * (1 + hydro_strain)
deformed_cell = torch.matmul(original_cell, hydro_deformation)

# Create deformed state manually
deformed_positions = cu_sim_state.positions * (1 + hydro_strain)
deformed_state = ts.SimState(
positions=deformed_positions,
cell=deformed_cell.mT.unsqueeze(0),
masses=cu_sim_state.masses,
pbc=cu_sim_state.pbc,
atomic_numbers=cu_sim_state.atomic_numbers,
)

calculated_strain = get_strain(deformed_state, cu_sim_state)

# For hydrostatic strain, εxx = εyy = εzz = hydro_strain, all shear components = 0
expected_strain = torch.zeros(6, device=device, dtype=dtype)
expected_strain[0] = hydro_strain # εxx
expected_strain[1] = hydro_strain # εyy
expected_strain[2] = hydro_strain # εzz
# εyz, εxz, εxy should remain zero

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)


def test_get_strain_symmetry_property(cu_sim_state: ts.SimState) -> None:
"""Test that the strain tensor calculation properly enforces symmetry (u + u^T)/2."""
device = cu_sim_state.device
dtype = cu_sim_state.dtype

# Create a deformation that would produce an asymmetric displacement gradient
# We'll manually create a deformed cell that would result in u[0,1] != u[1,0]
# but the symmetric strain tensor should symmetrize this

original_cell = cu_sim_state.row_vector_cell.squeeze()

# Create an asymmetric deformation matrix
asymmetric_deformation = torch.tensor(
[
[1.02, 0.03, 0.0], # This creates both normal and shear components
[0.0, 1.01, 0.0], # Different from symmetric case
[0.0, 0.0, 1.0],
],
device=device,
dtype=dtype,
)

deformed_cell = torch.matmul(original_cell, asymmetric_deformation)

# Convert positions to fractional, then back with new cell
frac_coords = torch.matmul(cu_sim_state.positions, torch.linalg.inv(original_cell))
deformed_positions = torch.matmul(frac_coords, deformed_cell)

deformed_state = ts.SimState(
positions=deformed_positions,
cell=deformed_cell.mT.unsqueeze(0),
masses=cu_sim_state.masses,
pbc=cu_sim_state.pbc,
atomic_numbers=cu_sim_state.atomic_numbers,
)

calculated_strain = get_strain(deformed_state, cu_sim_state)

# Manually calculate what the symmetric strain should be
cell_diff = deformed_cell - original_cell
u = torch.matmul(torch.linalg.inv(original_cell), cell_diff)
symmetric_strain_tensor = (u + u.mT) / 2

expected_strain = torch.tensor(
[
symmetric_strain_tensor[0, 0], # εxx
symmetric_strain_tensor[1, 1], # εyy
symmetric_strain_tensor[2, 2], # εzz
symmetric_strain_tensor[2, 1], # εyz
symmetric_strain_tensor[2, 0], # εxz
symmetric_strain_tensor[1, 0], # εxy
],
device=device,
dtype=dtype,
)

torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12)

# Verify that the shear components are properly symmetrized
# εxy should equal the average of the off-diagonal terms
expected_xy_strain = (u[1, 0] + u[0, 1]) / 2
assert torch.allclose(calculated_strain[5], expected_xy_strain, atol=1e-12)


def test_get_elementary_deformations_strain_consistency(
cu_sim_state: ts.SimState,
) -> None:
"""Test that deformations generated by get_elementary_deformations produce expected
strains."""
max_strain_normal = 0.02
max_strain_shear = 0.05
n_deform = 3

deformed_states = get_elementary_deformations(
cu_sim_state,
n_deform=n_deform,
max_strain_normal=max_strain_normal,
max_strain_shear=max_strain_shear,
bravais_type=BravaisType.TRICLINIC, # Test all axes
)

# Should generate deformations for all 6 axes (triclinic)
# Each axis generates n_deform-1 strains when n_deform is odd (excluding zero),
# or n_deform strains when n_deform is even (zero not included in linspace)
strains_per_axis = n_deform - 1 if n_deform % 2 == 1 else n_deform
expected_n_states = 6 * strains_per_axis
assert len(deformed_states) == expected_n_states

# Check that each deformed state produces a strain with expected dominant component
axis_to_strain_idx = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} # axis -> Voigt index

for i, deformed_state in enumerate(deformed_states):
strain = get_strain(deformed_state, cu_sim_state)

# Determine which axis this deformation corresponds to
axis = i // strains_per_axis # Integer division to get axis index
strain_idx = axis_to_strain_idx[axis]

# The strain component corresponding to this axis should be the largest
max_strain_component = torch.max(torch.abs(strain))
assert torch.isclose(
torch.abs(strain[strain_idx]), max_strain_component, rtol=1e-10, atol=1e-12
)

# Verify strain magnitude is within expected bounds
if axis < 3: # Normal strain
assert torch.abs(strain[strain_idx]) <= max_strain_normal + 1e-12
else: # Shear strain (factor of 2 due to symmetric strain tensor)
assert torch.abs(strain[strain_idx]) <= max_strain_shear / 2 + 1e-12


@pytest.fixture
def mace_model(device: torch.device) -> MaceModel:
"""Create a MACE model fixture for testing."""
mace_model = mace_mp(model="medium", default_dtype="float64", return_raw_model=True)

return MaceModel(
Expand Down Expand Up @@ -55,7 +288,7 @@ def test_elastic_tensor_symmetries(

Args:
sim_state_name: Name of the fixture containing the simulation state
model_fixture_name: Name of the model fixture to use
mace_model: MACE model fixture
expected_bravais_type: Expected Bravais lattice type
atol: Absolute tolerance for comparing elastic tensors
request: Pytest fixture request object
Expand Down Expand Up @@ -109,7 +342,9 @@ def test_elastic_tensor_symmetries(
)


def test_copper_elastic_properties(mace_model: MaceModel, cu_sim_state: ts.SimState):
def test_copper_elastic_properties(
mace_model: MaceModel, cu_sim_state: ts.SimState
) -> None:
"""Test calculation of elastic properties for copper."""

# Relax positions and cell
Expand Down
Loading