diff --git a/README.md b/README.md index bc283dad1..1fe0c16ed 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https TorchSim's package structure is summarized in the [API reference](https://radical-ai.github.io/torch-sim/reference/index.html) documentation and drawn as a treemap below. -![TorchSim package treemap](https://github.com/user-attachments/assets/1e67879b-cdca-4ebc-bbbd-061fed90dfed) +![TorchSim package treemap](https://github.com/user-attachments/assets/1ccb3a15-233d-4bc0-b11c-35a676a2bcf3) ## License diff --git a/pyproject.toml b/pyproject.toml index 28e812f84..a87cadc1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,14 +71,12 @@ docs = [ Repo = "https://github.com/radical-ai/torch-sim" [build-system] -requires = ["hatchling>=1.27.0"] -build-backend = "hatchling.build" +requires = ["uv_build>=0.7.12"] +build-backend = "uv_build" -[tool.hatch.build.targets.wheel] -packages = ["torch_sim"] - -[tool.hatch.build.targets.sdist] -include = ["/torch_sim"] +[tool.uv.build-backend] +module-name = "torch_sim" +module-root = "" [tool.ruff] target-version = "py311" diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 7a1479830..91a063da0 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -6,6 +6,9 @@ calculate_elastic_moduli, calculate_elastic_tensor, get_bravais_type, + get_cart_deformed_cell, + get_elementary_deformations, + get_strain, ) from torch_sim.optimizers import frechet_cell_fire from torch_sim.typing import BravaisType @@ -20,8 +23,238 @@ pytest.skip("MACE not installed", allow_module_level=True) +def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None: + """Test that zero deformation produces zero strain.""" + # Test with same state as reference and deformed - should give zero strain + strain = get_strain(cu_sim_state, cu_sim_state) + + expected_strain = torch.zeros(6, device=cu_sim_state.device, dtype=cu_sim_state.dtype) + torch.testing.assert_close(strain, expected_strain, atol=1e-12, rtol=1e-12) + + +def test_get_strain_pure_normal_strain(cu_sim_state: ts.SimState) -> None: + """Test pure normal strain calculations (uniaxial extension/compression).""" + device = cu_sim_state.device + dtype = cu_sim_state.dtype + + # Test pure xx strain (axis 0) + strain_magnitude = 0.05 + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=0, size=strain_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + # Expected: only εxx should be non-zero and equal to strain_magnitude + # For pure normal strain, the symmetric tensor should give εxx = strain_magnitude + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[0] = strain_magnitude # εxx + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + # Test pure yy strain (axis 1) + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=1, size=strain_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[1] = strain_magnitude # εyy + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + # Test pure zz strain (axis 2) + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=2, size=strain_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[2] = strain_magnitude # εzz + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + +def test_get_strain_pure_shear_strain(cu_sim_state: ts.SimState) -> None: + """Test pure shear strain calculations and verify symmetric strain tensor.""" + device = cu_sim_state.device + dtype = cu_sim_state.dtype + + # Test yz shear strain (axis 3) + shear_magnitude = 0.08 + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=3, size=shear_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + # For shear deformation, the displacement gradient u will have: + # u[1, 2] = shear_magnitude, but the symmetric strain is (u + u^T)/2 + # So εyz = (u[1,2] + u[2,1])/2 = (shear_magnitude + 0)/2 = shear_magnitude/2 + # This demonstrates the key symmetric strain tensor calculation at line 815 + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[3] = shear_magnitude / 2 # εyz = symmetric shear strain + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + # Test xz shear strain (axis 4) + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=4, size=shear_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[4] = shear_magnitude / 2 # εxz = symmetric shear strain + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + # Test xy shear strain (axis 5) + deformed_state = get_cart_deformed_cell(cu_sim_state, axis=5, size=shear_magnitude) + calculated_strain = get_strain(deformed_state, cu_sim_state) + + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[5] = shear_magnitude / 2 # εxy = symmetric shear strain + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + +def test_get_strain_hydrostatic_strain(cu_sim_state: ts.SimState) -> None: + """Test hydrostatic strain (equal expansion/compression in all directions).""" + device = cu_sim_state.device + dtype = cu_sim_state.dtype + + # Create hydrostatic deformation by scaling all cell vectors equally + hydro_strain = 0.03 + original_cell = cu_sim_state.row_vector_cell.squeeze() + + # Scale the cell uniformly (hydrostatic deformation) + hydro_deformation = torch.eye(3, device=device, dtype=dtype) * (1 + hydro_strain) + deformed_cell = torch.matmul(original_cell, hydro_deformation) + + # Create deformed state manually + deformed_positions = cu_sim_state.positions * (1 + hydro_strain) + deformed_state = ts.SimState( + positions=deformed_positions, + cell=deformed_cell.mT.unsqueeze(0), + masses=cu_sim_state.masses, + pbc=cu_sim_state.pbc, + atomic_numbers=cu_sim_state.atomic_numbers, + ) + + calculated_strain = get_strain(deformed_state, cu_sim_state) + + # For hydrostatic strain, εxx = εyy = εzz = hydro_strain, all shear components = 0 + expected_strain = torch.zeros(6, device=device, dtype=dtype) + expected_strain[0] = hydro_strain # εxx + expected_strain[1] = hydro_strain # εyy + expected_strain[2] = hydro_strain # εzz + # εyz, εxz, εxy should remain zero + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + +def test_get_strain_symmetry_property(cu_sim_state: ts.SimState) -> None: + """Test that the strain tensor calculation properly enforces symmetry (u + u^T)/2.""" + device = cu_sim_state.device + dtype = cu_sim_state.dtype + + # Create a deformation that would produce an asymmetric displacement gradient + # We'll manually create a deformed cell that would result in u[0,1] != u[1,0] + # but the symmetric strain tensor should symmetrize this + + original_cell = cu_sim_state.row_vector_cell.squeeze() + + # Create an asymmetric deformation matrix + asymmetric_deformation = torch.tensor( + [ + [1.02, 0.03, 0.0], # This creates both normal and shear components + [0.0, 1.01, 0.0], # Different from symmetric case + [0.0, 0.0, 1.0], + ], + device=device, + dtype=dtype, + ) + + deformed_cell = torch.matmul(original_cell, asymmetric_deformation) + + # Convert positions to fractional, then back with new cell + frac_coords = torch.matmul(cu_sim_state.positions, torch.linalg.inv(original_cell)) + deformed_positions = torch.matmul(frac_coords, deformed_cell) + + deformed_state = ts.SimState( + positions=deformed_positions, + cell=deformed_cell.mT.unsqueeze(0), + masses=cu_sim_state.masses, + pbc=cu_sim_state.pbc, + atomic_numbers=cu_sim_state.atomic_numbers, + ) + + calculated_strain = get_strain(deformed_state, cu_sim_state) + + # Manually calculate what the symmetric strain should be + cell_diff = deformed_cell - original_cell + u = torch.matmul(torch.linalg.inv(original_cell), cell_diff) + symmetric_strain_tensor = (u + u.mT) / 2 + + expected_strain = torch.tensor( + [ + symmetric_strain_tensor[0, 0], # εxx + symmetric_strain_tensor[1, 1], # εyy + symmetric_strain_tensor[2, 2], # εzz + symmetric_strain_tensor[2, 1], # εyz + symmetric_strain_tensor[2, 0], # εxz + symmetric_strain_tensor[1, 0], # εxy + ], + device=device, + dtype=dtype, + ) + + torch.testing.assert_close(calculated_strain, expected_strain, atol=1e-12, rtol=1e-12) + + # Verify that the shear components are properly symmetrized + # εxy should equal the average of the off-diagonal terms + expected_xy_strain = (u[1, 0] + u[0, 1]) / 2 + assert torch.allclose(calculated_strain[5], expected_xy_strain, atol=1e-12) + + +def test_get_elementary_deformations_strain_consistency( + cu_sim_state: ts.SimState, +) -> None: + """Test that deformations generated by get_elementary_deformations produce expected + strains.""" + max_strain_normal = 0.02 + max_strain_shear = 0.05 + n_deform = 3 + + deformed_states = get_elementary_deformations( + cu_sim_state, + n_deform=n_deform, + max_strain_normal=max_strain_normal, + max_strain_shear=max_strain_shear, + bravais_type=BravaisType.TRICLINIC, # Test all axes + ) + + # Should generate deformations for all 6 axes (triclinic) + # Each axis generates n_deform-1 strains when n_deform is odd (excluding zero), + # or n_deform strains when n_deform is even (zero not included in linspace) + strains_per_axis = n_deform - 1 if n_deform % 2 == 1 else n_deform + expected_n_states = 6 * strains_per_axis + assert len(deformed_states) == expected_n_states + + # Check that each deformed state produces a strain with expected dominant component + axis_to_strain_idx = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} # axis -> Voigt index + + for i, deformed_state in enumerate(deformed_states): + strain = get_strain(deformed_state, cu_sim_state) + + # Determine which axis this deformation corresponds to + axis = i // strains_per_axis # Integer division to get axis index + strain_idx = axis_to_strain_idx[axis] + + # The strain component corresponding to this axis should be the largest + max_strain_component = torch.max(torch.abs(strain)) + assert torch.isclose( + torch.abs(strain[strain_idx]), max_strain_component, rtol=1e-10, atol=1e-12 + ) + + # Verify strain magnitude is within expected bounds + if axis < 3: # Normal strain + assert torch.abs(strain[strain_idx]) <= max_strain_normal + 1e-12 + else: # Shear strain (factor of 2 due to symmetric strain tensor) + assert torch.abs(strain[strain_idx]) <= max_strain_shear / 2 + 1e-12 + + @pytest.fixture def mace_model(device: torch.device) -> MaceModel: + """Create a MACE model fixture for testing.""" mace_model = mace_mp(model="medium", default_dtype="float64", return_raw_model=True) return MaceModel( @@ -55,7 +288,7 @@ def test_elastic_tensor_symmetries( Args: sim_state_name: Name of the fixture containing the simulation state - model_fixture_name: Name of the model fixture to use + mace_model: MACE model fixture expected_bravais_type: Expected Bravais lattice type atol: Absolute tolerance for comparing elastic tensors request: Pytest fixture request object @@ -109,7 +342,9 @@ def test_elastic_tensor_symmetries( ) -def test_copper_elastic_properties(mace_model: MaceModel, cu_sim_state: ts.SimState): +def test_copper_elastic_properties( + mace_model: MaceModel, cu_sim_state: ts.SimState +) -> None: """Test calculation of elastic properties for copper.""" # Relax positions and cell diff --git a/tests/test_runners.py b/tests/test_runners.py index 1b7c6260f..d25da25cf 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from pathlib import Path import numpy as np @@ -798,3 +799,202 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: ) print(relaxed_state.energy) + + +@pytest.fixture +def mock_state() -> Callable: + """Create a mock state for testing convergence functions.""" + device = torch.device("cpu") + dtype = torch.float64 + n_batches, n_atoms = 2, 8 + torch.manual_seed(0) # deterministic forces + + class MockState: + def __init__(self, *, include_cell_forces: bool = True) -> None: + self.forces = torch.randn(n_atoms, 3, device=device, dtype=dtype) + self.batch = torch.repeat_interleave( + torch.arange(n_batches), n_atoms // n_batches + ) + self.device = device + self.dtype = dtype + self.n_batches = n_batches + if include_cell_forces: + self.cell_forces = torch.randn( + n_batches, 3, 3, device=device, dtype=dtype + ) + + return MockState + + +@pytest.mark.parametrize( + ("force_tol", "include_cell_forces", "has_cell_forces", "should_error"), + [ + (1e-2, True, True, False), # Standard case with cell forces + (1e-2, False, False, False), # Standard case without cell forces + (1e2, True, True, False), # High tolerance - should converge + (1e-6, True, True, False), # Low tolerance - may not converge + (1e-2, True, False, True), # Error case - cell forces required but missing + ], +) +def test_generate_force_convergence_fn( + *, + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + mock_state: Callable, + force_tol: float, + include_cell_forces: bool, + has_cell_forces: bool, + should_error: bool, +) -> None: + """Test generate_force_convergence_fn with various parameter combinations.""" + # Use mock state for error case, real state otherwise + if should_error: + state = mock_state(include_cell_forces=False) + else: + # Prepare real state + model_output = lj_model(ar_supercell_sim_state) + ar_supercell_sim_state.forces = model_output["forces"] + ar_supercell_sim_state.energy = model_output["energy"] + + if has_cell_forces: + ar_supercell_sim_state.cell_forces = torch.randn( + ar_supercell_sim_state.n_batches, + 3, + 3, + device=ar_supercell_sim_state.device, + dtype=ar_supercell_sim_state.dtype, + ) + state = ar_supercell_sim_state + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=include_cell_forces + ) + + if should_error: + with pytest.raises(ValueError, match="cell_forces not found in state"): + convergence_fn(state) + else: + result = convergence_fn(state) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.bool + assert result.shape == (state.n_batches,) + + +def test_generate_force_convergence_fn_tolerance_ordering( + ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test that higher tolerances are less restrictive than lower ones.""" + model_output = lj_model(ar_supercell_sim_state) + ar_supercell_sim_state.forces = model_output["forces"] + ar_supercell_sim_state.energy = model_output["energy"] + ar_supercell_sim_state.cell_forces = torch.randn( + ar_supercell_sim_state.n_batches, + 3, + 3, + device=ar_supercell_sim_state.device, + dtype=ar_supercell_sim_state.dtype, + ) + + tolerances = [1e-4, 1e-2, 1e0, 1e2] + results = [ + ts.generate_force_convergence_fn(force_tol=tol)(ar_supercell_sim_state) + for tol in tolerances + ] + + # If converged at lower tolerance, must be converged at higher tolerance + for idx in range(len(tolerances) - 1): + # Logical implication: results[idx] → results[idx + 1] + # Equivalent to: ~results[idx] | results[idx + 1] + implication = torch.logical_or(torch.logical_not(results[idx]), results[idx + 1]) + assert implication.all() + + +@pytest.mark.parametrize( + ("atomic_forces", "cell_forces", "force_tol", "expected_convergence"), + [ + ([0.05, 0.05], [0.05, 0.05], 0.1, [True, True]), # Both converged + ([0.15, 0.05], [0.05, 0.05], 0.1, [False, True]), # Only second converged + ([0.05, 0.05], [0.15, 0.05], 0.1, [False, True]), # Cell forces block first + ([0.15, 0.15], [0.15, 0.15], 0.1, [False, False]), # None converged + ], +) +def test_generate_force_convergence_fn_logic( + atomic_forces: list[float], + cell_forces: list[float], + force_tol: float, + expected_convergence: list[bool], +) -> None: + """Test convergence logic with controlled force values.""" + device, dtype = torch.device("cpu"), torch.float64 + n_batches, n_atoms = len(atomic_forces), 8 + + class ControlledMockState: + def __init__(self) -> None: + self.n_batches = n_batches + self.device, self.dtype = device, dtype + self.batch = torch.repeat_interleave( + torch.arange(n_batches), n_atoms // n_batches + ) + + # Set specific force magnitudes per batch + self.forces = torch.zeros(n_atoms, 3, device=device, dtype=dtype) + self.cell_forces = torch.zeros(n_batches, 3, 3, device=device, dtype=dtype) + + for batch_idx, (atomic_force, cell_force) in enumerate( + zip(atomic_forces, cell_forces, strict=False) + ): + batch_mask = self.batch == batch_idx + self.forces[batch_mask, 0] = atomic_force + self.cell_forces[batch_idx, 0, 0] = cell_force + + state = ControlledMockState() + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + result = convergence_fn(state) + + assert result.tolist() == expected_convergence + + +def test_generate_force_convergence_fn_ignores_last_energy( + ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test that convergence function ignores last_energy parameter.""" + model_output = lj_model(ar_supercell_sim_state) + ar_supercell_sim_state.forces = model_output["forces"] + ar_supercell_sim_state.energy = model_output["energy"] + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=1e-2, include_cell_forces=False + ) + + results = [ + convergence_fn(ar_supercell_sim_state), + convergence_fn(ar_supercell_sim_state, last_energy=torch.tensor([1.0])), + convergence_fn(ar_supercell_sim_state, last_energy=None), + ] + + # All results should be identical + assert all(torch.equal(results[0], result) for result in results[1:]) + + +def test_generate_force_convergence_fn_default_behavior( + mock_state: Callable, +) -> None: + """Test that default behavior includes cell forces.""" + state = mock_state(include_cell_forces=True) + # Set very small forces to ensure convergence + state.forces.fill_(0.01) + state.cell_forces.fill_(0.01) + + # Default and explicit should give same results + default_fn = ts.generate_force_convergence_fn(force_tol=0.1) + explicit_fn = ts.generate_force_convergence_fn( + force_tol=0.1, include_cell_forces=True + ) + + result_default = default_fn(state) + result_explicit = explicit_fn(state) + + assert torch.equal(result_default, result_explicit) + assert result_default.all() # Should converge with low forces diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 622cb643b..291034b7b 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -285,7 +285,7 @@ def _chunked_apply( def generate_force_convergence_fn( - force_tol: float = 1e-1, *, include_cell_forces: bool = False + force_tol: float = 1e-1, *, include_cell_forces: bool = True ) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. @@ -293,7 +293,7 @@ def generate_force_convergence_fn( Args: force_tol (float): Force tolerance for convergence include_cell_forces (bool): Whether to include the `cell_forces` in - the convergence check. + the convergence check. Defaults to True. Returns: Convergence function that takes a state and last energy and @@ -303,8 +303,13 @@ def generate_force_convergence_fn( def convergence_fn( state: SimState, last_energy: torch.Tensor | None = None, # noqa: ARG001 - ) -> bool: - """Check if the system has converged.""" + ) -> torch.Tensor: + """Check if the system has converged. + + Returns: + torch.Tensor: Boolean tensor of shape (n_batches,) indicating + convergence status for each batch. + """ force_conv = batchwise_max_force(state) < force_tol if include_cell_forces: @@ -334,8 +339,13 @@ def generate_energy_convergence_fn(energy_tol: float = 1e-3) -> Callable: def convergence_fn( state: SimState, last_energy: torch.Tensor | None = None, - ) -> bool: - """Check if the system has converged.""" + ) -> torch.Tensor: + """Check if the system has converged. + + Returns: + torch.Tensor: Boolean tensor of shape (n_batches,) indicating + convergence status for each batch. + """ return torch.abs(state.energy - last_energy) < energy_tol return convergence_fn