diff --git a/.gitignore b/.gitignore index f9d99f49f..d75c18368 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,10 @@ __pycache__ build/ dist/ *.egg-info + +# logging +*.log +*log.txt + +# model checkpoints +*.model diff --git a/examples/4_High_level_api/4.1_high_level_api.py b/examples/4_High_level_api/4.1_high_level_api.py index 90baba464..7febcd533 100644 --- a/examples/4_High_level_api/4.1_high_level_api.py +++ b/examples/4_High_level_api/4.1_high_level_api.py @@ -179,9 +179,8 @@ system=systems, model=mace_model, optimizer=unit_cell_fire, - convergence_fn=lambda state, last_energy: torch.all( - last_energy - state.energy < 1e-6 * MetalUnits.energy - ), + convergence_fn=lambda state, last_energy: last_energy - state.energy + < 1e-6 * MetalUnits.energy, max_steps=10 if os.getenv("CI") else 1000, ) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py new file mode 100644 index 000000000..19f3cb816 --- /dev/null +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -0,0 +1,137 @@ +"""Examples of using the auto-batching API.""" + +# /// script +# dependencies = [ +# "mace-torch>=0.3.10", +# ] +# /// + +"""Run as a interactive script.""" +# ruff: noqa: E402 + + +# %% +import os + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +from torchsim.autobatching import ( + ChunkingAutoBatcher, + HotSwappingAutoBatcher, + calculate_memory_scaler, + split_state, +) +from torchsim.integrators import nvt_langevin +from torchsim.models.mace import MaceModel +from torchsim.optimizers import unit_cell_fire +from torchsim.runners import atoms_to_state +from torchsim.state import BaseState +from torchsim.units import MetalUnits + + +if not torch.cuda.is_available(): + raise SystemExit(0) + +si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) +fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) + +device = torch.device("cuda") + +mace = mace_mp(model="small", return_raw_model=True) +mace_model = MaceModel( + model=mace, + device=device, + periodic=True, + dtype=torch.float64, + compute_force=True, +) + +si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64) +fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64) + +fire_init, fire_update = unit_cell_fire(mace_model) + +si_fire_state = fire_init(si_state) +fe_fire_state = fire_init(fe_state) + +fire_states = [si_fire_state, fe_fire_state] * (2 if os.getenv("CI") else 20) +fire_states = [state.clone() for state in fire_states] +for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + +len(fire_states) + + +# %% run hot swapping autobatcher +def convergence_fn(state: BaseState) -> bool: + """Check if the system has converged.""" + batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) + max_forces = state.forces.norm(dim=1) + batch_wise_max_force = batch_wise_max_force.scatter_reduce( + dim=0, + index=state.batch, + src=max_forces, + reduce="amax", + ) + return batch_wise_max_force < 1e-1 + + +single_system_memory = calculate_memory_scaler(fire_states[0]) +batcher = HotSwappingAutoBatcher( + model=mace_model, + states=fire_states, + memory_scales_with="n_atoms_x_density", + max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, +) + +all_completed_states, convergence_tensor = [], None +while True: + print(f"Starting new batch of {state.n_batches} states.") + + state, completed_states = batcher.next_batch(state, convergence_tensor) + print("Number of completed states", len(completed_states)) + + all_completed_states.extend(completed_states) + if state is None: + break + + # run 10 steps, arbitrary number + for _step in range(10): + state = fire_update(state) + convergence_tensor = convergence_fn(state) + + +# %% run chunking autobatcher +nvt_init, nvt_update = nvt_langevin( + model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature +) + + +si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64) +fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64) + +si_nvt_state = nvt_init(si_state) +fe_nvt_state = nvt_init(fe_state) + +nvt_states = [si_nvt_state, fe_nvt_state] * (2 if os.getenv("CI") else 20) +nvt_states = [state.clone() for state in nvt_states] +for state in nvt_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + +single_system_memory = calculate_memory_scaler(fire_states[0]) +batcher = ChunkingAutoBatcher( + model=mace_model, + states=nvt_states, + memory_scales_with="n_atoms_x_density", + max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, +) + +finished_states = [] +for batch in batcher: + for _ in range(100): + batch = nvt_update(batch) + + finished_states.extend(split_state(batch)) diff --git a/pyproject.toml b/pyproject.toml index 23b2b43ae..54e72b592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ requires-python = ">=3.11" dependencies = [ "ase>=3.24", + "binpacking>=1.5.2", "h5py>=3.12.1", "numpy>=1.26", "tables>=3.10.2", diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py new file mode 100644 index 000000000..25b61fd80 --- /dev/null +++ b/tests/test_autobatching.py @@ -0,0 +1,377 @@ +from typing import Any + +import pytest +import torch + +from torchsim.autobatching import ( + ChunkingAutoBatcher, + HotSwappingAutoBatcher, + calculate_memory_scaler, + determine_max_batch_size, +) +from torchsim.models.lennard_jones import LennardJonesModel +from torchsim.optimizers import unit_cell_fire +from torchsim.state import BaseState, split_state + + +def test_calculate_scaling_metric(si_base_state: BaseState) -> None: + """Test calculation of scaling metrics for a state.""" + # Test n_atoms metric + n_atoms_metric = calculate_memory_scaler(si_base_state, "n_atoms") + assert n_atoms_metric == si_base_state.n_atoms + + # Test n_atoms_x_density metric + density_metric = calculate_memory_scaler(si_base_state, "n_atoms_x_density") + volume = torch.abs(torch.linalg.det(si_base_state.cell[0])) / 1000 + expected = si_base_state.n_atoms * (si_base_state.n_atoms / volume.item()) + assert pytest.approx(density_metric, rel=1e-5) == expected + + # Test invalid metric + with pytest.raises(ValueError, match="Invalid metric"): + calculate_memory_scaler(si_base_state, "invalid_metric") + + +def test_split_state(si_double_base_state: BaseState) -> None: + """Test splitting a batched state into individual states.""" + split_states = split_state(si_double_base_state) + + # Check we get the right number of states + assert len(split_states) == 2 + + # Check each state has the correct properties + for state in enumerate(split_states): + assert state[1].n_batches == 1 + assert torch.all( + state[1].batch == 0 + ) # Each split state should have batch indices reset to 0 + assert state[1].n_atoms == si_double_base_state.n_atoms // 2 + assert state[1].positions.shape[0] == si_double_base_state.n_atoms // 2 + assert state[1].cell.shape[0] == 1 + + +def test_chunking_auto_batcher( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test ChunkingAutoBatcher with different states.""" + # Create a list of states with different sizes + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric to avoid GPU memory testing + batcher = ChunkingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, # Set a small value to force multiple batches + ) + + # Check that the batcher correctly identified the metrics + assert len(batcher.memory_scalers) == 2 + assert batcher.memory_scalers[0] == si_base_state.n_atoms + assert batcher.memory_scalers[1] == fe_fcc_state.n_atoms + + # Get batches until None is returned + batches = list(batcher) + + # Check we got the expected number of batches + assert len(batches) == len(batcher.batched_states) + + # Test restore_original_order + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + + # Check that the restored states match the original states in order + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + +def test_chunking_auto_batcher_with_indices( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test ChunkingAutoBatcher with return_indices=True.""" + states = [si_base_state, fe_fcc_state] + + batcher = ChunkingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + return_indices=True, + ) + + # Get batches with indices + batches_with_indices = [] + for batch, indices in batcher: + batches_with_indices.append((batch, indices)) + + # Check we got the expected number of batches + assert len(batches_with_indices) == len(batcher.batched_states) + + # Check that the indices match the expected bin indices + for i, (_, indices) in enumerate(batches_with_indices): + assert indices == batcher.index_bins[i] + + +def test_chunking_auto_batcher_restore_order_with_split_states( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test ChunkingAutoBatcher's restore_original_order method with split states.""" + # Create a list of states with different sizes + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric to avoid GPU memory testing + batcher = ChunkingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, # Set a small value to force multiple batches + ) + + # Get batches until None is returned + batches = [] + while True: + batch = batcher.next_batch() + if batch is None: + break + # Split each batch into individual states to simulate processing + # split_batch = split_state(batch) + batches.append(batch) + + # Test restore_original_order with split states + # This tests the chain.from_iterable functionality + restored_states = batcher.restore_original_order(batches) + + # Check we got the right number of states back + assert len(restored_states) == len(states) + + # Check that the restored states match the original states in order + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + +def test_hot_swapping_max_metric_too_small( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test HotSwappingAutoBatcher with different states.""" + # Create a list of states + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric + batcher = HotSwappingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=1.0, # Set a small value to force multiple batches + ) + + # Get the first batch + with pytest.raises(ValueError, match="is greater than max_metric"): + batcher.next_batch(None, None) + + +def test_hot_swapping_auto_batcher( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test HotSwappingAutoBatcher with different states.""" + # Create a list of states + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric + batcher = HotSwappingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260, # Set a small value to force multiple batches + ) + + # Get the first batch + first_batch, [] = batcher.next_batch(states, None) + assert isinstance(first_batch, BaseState) + + # Create a convergence tensor where the first state has converged + convergence = torch.tensor([True]) + + # Get the next batch + next_batch, popped_batch, idx = batcher.next_batch( + first_batch, convergence, return_indices=True + ) + assert isinstance(next_batch, BaseState) + assert isinstance(popped_batch, list) + assert isinstance(popped_batch[0], BaseState) + assert idx == [1] + + # Check that the converged state was removed + assert len(batcher.current_scalers) == 1 + assert len(batcher.current_idx) == 1 + assert len(batcher.completed_idx_og_order) == 1 + + # Create a convergence tensor where the remaining state has converged + convergence = torch.tensor([True]) + + # Get the next batch, which should be None since all states have converged + final_batch, popped_batch = batcher.next_batch(next_batch, convergence) + assert final_batch is None + + # Check that all states are marked as completed + assert len(batcher.completed_idx_og_order) == 2 + + +def test_determine_max_batch_size_fibonacci( + si_base_state: BaseState, lj_calculator: LennardJonesModel, monkeypatch: Any +) -> None: + """Test that determine_max_batch_size uses Fibonacci sequence correctly.""" + + # Mock measure_model_memory_forward to avoid actual GPU memory testing + def mock_measure(*_args: Any, **_kwargs: Any) -> float: + return 0.1 # Return a small constant memory usage + + monkeypatch.setattr( + "torchsim.autobatching.measure_model_memory_forward", mock_measure + ) + + # Test with a small max_atoms value to limit the sequence + max_size = determine_max_batch_size(si_base_state, lj_calculator, max_atoms=10) + + # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] + # Since we're not triggering OOM errors with our mock, it should + # return the largest value < max_atoms + assert max_size == 8 + + +def test_hot_swapping_auto_batcher_restore_order( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test HotSwappingAutoBatcher's restore_original_order method.""" + states = [si_base_state, fe_fcc_state] + + batcher = HotSwappingAutoBatcher( + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + + # Get the first batch + first_batch, [] = batcher.next_batch(states, None) + + # Simulate convergence of all states + completed_states_list = [] + convergence = torch.tensor([True]) + next_batch, completed_states = batcher.next_batch(first_batch, convergence) + completed_states_list.extend(completed_states) + + # sample batch a second time + # sample batch a second time + next_batch, completed_states = batcher.next_batch(next_batch, convergence) + completed_states_list.extend(completed_states) + + # Test restore_original_order + restored_states = batcher.restore_original_order(completed_states_list) + assert len(restored_states) == 2 + + # Check that the restored states match the original states in order + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + # # Test error when number of states doesn't match + # with pytest.raises( + # ValueError, match="Number of completed states .* does not match" + # ): + # batcher.restore_original_order([si_base_state]) + + +def test_hot_swapping_with_fire( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + fire_init, fire_update = unit_cell_fire(lj_calculator) + + si_fire_state = fire_init(si_base_state) + fe_fire_state = fire_init(fe_fcc_state) + + fire_states = [si_fire_state, fe_fire_state] * 5 + fire_states = [state.clone() for state in fire_states] + for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = HotSwappingAutoBatcher( + states=fire_states, + model=lj_calculator, + memory_scales_with="n_atoms", + # max_metric=400_000, + max_memory_scaler=600, + ) + + def convergence_fn(state: BaseState) -> bool: + batch_wise_max_force = torch.zeros( + state.n_batches, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + batch_wise_max_force = batch_wise_max_force.scatter_reduce( + dim=0, + index=state.batch, + src=max_forces, + reduce="amax", + ) + return batch_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + print(f"Starting new batch of {state.n_batches} states.") + + state, completed_states = batcher.next_batch(state, convergence_tensor) + print("Number of completed states", len(completed_states)) + + all_completed_states.extend(completed_states) + if state is None: + break + + # run 10 steps, arbitrary number + for _ in range(10): + state = fire_update(state) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(fire_states) + + +def test_chunking_auto_batcher_with_fire( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + fire_init, fire_update = unit_cell_fire(lj_calculator) + + si_fire_state = fire_init(si_base_state) + fe_fire_state = fire_init(fe_fcc_state) + + fire_states = [si_fire_state, fe_fire_state] * 5 + fire_states = [state.clone() for state in fire_states] + for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = ChunkingAutoBatcher( + states=fire_states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=400, + ) + + finished_states = [] + for batch in batcher: + for _ in range(5): + batch = fire_update(batch) + + finished_states.extend(split_state(batch)) + + restored_states = batcher.restore_original_order(finished_states) + assert len(restored_states) == len(fire_states) + for restored, original in zip(restored_states, fire_states, strict=True): + assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/tests/test_state.py b/tests/test_state.py index 27b5526ba..e77fa30fe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -6,7 +6,9 @@ BaseState, concatenate_states, infer_property_scope, + pop_states, slice_substate, + split_state, ) from torchsim.unbatched_integrators import MDState @@ -195,3 +197,56 @@ def test_concatenate_double_si_and_fe_states( si_slice_1.positions, slice_substate(si_double_base_state, 1).positions ) assert torch.allclose(fe_slice.positions, fe_fcc_state.positions) + + +def test_split_state(si_double_base_state: BaseState) -> None: + """Test splitting a state into a list of states.""" + states = split_state(si_double_base_state) + assert len(states) == si_double_base_state.n_batches + for state in states: + assert isinstance(state, BaseState) + assert state.positions.shape == (8, 3) + assert state.masses.shape == (8,) + assert state.cell.shape == (1, 3, 3) + assert state.atomic_numbers.shape == (8,) + assert torch.allclose(state.batch, torch.zeros_like(state.batch)) + + +def test_split_many_states( + si_base_state: BaseState, ar_base_state: BaseState, fe_fcc_state: BaseState +) -> None: + """Test splitting a state into a list of states.""" + states = [si_base_state, ar_base_state, fe_fcc_state] + concatenated = concatenate_states(states) + split_states = split_state(concatenated) + for state, sub_state in zip(states, split_states, strict=True): + assert isinstance(sub_state, BaseState) + assert torch.allclose(sub_state.positions, state.positions) + assert torch.allclose(sub_state.masses, state.masses) + assert torch.allclose(sub_state.cell, state.cell) + assert torch.allclose(sub_state.atomic_numbers, state.atomic_numbers) + assert torch.allclose(sub_state.batch, state.batch) + + assert len(states) == 3 + + +def test_pop_states( + si_base_state: BaseState, ar_base_state: BaseState, fe_fcc_state: BaseState +) -> None: + """Test popping states from a state.""" + states = [si_base_state, ar_base_state, fe_fcc_state] + concatenated_states = concatenate_states(states) + kept_state, popped_states = pop_states(concatenated_states, torch.tensor([0])) + + assert isinstance(kept_state, BaseState) + assert isinstance(popped_states, list) + assert len(popped_states) == 1 + assert isinstance(popped_states[0], BaseState) + assert popped_states[0].positions.shape == si_base_state.positions.shape + + len_kept = ar_base_state.n_atoms + fe_fcc_state.n_atoms + assert kept_state.positions.shape == (len_kept, 3) + assert kept_state.masses.shape == (len_kept,) + assert kept_state.cell.shape == (2, 3, 3) + assert kept_state.atomic_numbers.shape == (len_kept,) + assert kept_state.batch.shape == (len_kept,) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py new file mode 100644 index 000000000..6842b974b --- /dev/null +++ b/torchsim/autobatching.py @@ -0,0 +1,551 @@ +"""Utilities for batching and memory management in torchsim.""" + +from collections.abc import Iterator +from itertools import chain +from typing import Literal + +import binpacking +import torch + +from torchsim.models.interface import ModelInterface +from torchsim.state import BaseState, concatenate_states, pop_states, split_state + + +def measure_model_memory_forward(state: BaseState, model: ModelInterface) -> float: + """Measure peak GPU memory usage during a model's forward pass. + + Clears GPU cache, runs a forward pass with the provided state, and measures + the maximum memory allocated during execution. + + Args: + state: Input state to pass to the model. + model: Model to measure memory usage for. + + Returns: + Peak memory usage in gigabytes. + """ + # assert model device is not cpu + if model.device.type == "cpu": + raise ValueError( + "Memory estimation does not make sense on CPU and is unsupported." + ) + + # Clear GPU memory + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.reset_peak_memory_stats() + + model( + positions=state.positions, + cell=state.cell, + batch=state.batch, + atomic_numbers=state.atomic_numbers, + ) + + return torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB + + +def determine_max_batch_size( + state: BaseState, model: ModelInterface, max_atoms: int = 500_000 +) -> int: + """Determine maximum batch size that fits in GPU memory. + + Uses a Fibonacci sequence to efficiently search for the largest number of + batches that can be processed without running out of GPU memory. + + Args: + state: Base state to replicate for testing. + model: Model to test with. + max_atoms: Upper limit on number of atoms to try (for safety). + + Returns: + Maximum number of batches that fit in GPU memory. + """ + # create a list of integers following the fibonacci sequence + fib = [1, 2] + while fib[-1] < max_atoms: + fib.append(fib[-1] + fib[-2]) + + for i in range(len(fib)): + n_batches = fib[i] + concat_state = concatenate_states([state] * n_batches) + + try: + measure_model_memory_forward(concat_state, model) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + return fib[i - 2] + raise + + return fib[-2] + + +def calculate_memory_scaler( + state: BaseState, + memory_scales_with: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", +) -> float: + """Calculate a metric that estimates memory requirements for a state. + + Provides different scaling metrics based on system properties that correlate + with memory usage. + + Args: + state: State to calculate metric for. + memory_scales_with: Type of metric to use: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density + + Returns: + Calculated metric value. + """ + if state.n_batches > 1: + raise ValueError("State must be a single batch") + if memory_scales_with == "n_atoms": + return state.n_atoms + if memory_scales_with == "n_atoms_x_density": + volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + number_density = state.n_atoms / volume.item() + return state.n_atoms * number_density + raise ValueError(f"Invalid metric: {memory_scales_with}") + + +def estimate_max_memory_scaler( + model: ModelInterface, + state_list: list[BaseState], + metric_values: list[float], + max_atoms: int = 500_000, +) -> float: + """Estimate maximum memory scaling metric that fits in GPU memory. + + Tests both minimum and maximum metric states to determine a safe upper bound + for the memory scaling metric. + + Args: + model: Model to test with. + state_list: List of states to test. + metric_values: Corresponding metric values for each state. + max_atoms: Maximum number of atoms to try. + + Returns: + Maximum safe metric value that fits in GPU memory. + """ + metric_values = torch.tensor(metric_values) + + # select one state with the min n_atoms + min_metric = metric_values.min() + max_metric = metric_values.max() + + min_state = state_list[metric_values.argmin()] + max_state = state_list[metric_values.argmax()] + + min_state_max_batches = determine_max_batch_size(min_state, model, max_atoms) + max_state_max_batches = determine_max_batch_size(max_state, model, max_atoms) + + return min(min_state_max_batches * min_metric, max_state_max_batches * max_metric) + + +class ChunkingAutoBatcher: + """Batcher that groups states into bins of similar computational cost. + + Divides a collection of states into batches that can be processed efficiently + without exceeding GPU memory. States are grouped based on a memory scaling + metric to maximize GPU utilization. + """ + + def __init__( + self, + states: list[BaseState] | BaseState, + model: ModelInterface, + *, + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + max_memory_scaler: float | None = None, + max_atoms_to_try: int = 500_000, + return_indices: bool = False, + ) -> None: + """Initialize the chunking auto-batcher. + + Args: + states: Collection of states to batch (either a list or a single state + that will be split). + model: Model to batch for, used to estimate memory requirements. + memory_scales_with: Metric to use for estimating memory requirements: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density + max_memory_scaler: Maximum metric value allowed per batch. If None, + will be automatically estimated. + max_atoms_to_try: Maximum number of atoms to try when estimating + max_memory_scaler. + return_indices: Whether to return original indices along with batches. + """ + self.state_slices = ( + split_state(states) if isinstance(states, BaseState) else states + ) + self.memory_scalers = [ + calculate_memory_scaler(state_slice, memory_scales_with) + for state_slice in self.state_slices + ] + if not max_memory_scaler: + self.max_memory_scaler = estimate_max_memory_scaler( + model, self.state_slices, self.memory_scalers, max_atoms_to_try + ) + print(f"Max metric calculated: {self.max_memory_scaler}") + else: + self.max_memory_scaler = max_memory_scaler + + self.return_indices = return_indices + # verify that no systems are too large + max_metric_value = max(self.memory_scalers) + max_metric_idx = self.memory_scalers.index(max_metric_value) + if max_metric_value > self.max_memory_scaler: + raise ValueError( + f"Max metric of system with index {max_metric_idx} in states: " + f"{max(self.memory_scalers)} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." + ) + + self.index_to_scaler = dict(enumerate(self.memory_scalers)) + self.index_bins = binpacking.to_constant_volume( + self.index_to_scaler, V_max=self.max_memory_scaler + ) + self.batched_states = [] + for index_bin in self.index_bins: + self.batched_states.append([self.state_slices[i] for i in index_bin]) + self.current_state_bin = 0 + + def next_batch( + self, *, return_indices: bool = False + ) -> BaseState | tuple[list[BaseState], list[int]] | None: + """Get the next batch of states. + + Returns batches sequentially until all states have been processed. + + Args: + return_indices: Whether to return original indices along with the batch. + Overrides the value set during initialization. + + Returns: + - If return_indices is False: The next batch of states, + or None if no more batches. + - If return_indices is True: Tuple of (batch, indices), + or None if no more batches. + """ + # TODO: need to think about how this intersects with reporting too + # TODO: definitely a clever treatment to be done with iterators here + if self.current_state_bin < len(self.batched_states): + state_bin = self.batched_states[self.current_state_bin] + state = concatenate_states(state_bin) + self.current_state_bin += 1 + if return_indices: + return state, self.index_bins[self.current_state_bin - 1] + return state + return None + + def __iter__(self) -> Iterator[BaseState]: + """Return self as an iterator. + + Allows using the batcher in a for loop. + + Returns: + Self as an iterator. + """ + return self + + def __next__(self) -> BaseState: + """Get the next batch for iteration. + + Implements the iterator protocol to allow using the batcher in a for loop. + + Returns: + The next batch of states. + + Raises: + StopIteration: When there are no more batches. + """ + next_batch = self.next_batch(return_indices=self.return_indices) + if next_batch is None: + raise StopIteration + return next_batch + + def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: + """Reorder processed states back to their original sequence. + + Takes states that were processed in batches and restores them to the + original order they were provided in. + + Args: + batched_states: List of state batches to reorder. + + Returns: + States in their original order. + + Raises: + ValueError: If the number of states doesn't match + the number of original indices. + """ + state_bins = [split_state(state) for state in batched_states] + + # Flatten lists + all_states = list(chain.from_iterable(state_bins)) + original_indices = list(chain.from_iterable(self.index_bins)) + + if len(all_states) != len(original_indices): + raise ValueError( + f"Number of states ({len(all_states)}) does not match " + f"number of original indices ({len(original_indices)})" + ) + + # sort states by original indices + indexed_states = list(zip(original_indices, all_states, strict=True)) + return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] + + +class HotSwappingAutoBatcher: + """Batcher that dynamically swaps states based on convergence. + + Optimizes GPU utilization by removing converged states from the batch and + adding new states to process. This approach is ideal for iterative processes + where different states may converge at different rates. + """ + + def __init__( + self, + states: list[BaseState] | Iterator[BaseState] | BaseState, + model: ModelInterface, + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + max_memory_scaler: float | None = None, + max_atoms_to_try: int = 500_000, + ) -> None: + """Initialize the hot-swapping auto-batcher. + + Args: + states: Collection of states to process (list, iterator, or single state + that will be split). + model: Model to batch for, used to estimate memory requirements. + memory_scales_with: Metric to use for estimating memory requirements: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density + max_memory_scaler: Maximum metric value allowed per batch. If None, + will be automatically estimated. + max_atoms_to_try: Maximum number of atoms to try when estimating + max_memory_scaler. + """ + if isinstance(states, BaseState): + states = split_state(states) + if isinstance(states, list): + states = iter(states) + + self.model = model + self.states_iterator = states + self.memory_scales_with = memory_scales_with + self.max_memory_scaler = max_memory_scaler or None + self.max_atoms_to_try = max_atoms_to_try + + self.current_scalers = [] + self.current_idx = [] + self.iterator_idx = 0 + + self.completed_idx_og_order = [] + + def _get_next_states(self) -> None: + """Add states from the iterator until max_memory_scaler is reached. + + Pulls states from the iterator and adds them to the current batch until + adding another would exceed the maximum memory scaling metric. + + Returns: + List of new states added to the batch. + """ + new_metrics = [] + new_idx = [] + new_states = [] + for state in self.states_iterator: + metric = calculate_memory_scaler(state, self.memory_scales_with) + if metric > self.max_memory_scaler: + raise ValueError( + f"State metric {metric} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." + ) + # new_metric += sum(new_metrics) + if ( + sum(self.current_scalers) + sum(new_metrics) + metric + > self.max_memory_scaler + ): + # put the state back in the iterator + self.states_iterator = chain([state], self.states_iterator) + break + + new_metrics.append(metric) + new_idx.append(self.iterator_idx) + new_states.append(state) + self.iterator_idx += 1 + + self.current_scalers.extend(new_metrics) + self.current_idx.extend(new_idx) + + return new_states + + def _delete_old_states(self, completed_idx: list[int]) -> None: + """Remove completed states from tracking lists. + + Updates internal tracking of states and their metrics when states are + completed and removed from processing. + + Args: + completed_idx: Indices of completed states to remove. + """ + # Sort in descending order to avoid index shifting problems + completed_idx.sort(reverse=True) + + # update state tracking lists + for idx in completed_idx: + og_idx = self.current_idx.pop(idx) + self.current_scalers.pop(idx) + self.completed_idx_og_order.append(og_idx) + + def _first_batch(self) -> BaseState: + """Create and return the first batch of states. + + Initializes the batcher by estimating memory requirements if needed + and creating the first batch of states to process. + + Returns: + Tuple of (first batch, empty list of completed states). + """ + # we need to sample a state and use it to estimate the max metric + # for the first batch + first_state = next(self.states_iterator) + first_metric = calculate_memory_scaler(first_state, self.memory_scales_with) + self.current_scalers += [first_metric] + self.current_idx += [0] + self.iterator_idx += 1 + # self.total_metric += first_metric + + # if max_metric is not set, estimate it + has_max_metric = bool(self.max_memory_scaler) + if not has_max_metric: + self.max_memory_scaler = estimate_max_memory_scaler( + self.model, + [first_state], + [first_metric], + max_atoms=self.max_atoms_to_try, + ) + self.max_memory_scaler *= 0.8 + + states = self._get_next_states() + + if not has_max_metric: + self.max_memory_scaler = estimate_max_memory_scaler( + self.model, + [first_state, *states], + self.current_scalers, + max_atoms=self.max_atoms_to_try, + ) + print(f"Max metric calculated: {self.max_memory_scaler}") + return concatenate_states([first_state, *states]), [] + + def next_batch( + self, + updated_state: BaseState | None, + convergence_tensor: torch.Tensor | None, + *, + return_indices: bool = False, + ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: + """Get the next batch of states based on convergence. + + Removes converged states from the batch, adds new states if possible, + and returns both the updated batch and the completed states. + + Args: + updated_state: Current state after processing. + convergence_tensor: Boolean tensor indicating which states have converged. + If None, assumes this is the first call. + return_indices: Whether to return original indices along with the batch. + + Returns: + - If return_indices is False: Tuple of (next_batch, completed_states) + - If return_indices is True: Tuple of (next_batch, completed_states, indices) + + When no states remain to process, next_batch will be None. + """ + # TODO: this needs to be refactored to avoid so + # many split and concatenate operations, we should + # take the updated_concat_state and pop off + # the states that have converged. with the pop_states function + + if convergence_tensor is None or updated_state is None: + if self.iterator_idx > 0: + raise ValueError( + "A convergence tensor must be provided after the " + "first batch has been run." + ) + return self._first_batch() + + # assert statements helpful for debugging, should be moved to validate fn + # the first two are most important + assert len(convergence_tensor) == updated_state.n_batches + assert len(self.current_idx) == len(self.current_scalers) + assert len(convergence_tensor.shape) == 1 + assert updated_state.n_batches > 0 + + completed_idx = torch.where(convergence_tensor)[0].tolist() + completed_idx.sort(reverse=True) + + remaining_state, completed_states = pop_states(updated_state, completed_idx) + + self._delete_old_states(completed_idx) + next_states = self._get_next_states() + + # there are no states left to run, return the completed states + if not self.current_idx: + return ( + (None, completed_states, []) + if return_indices + else (None, completed_states) + ) + + # concatenate remaining state with next states + if remaining_state.n_batches > 0: + next_states = [remaining_state, *next_states] + next_batch = concatenate_states(next_states) + + if return_indices: + return next_batch, completed_states, self.current_idx + + return next_batch, completed_states + + def restore_original_order( + self, completed_states: list[BaseState] + ) -> list[BaseState]: + """Reorder completed states back to their original sequence. + + Takes states that were completed in arbitrary order and restores them + to the original order they were provided in. + + Args: + completed_states: List of completed states to reorder. + + Returns: + States in their original order. + + Raises: + ValueError: If the number of completed states doesn't match the + number of completed indices. + """ + # TODO: should act on full states, not state slices + + if len(completed_states) != len(self.completed_idx_og_order): + raise ValueError( + f"Number of completed states ({len(completed_states)}) does not match " + f"number of completed indices ({len(self.completed_idx_og_order)})" + ) + + # Create pairs of (original_index, state) + indexed_states = list( + zip(self.completed_idx_og_order, completed_states, strict=True) + ) + + # Sort by original index + return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] diff --git a/torchsim/runners.py b/torchsim/runners.py index 01eaff823..9e0005ac0 100644 --- a/torchsim/runners.py +++ b/torchsim/runners.py @@ -113,7 +113,8 @@ def optimize( system: Input system to optimize (ASE Atoms, Pymatgen Structure, or BaseState) model: Neural network calculator module optimizer: Optimization algorithm function - convergence_fn: Condition for convergence + convergence_fn: Condition for convergence, should return a boolean tensor + of length n_batches unit_system: Unit system for energy tolerance optimizer_kwargs: Additional keyword arguments for optimizer trajectory_reporter: Optional reporter for tracking optimization trajectory @@ -126,7 +127,7 @@ def optimize( if convergence_fn is None: def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: - return torch.all(last_energy - state.energy < 1e-6 * unit_system.energy) + return last_energy - state.energy < 1e-6 * unit_system.energy # we partially evaluate the function to create a new function with # an optional second argument, this can be set to state later on @@ -144,7 +145,7 @@ def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: step: int = 1 last_energy = state.energy + 1 - while not convergence_fn(state, last_energy): + while not torch.all(convergence_fn(state, last_energy)): last_energy = state.energy state = update_fn(state) diff --git a/torchsim/state.py b/torchsim/state.py index ed1b57cc1..f5bdec42c 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -263,6 +263,117 @@ def infer_property_scope( return scope +def split_state( + state: BaseState, + ambiguous_handling: Literal["error", "globalize"] = "error", +) -> list[BaseState]: + """Split a state into a list of states, each containing a single batch element. + This also needs to be optimized. + """ + # TODO: make this more efficient + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) + + batch_sizes = torch.bincount(state.batch).tolist() + + global_attrs = {} + + # Process global properties (unchanged) + for attr_name in scope["global"]: + global_attrs[attr_name] = getattr(state, attr_name) + + sliced_attrs = {} + + # Process per-atom properties (filter by batch mask) + for attr_name in scope["per_atom"]: + if attr_name == "batch": + continue + attr_value = getattr(state, attr_name) + sliced_attrs[attr_name] = torch.split(attr_value, batch_sizes, dim=0) + + # Process per-batch properties (select the specific batch) + for attr_name in scope["per_batch"]: + attr_value = getattr(state, attr_name) + sliced_attrs[attr_name] = torch.split(attr_value, 1, dim=0) + + states = [] + for i in range(state.n_batches): + state = type(state)( + batch=torch.zeros(batch_sizes[i], device=state.device, dtype=torch.int64), + **{attr_name: sliced_attrs[attr_name][i] for attr_name in sliced_attrs}, + **global_attrs, + ) + states.append(state) + + return states + + +def pop_states( + state: BaseState, + pop_indices: list[int], + ambiguous_handling: Literal["error", "globalize"] = "error", +) -> tuple[BaseState, list[BaseState]]: + """Pop off the states with masking in a way that + minimizes memory operations. We can use the mask to make the popped + and remaining states in place then split the popped states. + + Infer batchwise atomwise should also be optimized. + """ + if len(pop_indices) == 0: + return state, [] + + pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) + + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) + + # Process global properties (unchanged) + global_attrs = {} + for attr_name in scope["global"]: + global_attrs[attr_name] = getattr(state, attr_name) + + keep_attrs = {} + pop_attrs = {} + + # Process per-atom properties (filter by batch mask) + for attr_name in scope["per_atom"]: + keep_mask = torch.isin(state.batch, pop_indices, invert=True) + attr_value = getattr(state, attr_name) + + if attr_name == "batch": + n_popped = len(pop_indices) + n_kept = state.n_batches - n_popped + _, keep_counts = torch.unique_consecutive( + attr_value[keep_mask], return_counts=True + ) + keep_batch_indices = torch.repeat_interleave( + torch.arange(n_kept, device=state.device), keep_counts + ) + keep_attrs[attr_name] = keep_batch_indices + + _, pop_counts = torch.unique_consecutive( + attr_value[~keep_mask], return_counts=True + ) + pop_batch_indices = torch.repeat_interleave( + torch.arange(n_popped, device=state.device), pop_counts + ) + pop_attrs[attr_name] = pop_batch_indices + continue + + keep_attrs[attr_name] = attr_value[keep_mask] + pop_attrs[attr_name] = attr_value[~keep_mask] + + # Process per-batch properties (select the specific batch) + for attr_name in scope["per_batch"]: + attr_value = getattr(state, attr_name) + batch_range = torch.arange(state.n_batches, device=state.device) + keep_mask = torch.isin(batch_range, pop_indices, invert=True) + keep_attrs[attr_name] = attr_value[keep_mask] + pop_attrs[attr_name] = attr_value[~keep_mask] + + keep_state = type(state)(**keep_attrs, **global_attrs) + pop_states = split_state(type(state)(**pop_attrs, **global_attrs)) + return keep_state, pop_states + + def slice_substate( state: BaseState, batch_index: int, @@ -278,6 +389,8 @@ def slice_substate( Returns: A BaseState object containing the sliced substate """ + # TODO: should share more logic with pop_states, basically the same + # TODO: should be renamed slice_state scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) # Create a mask for the atoms in the specified batch @@ -310,7 +423,7 @@ def slice_substate( return type(state)(**sliced_attrs) -def concatenate_states( # noqa: C901 +def concatenate_states( states: list[BaseState], device: torch.device | None = None ) -> BaseState: """Concatenate a list of BaseStates into a single BaseState. @@ -336,57 +449,57 @@ def concatenate_states( # noqa: C901 if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") - # Categorize properties by scope for each state - property_scopes = [infer_property_scope(state) for state in states] - - # Collect all property names across all states - all_props = set() - for scope in property_scopes: - for scope_type in scope.values(): - all_props.update(scope_type) - - # Initialize dictionaries to hold concatenated properties - concatenated = {} + # Use the target device or default to the first state's device + target_device = device or first_state.device - # Process global properties (take from first state) - for prop_name in property_scopes[0]["global"]: - concatenated[prop_name] = getattr(first_state, prop_name) + # Get property scopes from the first state to identify + # global/per-atom/per-batch properties + first_scope = infer_property_scope(first_state) + global_props = set(first_scope["global"]) + per_atom_props = set(first_scope["per_atom"]) + per_batch_props = set(first_scope["per_batch"]) - # Process per-atom properties (concatenate) - for prop_name in set().union(*[scope["per_atom"] for scope in property_scopes]): - tensors = [ - getattr(state, prop_name) for state in states if hasattr(state, prop_name) - ] - if tensors: - concatenated[prop_name] = torch.cat(tensors, dim=0) - - # Process per-batch properties (concatenate) - for prop_name in set().union(*[scope["per_batch"] for scope in property_scopes]): - tensors = [ - getattr(state, prop_name) for state in states if hasattr(state, prop_name) - ] - if tensors: - concatenated[prop_name] = torch.cat(tensors, dim=0) + # Initialize result with global properties from first state + concatenated = {prop: getattr(first_state, prop) for prop in global_props} - # Create new batch indices that account for existing batch structure + # Pre-allocate lists for tensors to concatenate + per_atom_tensors = {prop: [] for prop in per_atom_props} + per_batch_tensors = {prop: [] for prop in per_batch_props} new_batch_indices = [] batch_offset = 0 - device = device or states[0].device + # Process all states in a single pass for state in states: - state = state_to_device(state, device) - - # Get the number of unique batches in this state - num_batches = len(torch.unique(state.batch)) - - # For each atom, map its current batch index to a new index with the offset + # Move state to target device if needed + if state.device != target_device: + state = state_to_device(state, target_device) + + # Collect per-atom properties + for prop in per_atom_props: + # if hasattr(state, prop): + per_atom_tensors[prop].append(getattr(state, prop)) + + # Collect per-batch properties + for prop in per_batch_props: + # if hasattr(state, prop): + per_batch_tensors[prop].append(getattr(state, prop)) + + # Update batch indices + num_batches = state.n_batches new_indices = state.batch + batch_offset new_batch_indices.append(new_indices) - - # Update the offset for the next state batch_offset += num_batches - # Concatenate all batch indices + # Concatenate collected tensors + for prop, tensors in per_atom_tensors.items(): + # if tensors: + concatenated[prop] = torch.cat(tensors, dim=0) + + for prop, tensors in per_batch_tensors.items(): + # if tensors: + concatenated[prop] = torch.cat(tensors, dim=0) + + # Concatenate batch indices concatenated["batch"] = torch.cat(new_batch_indices) # Create a new instance of the same class diff --git a/torchsim/workflows.py b/torchsim/workflows.py index b1793639b..91eba7503 100644 --- a/torchsim/workflows.py +++ b/torchsim/workflows.py @@ -580,7 +580,7 @@ def get_subcells_to_crystallize( # Convert stoichiometries to composition formulas comps = [] for stoich in stoichs: - comp = dict(zip(elements, stoich, strict=False)) + comp = dict(zip(elements, stoich, strict=True)) comps.append(Composition.from_dict(comp).reduced_formula) restrict_to_compositions = set(comps)