From b8d49fd18656c96c4a68f024ec1b7b547358a360 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 7 Mar 2025 19:51:34 +0000 Subject: [PATCH 01/36] in progres batching size determination --- torchsim/batching.py | 640 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 640 insertions(+) create mode 100644 torchsim/batching.py diff --git a/torchsim/batching.py b/torchsim/batching.py new file mode 100644 index 000000000..05f73f6e4 --- /dev/null +++ b/torchsim/batching.py @@ -0,0 +1,640 @@ +# %% +import copy +from typing import Literal + +import numpy as np +import torch +import torch.profiler +from ase.build import bulk +from propfolio.utils import composition_to_random_structure +from pymatgen.core import Composition + +from torchsim.models.soft_sphere import SoftSphereModel +from torchsim.optimizers import unit_cell_fire +from torchsim.runners import atoms_to_state, optimize +from torchsim.state import BaseState + + +def pack_soft_sphere( + comp: Composition, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + sigma: float = 2.5, + scale_volume: float = 1.0, +) -> BaseState: + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + struct = composition_to_random_structure(comp, scale_volume=scale_volume) + + ss_model = SoftSphereModel( + sigma=sigma, + device=device, + dtype=dtype, + use_neighbor_list=True, + compute_stress=True, + ) + return optimize( + system=struct, + model=ss_model, + optimizer=unit_cell_fire, + ) + + +def load_fairchem_or_mace( + model_path: str, + model_type: Literal["mace", "fairchem"], + device: torch.device, + dtype: torch.dtype, + compute_stress: bool | None = None, +): # TODO: replace with generic model + if model_type == "mace": + from torchsim.models.mace import MaceModel + + model = MaceModel( + model=torch.load(model_path, map_location=device), + device=device, + periodic=True, + compute_force=True, + dtype=dtype, + enable_cueq=True, + compute_stress=compute_stress if compute_stress is not None else False, + ) + elif model_type == "fairchem": + from torchsim.models.fairchem import FairChemModel + + model = FairChemModel( + model=model_path, + dtype=dtype, + compute_stress=compute_stress if compute_stress is not None else True, + ) + else: + raise ValueError(f"Unknown model type: {model_type}") + + return model + + +# def profile_model(model, input_tensors: dict[str, torch.Tensor]): + +# with torch.profiler.profile( +# activities=[torch.profiler.ProfilerActivity.CUDA], +# profile_memory=True, +# record_shapes=True, +# ) as prof: +# with torch.no_grad(): +# model(**input_tensors) + +# print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)) + + +def measure_model_memory_forward(model, input_tensors: dict[str, torch.Tensor]): + torch.cuda.reset_peak_memory_stats() + + import time + + start = time.perf_counter() + with torch.no_grad(): + model(**input_tensors) + end = time.perf_counter() + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to MB + return {"peak_memory": peak_memory, "time": end - start} + + +def measure_model_memory_optimize(model, state: BaseState): + def converge_forces(state) -> bool: + return torch.all(state.forces < 2e-1) + + optimize_state = optimize( + system=state, + model=model, + optimizer=unit_cell_fire, + convergence_fn=converge_forces, + cell_factor=10000, + ) + torch.cuda.reset_peak_memory_stats() + + import time + + start = time.perf_counter() + model( + positions=optimize_state.positions, + cell=optimize_state.cell, + batch=optimize_state.batch, + atomic_numbers=optimize_state.atomic_numbers, + ) + end = time.perf_counter() + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to MB + return {"peak_memory": peak_memory, "time": end - start} + + +# %% +# load model +mace_model_path = "/lambdafs/assets/mace_checkpoints/2024-12-03-mace-mp-alex-0.model" +radsim_model_path = ( + "/lambdafs/assets/radsim_checkpoints/radsim-s-v4/FT-BMG-GN-S-OMat-noquad-cutoff6.pt" +) + +# model = load_fairchem_or_mace( +# model_path=mace_model_path, +# model_type="mace", +# device=torch.device("cuda"), +# dtype=torch.float64, +# compute_stress=True, +# ) + + +# %% +# gc gpu memory +torch.cuda.empty_cache() + + +# %% +profile_model( + model, + { + "positions": state.positions, + "cell": state.cell, + "batch": state.batch, + "atomic_numbers": state.atomic_numbers, + }, +) + + +# %% +import gc + + +gc.collect() +torch.cuda.synchronize() +torch.cuda.empty_cache() +torch.cuda.ipc_collect() + + +atoms = bulk("Al", "fcc", a=4.05).repeat((5, 5, 5)) +state = atoms_to_state([atoms] * 60, model.device, model.dtype) + +# model.forward( +# positions=state.positions, +# cell=state.cell, +# batch=state.batch, +# atomic_numbers=state.atomic_numbers, +# ) +print("state.n_atoms", state.n_atoms) + +stats = measure_model_memory( + model, + { + "positions": state.positions, + "cell": state.cell, + "batch": state.batch, + "atomic_numbers": state.atomic_numbers, + }, +) +print(stats) +del state + + +# %% +# generate arbitrary compositions in Al Fe Mg space +compositions = [ + Composition("Al10Fe10Mg10"), + Composition("Al30"), + Composition("Fe30"), + Composition("Mg30"), + Composition("Al15Fe15"), + Composition("Al15Mg15"), + Composition("Fe15Mg15"), +] +stats_records = [] +for comp in compositions: + state = pack_soft_sphere(comp * 50, model.device, model.dtype) + # number density + n_atoms = state.n_atoms + volume = (state.cell[0, 0, 0] ** 3 / 1000).item() + number_density = n_atoms / volume + stats = measure_model_memory_optimize( + model, + state, + # { + # "positions": state.positions, + # "cell": state.cell, + # "batch": state.batch, + # "atomic_numbers": state.atomic_numbers, + # }, + ) + stats["number_density"] = number_density + stats["composition"] = str(comp.formula) + stats_records.append(stats) + + +# %% +import pandas as pd + + +df = pd.DataFrame.from_records(stats_records) +df + +# use plotly express to plot number density vs peak memory +import plotly.express as px + + +# make plot origin 0,0 +fig = px.scatter(df, x="number_density", y="peak_memory", hover_data=["composition"]) +fig.update_xaxes(range=[0, 100]) +fig.update_yaxes(range=[0, 10]) +fig.show() + + +# %% +# visualize the number density of the periodic table of +# elements with pymatviz + +# ... existing code ... + + +def measure_model_memory_forward(model, input_tensors: dict[str, torch.Tensor]): + # Clear GPU memory + import gc + import time + + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.reset_peak_memory_stats() + + start = time.perf_counter() + + model(**input_tensors) + end = time.perf_counter() + + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB + return {"peak_memory": peak_memory, "time": end - start} + + +def generate_test_sizes(start: int = 100, max_size: int = 100000, factor: float = 1.4): + """Generate a sequence of test sizes with exponential growth. + + Args: + start: Starting number of atoms + max_size: Maximum number of atoms to test + factor: Multiplicative factor between sizes + + Yields: + int: Next test size + """ + current = start + while current <= max_size: + yield current + current = int(current * factor) + + +def create_bulk_system( + element: str, + structure_type: str, + repeat_size: tuple[int, int, int], + device: torch.device, + dtype: torch.dtype, +) -> BaseState: + """Create a bulk crystal system with the specified parameters. + + Args: + element: Chemical element symbol + structure_type: Crystal structure type (fcc, bcc, etc.) + repeat_size: Tuple of repeat counts along each axis + device: Torch device + dtype: Data type for tensors + + Returns: + BaseState: System state + """ + # Create bulk structure with ASE + atoms = bulk(element, structure_type).repeat(repeat_size) + + # Convert to state + return atoms_to_state([atoms], device, dtype) + + +def calculate_number_density(state: BaseState) -> float: + """Calculate number density in atoms/nm³ for a state. + + Args: + state: System state + + Returns: + float: Number density in atoms/nm³ + """ + n_atoms = state.n_atoms + # Calculate volume in nm³ (convert from ų) + volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + return n_atoms / volume.item() + + +def test_model_memory_limit( + model, + element: str = "Al", + structure_type: str = "fcc", + device: torch.device = None, + dtype: torch.dtype = torch.float64, + start_size: int = 100, + max_size: int = 100000, + size_factor: float = 1.4, + density_factors: list[float] = None, + max_memory_gb: float = None, + safety_factor: float = 0.9, +): + """Test model memory limits using bulk systems of increasing size. + + Args: + model: The model to test + element: Element to use for bulk structure + structure_type: Crystal structure type + device: Torch device + dtype: Data type for tensors + start_size: Initial number of atoms to test + max_size: Maximum number of atoms to test + size_factor: Growth factor between test sizes + density_factors: List of factors to scale the lattice constant + max_memory_gb: Maximum GPU memory in GB (defaults to 90% of available) + safety_factor: Factor to reduce max_memory by + + Returns: + dict: Analysis results + """ + # Set defaults + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + if density_factors is None: + density_factors = [0.7, 0.85, 1.0, 1.15, 1.3] + + # Prepare to store results + results = [] + + stop_loop = False + # Test with different system sizes + # for n_atoms_target in generate_test_sizes(start_size, max_size, size_factor): + for repeat_dim in range(5, 30): + # Determine repeat size to get close to target atom count + # First get atoms per unit cell + atoms = bulk(element, structure_type) + + # Calculate repeat dimensions + # repeat_dim = int(round((n_atoms_target / len(atoms)) ** (1 / 3))) + repeat_size = (repeat_dim, repeat_dim, repeat_dim) + atoms = atoms.repeat(repeat_size) + atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 + + print(f"Testing system with {len(atoms)} atoms") + + # Test each density factor + for density_factor in density_factors: + # Create bulk system + # For density variation, we'll scale the lattice constant + # Smaller lattice constant = higher density + atoms = copy.deepcopy(atoms) + atoms.set_cell(atoms.get_cell() / density_factor) + + # Convert to state + state = atoms_to_state([atoms], device, dtype) + + # Get actual number of atoms and density + number_density = calculate_number_density(state) + + try: + # Measure memory during forward pass + memory_stats = measure_model_memory_forward( + model, + { + "positions": state.positions, + "cell": state.cell, + "batch": state.batch, + "atomic_numbers": state.atomic_numbers, + }, + ) + + # Add to results + results.append( + { + "n_atoms": len(atoms), + "number_density": number_density, + "peak_memory_gb": memory_stats["peak_memory"], + "time": memory_stats["time"], + "element": element, + "structure": structure_type, + "density_factor": density_factor, + "success": True, + } + ) + + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print( + f"Out of memory at {len(atoms)} atoms, density factor {density_factor}" + ) + results.append( + { + "n_atoms": len(atoms), + "number_density": number_density, + "peak_memory_gb": float("nan"), + "time": float("nan"), + "element": element, + "structure": structure_type, + "density_factor": density_factor, + "success": False, + } + ) + # Break inner loop and try smaller atom count + stop_loop = True + break + # Re-raise if it's not an OOM error + raise + + if stop_loop: + break + + return results + + +def analyze_memory_results(results): + """Analyze memory usage results and fit a predictive model. + + Args: + results: List of result dictionaries + max_memory_gb: Maximum GPU memory in GB + + Returns: + dict: Analysis results including interaction term between atoms and density + """ + import numpy as np + import pandas as pd + + # Convert results to DataFrame + df = pd.DataFrame.from_records(results) + + # Only proceed with model fitting if we have enough successful runs + successful_runs = df[df["success"]] + if len(successful_runs) > 3: + try: + from sklearn.linear_model import LinearRegression + from sklearn.preprocessing import PolynomialFeatures + + max_memory_gb = df["peak_memory_gb"].max() + + # Create features including the interaction term + # We want: memory ~ a*n_atoms + b*n_atoms*density + c + X_base = successful_runs[["n_atoms", "number_density"]].values + + # Create interaction term manually + successful_runs["interaction"] = ( + successful_runs["n_atoms"] * successful_runs["number_density"] + ) + X = successful_runs[["n_atoms", "interaction"]].values + y = successful_runs["peak_memory_gb"].values + + # Fit linear model with interaction term + model = LinearRegression() + model.fit(X, y) + + # Calculate estimated maximum atoms at each density + max_atoms_est = {} + for density in sorted(df["number_density"].unique()): + if np.isnan(density): + continue + + # Solve for n_atoms where: + # model.coef_[0] * n_atoms + model.coef_[1] * (n_atoms * density) + model.intercept_ = max_memory_gb + # Rearranging: n_atoms * (model.coef_[0] + model.coef_[1] * density) = max_memory_gb - model.intercept_ + + denominator = model.coef_[0] + model.coef_[1] * density + if denominator > 0: # Avoid division by zero or negative values + max_atoms = (max_memory_gb - model.intercept_) / denominator + max_atoms_est[density] = max(0, int(max_atoms)) + else: + max_atoms_est[density] = 0 + + # Create formula string + formula = ( + f"Memory (GB) = {model.coef_[0]:.2e} * n_atoms + " + f"{model.coef_[1]:.2e} * (n_atoms * density) + " + f"{model.intercept_:.2f}" + ) + + # define a function that calculates the peak memory from the formula + def peak_memory_from_formula(n_atoms, density): + return ( + model.coef_[0] * n_atoms + + model.coef_[1] * (n_atoms * density) + + model.intercept_ + ) + + return { + "fitted_model": model, + "max_atoms_estimated": max_atoms_est, + "coef_n_atoms": model.coef_[0], + "coef_interaction": model.coef_[1], + "intercept": model.intercept_, + "formula": formula, + "peak_memory_from_formula": peak_memory_from_formula, + "max_memory_used": max_memory_gb, + } + except ImportError: + return {"error": "sklearn not available for model fitting"} + + return {"error": "Not enough successful runs to fit model"} + + +# %% +# Load your model +model = load_fairchem_or_mace( + model_path=radsim_model_path, + model_type="fairchem", + device=torch.device("cuda"), + dtype=torch.float64, + compute_stress=False, +) + + +# %% +# Test memory limits +results = test_model_memory_limit( + model=model, + element="Al", + structure_type="fcc", + device=torch.device("cuda"), + dtype=torch.float64, + start_size=100, + max_size=50000, + density_factors=[0.8, 1.0, 1.2], +) + + +# %% +fit_results = analyze_memory_results(results["dataframe"], 10) + +all_results = {**results, **fit_results} + +# Print memory scaling formula +# if "formula" in results: +# print(results["formula"]) + +# # Print estimated maximum atoms at different densities +# print("\nEstimated maximum atoms before OOM:") +# for density, max_atoms in results["max_atoms_estimated"].items(): +# print(f" At density {density:.1f}: {max_atoms:,} atoms") + +# # Visualize results +# fig1, fig2 = visualize_memory_results(results) +# if fig1: +# fig1.show() +# if fig2: +# fig2.show() +fit_results + + +# %% +fit_results + + +# %% +df = results["dataframe"] + +df["number_density_x_n_atoms"] = df["number_density"] * df["n_atoms"] + +# plot peak_memory vs n_atoms +import plotly.express as px + + +fig = px.scatter( + df, + x="number_density_x_n_atoms", + y="peak_memory_gb", + title="MACE Memory Usage vs Number Density x Number of Atoms", +) +# width 800 +fig.update_layout(width=600) +fig.show() + + +# %% +fig = px.scatter(df, x="number_density", y="peak_memory_gb") +fig.show() + + +# %% +fig = px.scatter( + df, + x="n_atoms", + y="peak_memory_gb", + title="Fairchem Memory Usage vs Number of Atoms", +) + +fig.update_layout(width=600, xaxis_range=[0, 5000]) +fig.show() + + +# %% +df + + +# %% +results From f63b8dad7daad6e2e188635b6781846f0099d9f7 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 8 Mar 2025 22:32:27 +0000 Subject: [PATCH 02/36] add autobatching logic --- torchsim/batching.py | 640 ------------------------------------- torchsim/batching_utils.py | 385 ++++++++++++++++++++++ 2 files changed, 385 insertions(+), 640 deletions(-) delete mode 100644 torchsim/batching.py create mode 100644 torchsim/batching_utils.py diff --git a/torchsim/batching.py b/torchsim/batching.py deleted file mode 100644 index 05f73f6e4..000000000 --- a/torchsim/batching.py +++ /dev/null @@ -1,640 +0,0 @@ -# %% -import copy -from typing import Literal - -import numpy as np -import torch -import torch.profiler -from ase.build import bulk -from propfolio.utils import composition_to_random_structure -from pymatgen.core import Composition - -from torchsim.models.soft_sphere import SoftSphereModel -from torchsim.optimizers import unit_cell_fire -from torchsim.runners import atoms_to_state, optimize -from torchsim.state import BaseState - - -def pack_soft_sphere( - comp: Composition, - device: torch.device | None = None, - dtype: torch.dtype = torch.float64, - sigma: float = 2.5, - scale_volume: float = 1.0, -) -> BaseState: - device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - - struct = composition_to_random_structure(comp, scale_volume=scale_volume) - - ss_model = SoftSphereModel( - sigma=sigma, - device=device, - dtype=dtype, - use_neighbor_list=True, - compute_stress=True, - ) - return optimize( - system=struct, - model=ss_model, - optimizer=unit_cell_fire, - ) - - -def load_fairchem_or_mace( - model_path: str, - model_type: Literal["mace", "fairchem"], - device: torch.device, - dtype: torch.dtype, - compute_stress: bool | None = None, -): # TODO: replace with generic model - if model_type == "mace": - from torchsim.models.mace import MaceModel - - model = MaceModel( - model=torch.load(model_path, map_location=device), - device=device, - periodic=True, - compute_force=True, - dtype=dtype, - enable_cueq=True, - compute_stress=compute_stress if compute_stress is not None else False, - ) - elif model_type == "fairchem": - from torchsim.models.fairchem import FairChemModel - - model = FairChemModel( - model=model_path, - dtype=dtype, - compute_stress=compute_stress if compute_stress is not None else True, - ) - else: - raise ValueError(f"Unknown model type: {model_type}") - - return model - - -# def profile_model(model, input_tensors: dict[str, torch.Tensor]): - -# with torch.profiler.profile( -# activities=[torch.profiler.ProfilerActivity.CUDA], -# profile_memory=True, -# record_shapes=True, -# ) as prof: -# with torch.no_grad(): -# model(**input_tensors) - -# print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)) - - -def measure_model_memory_forward(model, input_tensors: dict[str, torch.Tensor]): - torch.cuda.reset_peak_memory_stats() - - import time - - start = time.perf_counter() - with torch.no_grad(): - model(**input_tensors) - end = time.perf_counter() - - peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to MB - return {"peak_memory": peak_memory, "time": end - start} - - -def measure_model_memory_optimize(model, state: BaseState): - def converge_forces(state) -> bool: - return torch.all(state.forces < 2e-1) - - optimize_state = optimize( - system=state, - model=model, - optimizer=unit_cell_fire, - convergence_fn=converge_forces, - cell_factor=10000, - ) - torch.cuda.reset_peak_memory_stats() - - import time - - start = time.perf_counter() - model( - positions=optimize_state.positions, - cell=optimize_state.cell, - batch=optimize_state.batch, - atomic_numbers=optimize_state.atomic_numbers, - ) - end = time.perf_counter() - - peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to MB - return {"peak_memory": peak_memory, "time": end - start} - - -# %% -# load model -mace_model_path = "/lambdafs/assets/mace_checkpoints/2024-12-03-mace-mp-alex-0.model" -radsim_model_path = ( - "/lambdafs/assets/radsim_checkpoints/radsim-s-v4/FT-BMG-GN-S-OMat-noquad-cutoff6.pt" -) - -# model = load_fairchem_or_mace( -# model_path=mace_model_path, -# model_type="mace", -# device=torch.device("cuda"), -# dtype=torch.float64, -# compute_stress=True, -# ) - - -# %% -# gc gpu memory -torch.cuda.empty_cache() - - -# %% -profile_model( - model, - { - "positions": state.positions, - "cell": state.cell, - "batch": state.batch, - "atomic_numbers": state.atomic_numbers, - }, -) - - -# %% -import gc - - -gc.collect() -torch.cuda.synchronize() -torch.cuda.empty_cache() -torch.cuda.ipc_collect() - - -atoms = bulk("Al", "fcc", a=4.05).repeat((5, 5, 5)) -state = atoms_to_state([atoms] * 60, model.device, model.dtype) - -# model.forward( -# positions=state.positions, -# cell=state.cell, -# batch=state.batch, -# atomic_numbers=state.atomic_numbers, -# ) -print("state.n_atoms", state.n_atoms) - -stats = measure_model_memory( - model, - { - "positions": state.positions, - "cell": state.cell, - "batch": state.batch, - "atomic_numbers": state.atomic_numbers, - }, -) -print(stats) -del state - - -# %% -# generate arbitrary compositions in Al Fe Mg space -compositions = [ - Composition("Al10Fe10Mg10"), - Composition("Al30"), - Composition("Fe30"), - Composition("Mg30"), - Composition("Al15Fe15"), - Composition("Al15Mg15"), - Composition("Fe15Mg15"), -] -stats_records = [] -for comp in compositions: - state = pack_soft_sphere(comp * 50, model.device, model.dtype) - # number density - n_atoms = state.n_atoms - volume = (state.cell[0, 0, 0] ** 3 / 1000).item() - number_density = n_atoms / volume - stats = measure_model_memory_optimize( - model, - state, - # { - # "positions": state.positions, - # "cell": state.cell, - # "batch": state.batch, - # "atomic_numbers": state.atomic_numbers, - # }, - ) - stats["number_density"] = number_density - stats["composition"] = str(comp.formula) - stats_records.append(stats) - - -# %% -import pandas as pd - - -df = pd.DataFrame.from_records(stats_records) -df - -# use plotly express to plot number density vs peak memory -import plotly.express as px - - -# make plot origin 0,0 -fig = px.scatter(df, x="number_density", y="peak_memory", hover_data=["composition"]) -fig.update_xaxes(range=[0, 100]) -fig.update_yaxes(range=[0, 10]) -fig.show() - - -# %% -# visualize the number density of the periodic table of -# elements with pymatviz - -# ... existing code ... - - -def measure_model_memory_forward(model, input_tensors: dict[str, torch.Tensor]): - # Clear GPU memory - import gc - import time - - gc.collect() - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - torch.cuda.reset_peak_memory_stats() - - start = time.perf_counter() - - model(**input_tensors) - end = time.perf_counter() - - peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB - return {"peak_memory": peak_memory, "time": end - start} - - -def generate_test_sizes(start: int = 100, max_size: int = 100000, factor: float = 1.4): - """Generate a sequence of test sizes with exponential growth. - - Args: - start: Starting number of atoms - max_size: Maximum number of atoms to test - factor: Multiplicative factor between sizes - - Yields: - int: Next test size - """ - current = start - while current <= max_size: - yield current - current = int(current * factor) - - -def create_bulk_system( - element: str, - structure_type: str, - repeat_size: tuple[int, int, int], - device: torch.device, - dtype: torch.dtype, -) -> BaseState: - """Create a bulk crystal system with the specified parameters. - - Args: - element: Chemical element symbol - structure_type: Crystal structure type (fcc, bcc, etc.) - repeat_size: Tuple of repeat counts along each axis - device: Torch device - dtype: Data type for tensors - - Returns: - BaseState: System state - """ - # Create bulk structure with ASE - atoms = bulk(element, structure_type).repeat(repeat_size) - - # Convert to state - return atoms_to_state([atoms], device, dtype) - - -def calculate_number_density(state: BaseState) -> float: - """Calculate number density in atoms/nm³ for a state. - - Args: - state: System state - - Returns: - float: Number density in atoms/nm³ - """ - n_atoms = state.n_atoms - # Calculate volume in nm³ (convert from ų) - volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 - return n_atoms / volume.item() - - -def test_model_memory_limit( - model, - element: str = "Al", - structure_type: str = "fcc", - device: torch.device = None, - dtype: torch.dtype = torch.float64, - start_size: int = 100, - max_size: int = 100000, - size_factor: float = 1.4, - density_factors: list[float] = None, - max_memory_gb: float = None, - safety_factor: float = 0.9, -): - """Test model memory limits using bulk systems of increasing size. - - Args: - model: The model to test - element: Element to use for bulk structure - structure_type: Crystal structure type - device: Torch device - dtype: Data type for tensors - start_size: Initial number of atoms to test - max_size: Maximum number of atoms to test - size_factor: Growth factor between test sizes - density_factors: List of factors to scale the lattice constant - max_memory_gb: Maximum GPU memory in GB (defaults to 90% of available) - safety_factor: Factor to reduce max_memory by - - Returns: - dict: Analysis results - """ - # Set defaults - device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - if density_factors is None: - density_factors = [0.7, 0.85, 1.0, 1.15, 1.3] - - # Prepare to store results - results = [] - - stop_loop = False - # Test with different system sizes - # for n_atoms_target in generate_test_sizes(start_size, max_size, size_factor): - for repeat_dim in range(5, 30): - # Determine repeat size to get close to target atom count - # First get atoms per unit cell - atoms = bulk(element, structure_type) - - # Calculate repeat dimensions - # repeat_dim = int(round((n_atoms_target / len(atoms)) ** (1 / 3))) - repeat_size = (repeat_dim, repeat_dim, repeat_dim) - atoms = atoms.repeat(repeat_size) - atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 - - print(f"Testing system with {len(atoms)} atoms") - - # Test each density factor - for density_factor in density_factors: - # Create bulk system - # For density variation, we'll scale the lattice constant - # Smaller lattice constant = higher density - atoms = copy.deepcopy(atoms) - atoms.set_cell(atoms.get_cell() / density_factor) - - # Convert to state - state = atoms_to_state([atoms], device, dtype) - - # Get actual number of atoms and density - number_density = calculate_number_density(state) - - try: - # Measure memory during forward pass - memory_stats = measure_model_memory_forward( - model, - { - "positions": state.positions, - "cell": state.cell, - "batch": state.batch, - "atomic_numbers": state.atomic_numbers, - }, - ) - - # Add to results - results.append( - { - "n_atoms": len(atoms), - "number_density": number_density, - "peak_memory_gb": memory_stats["peak_memory"], - "time": memory_stats["time"], - "element": element, - "structure": structure_type, - "density_factor": density_factor, - "success": True, - } - ) - - except RuntimeError as e: - if "CUDA out of memory" in str(e): - print( - f"Out of memory at {len(atoms)} atoms, density factor {density_factor}" - ) - results.append( - { - "n_atoms": len(atoms), - "number_density": number_density, - "peak_memory_gb": float("nan"), - "time": float("nan"), - "element": element, - "structure": structure_type, - "density_factor": density_factor, - "success": False, - } - ) - # Break inner loop and try smaller atom count - stop_loop = True - break - # Re-raise if it's not an OOM error - raise - - if stop_loop: - break - - return results - - -def analyze_memory_results(results): - """Analyze memory usage results and fit a predictive model. - - Args: - results: List of result dictionaries - max_memory_gb: Maximum GPU memory in GB - - Returns: - dict: Analysis results including interaction term between atoms and density - """ - import numpy as np - import pandas as pd - - # Convert results to DataFrame - df = pd.DataFrame.from_records(results) - - # Only proceed with model fitting if we have enough successful runs - successful_runs = df[df["success"]] - if len(successful_runs) > 3: - try: - from sklearn.linear_model import LinearRegression - from sklearn.preprocessing import PolynomialFeatures - - max_memory_gb = df["peak_memory_gb"].max() - - # Create features including the interaction term - # We want: memory ~ a*n_atoms + b*n_atoms*density + c - X_base = successful_runs[["n_atoms", "number_density"]].values - - # Create interaction term manually - successful_runs["interaction"] = ( - successful_runs["n_atoms"] * successful_runs["number_density"] - ) - X = successful_runs[["n_atoms", "interaction"]].values - y = successful_runs["peak_memory_gb"].values - - # Fit linear model with interaction term - model = LinearRegression() - model.fit(X, y) - - # Calculate estimated maximum atoms at each density - max_atoms_est = {} - for density in sorted(df["number_density"].unique()): - if np.isnan(density): - continue - - # Solve for n_atoms where: - # model.coef_[0] * n_atoms + model.coef_[1] * (n_atoms * density) + model.intercept_ = max_memory_gb - # Rearranging: n_atoms * (model.coef_[0] + model.coef_[1] * density) = max_memory_gb - model.intercept_ - - denominator = model.coef_[0] + model.coef_[1] * density - if denominator > 0: # Avoid division by zero or negative values - max_atoms = (max_memory_gb - model.intercept_) / denominator - max_atoms_est[density] = max(0, int(max_atoms)) - else: - max_atoms_est[density] = 0 - - # Create formula string - formula = ( - f"Memory (GB) = {model.coef_[0]:.2e} * n_atoms + " - f"{model.coef_[1]:.2e} * (n_atoms * density) + " - f"{model.intercept_:.2f}" - ) - - # define a function that calculates the peak memory from the formula - def peak_memory_from_formula(n_atoms, density): - return ( - model.coef_[0] * n_atoms - + model.coef_[1] * (n_atoms * density) - + model.intercept_ - ) - - return { - "fitted_model": model, - "max_atoms_estimated": max_atoms_est, - "coef_n_atoms": model.coef_[0], - "coef_interaction": model.coef_[1], - "intercept": model.intercept_, - "formula": formula, - "peak_memory_from_formula": peak_memory_from_formula, - "max_memory_used": max_memory_gb, - } - except ImportError: - return {"error": "sklearn not available for model fitting"} - - return {"error": "Not enough successful runs to fit model"} - - -# %% -# Load your model -model = load_fairchem_or_mace( - model_path=radsim_model_path, - model_type="fairchem", - device=torch.device("cuda"), - dtype=torch.float64, - compute_stress=False, -) - - -# %% -# Test memory limits -results = test_model_memory_limit( - model=model, - element="Al", - structure_type="fcc", - device=torch.device("cuda"), - dtype=torch.float64, - start_size=100, - max_size=50000, - density_factors=[0.8, 1.0, 1.2], -) - - -# %% -fit_results = analyze_memory_results(results["dataframe"], 10) - -all_results = {**results, **fit_results} - -# Print memory scaling formula -# if "formula" in results: -# print(results["formula"]) - -# # Print estimated maximum atoms at different densities -# print("\nEstimated maximum atoms before OOM:") -# for density, max_atoms in results["max_atoms_estimated"].items(): -# print(f" At density {density:.1f}: {max_atoms:,} atoms") - -# # Visualize results -# fig1, fig2 = visualize_memory_results(results) -# if fig1: -# fig1.show() -# if fig2: -# fig2.show() -fit_results - - -# %% -fit_results - - -# %% -df = results["dataframe"] - -df["number_density_x_n_atoms"] = df["number_density"] * df["n_atoms"] - -# plot peak_memory vs n_atoms -import plotly.express as px - - -fig = px.scatter( - df, - x="number_density_x_n_atoms", - y="peak_memory_gb", - title="MACE Memory Usage vs Number Density x Number of Atoms", -) -# width 800 -fig.update_layout(width=600) -fig.show() - - -# %% -fig = px.scatter(df, x="number_density", y="peak_memory_gb") -fig.show() - - -# %% -fig = px.scatter( - df, - x="n_atoms", - y="peak_memory_gb", - title="Fairchem Memory Usage vs Number of Atoms", -) - -fig.update_layout(width=600, xaxis_range=[0, 5000]) -fig.show() - - -# %% -df - - -# %% -results diff --git a/torchsim/batching_utils.py b/torchsim/batching_utils.py new file mode 100644 index 000000000..d2c86f6fc --- /dev/null +++ b/torchsim/batching_utils.py @@ -0,0 +1,385 @@ +"""Utilities for batching and memory management in torchsim.""" + +from collections.abc import Iterator +from itertools import chain +from typing import Literal + +import binpacking +import torch +from ase.build import bulk + +from torchsim.models.interface import ModelInterface +from torchsim.runners import atoms_to_state +from torchsim.state import BaseState, concatenate_states, slice_substate + + +def measure_model_memory_forward(model: ModelInterface, state: BaseState) -> float: + """Measure peak GPU memory usage during model forward pass. + + Args: + model: The model to measure memory usage for. + state: The input state to pass to the model. + + Returns: + Peak memory usage in GB. + """ + # Clear GPU memory + + # gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.reset_peak_memory_stats() + + model( + positions=state.positions, + cell=state.cell, + batch=state.batch, + atomic_numbers=state.atomic_numbers, + ) + + return torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB + + +def determine_max_batch_size( + model: ModelInterface, state: BaseState, max_atoms: int = 20000 +) -> int: + """Determine maximum batch size that fits in GPU memory. + + Args: + model: The model to test with. + state: The base state to replicate. + max_atoms: Maximum number of atoms to try. + + Returns: + Maximum number of batches that fit in GPU memory. + """ + # create a list of integers following the fibonacci sequence + fib = [1, 2] + while fib[-1] < max_atoms: + fib.append(fib[-1] + fib[-2]) + + for i in range(len(fib)): + n_batches = fib[i] + concat_state = concatenate_states([state] * n_batches) + + try: + measure_model_memory_forward(model, concat_state) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + return fib[i - 2] + raise + + return fib[-1] + + +def split_state(state: BaseState) -> list[BaseState]: + """Split a state into a list of states, each containing a single batch element.""" + # TODO: make this more efficient + return [slice_substate(state, i) for i in range(state.n_batches)] + + +def calculate_baseline_memory(model: ModelInterface) -> float: + """Calculate baseline memory usage of the model. + + Args: + model: The model to measure baseline memory for. + + Returns: + Baseline memory usage in GB. + """ + # Create baseline atoms with different sizes + baseline_atoms = [bulk("Al", "fcc").repeat((i, 1, 1)) for i in range(1, 9, 2)] + baseline_states = [ + atoms_to_state(atoms, model.device, model.dtype) for atoms in baseline_atoms + ] + + # Measure memory usage for each state + memory_list = [ + measure_model_memory_forward(model, state) for state in baseline_states + ] + + # Calculate number of atoms in each baseline state + n_atoms_list = [state.n_atoms for state in baseline_states] + + # Convert to tensors + n_atoms_tensor = torch.tensor(n_atoms_list, dtype=torch.float) + memory_tensor = torch.tensor(memory_list, dtype=torch.float) + + # Prepare design matrix (with column of ones for intercept) + X = torch.stack([torch.ones_like(n_atoms_tensor), n_atoms_tensor], dim=1) + + # Solve normal equations + beta = torch.linalg.lstsq(X, memory_tensor.unsqueeze(1)).solution.squeeze() + + # Extract intercept (b) and slope (m) + intercept, _ = beta[0].item(), beta[1].item() + + return intercept + + +def calculate_scaling_metric( + state_slice: BaseState, + metric: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", +) -> float: + """Calculate scaling metric for a state. + + Args: + state_slice: The state to calculate metric for. + metric: The type of metric to calculate. + + Returns: + The calculated metric value. + """ + if metric == "n_atoms": + return state_slice.n_atoms + if metric == "n_atoms_x_density": + volume = torch.abs(torch.linalg.det(state_slice.cell[0])) / 1000 + number_density = state_slice.n_atoms / volume.item() + return state_slice.n_atoms * number_density + raise ValueError(f"Invalid metric: {metric}") + + +def estimate_max_metric( + model: ModelInterface, + state_list: list[BaseState], + metric_values: list[float], + max_atoms: int = 20000, +) -> float: + """Estimate maximum metric value that fits in GPU memory. + + Args: + model: The model to test with. + state_list: List of states to test. + metric_values: Corresponding metric values for each state. + max_atoms: Maximum number of atoms to try. + + Returns: + Maximum metric value that fits in GPU memory. + """ + # all_metrics = torch.tensor( + # [calculate_scaling_metric(state_slice, metric) for state_slice in state_list] + # ) + + # select one state with the min n_atoms + min_metric = metric_values.min() + max_metric = metric_values.max() + + min_state = state_list[metric_values.argmin()] + max_state = state_list[metric_values.argmax()] + + min_state_max_batches = determine_max_batch_size(model, min_state, max_atoms) + max_state_max_batches = determine_max_batch_size(model, max_state, max_atoms) + + return min(min_state_max_batches * min_metric, max_state_max_batches * max_metric) + + +class ChunkingAutoBatcher: + """Batcher that chunks states into bins of similar computational cost.""" + + def __init__( + self, + model: ModelInterface, + states: list[BaseState] | BaseState, + metric: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + max_metric: float | None = None, + max_atoms_to_try: int = 1_000_000, + ) -> None: + """Initialize the batcher. + + Args: + model: The model to batch for. + states: States to batch. + metric: Metric to use for batching. + max_metric: Maximum metric value per batch. + max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. + """ + self.state_slices = ( + split_state(states) if isinstance(states, BaseState) else states + ) + self.metrics = [ + calculate_scaling_metric(state_slice, metric) + for state_slice in self.state_slices + ] + if not max_metric: + self.max_metric = estimate_max_metric( + model, self.state_slices, self.metrics, max_atoms_to_try + ) + else: + self.max_metric = max_metric + self.index_to_metric = dict(enumerate(self.metrics)) + self.index_bins = binpacking.to_constant_volume( + self.index_to_metric, V_max=self.max_metric + ) + self.state_bins = [] + for index_bin in self.index_bins: + self.state_bins.append([self.state_slices[i] for i in index_bin]) + self.current_state_bin = 0 + + def next_batch( + self, *, return_indices: bool = False + ) -> BaseState | tuple[BaseState, list[int]] | None: + """Get the next batch of states. + + Args: + return_indices: Whether to return indices along with the batch. + + Returns: + The next batch of states, optionally with indices, or None if no more batches. + """ + # TODO: need to think about how this intersects with reporting too + # TODO: definitely a clever treatment to be done with iterators here + if self.current_state_bin < len(self.state_bins): + state_bin = self.state_bins[self.current_state_bin] + self.current_state_bin += 1 + if return_indices: + return ( + concatenate_states(state_bin), + self.index_bins[self.current_state_bin - 1], + ) + return concatenate_states(state_bin) + return None + + def restore_original_order(self, state_bins: list[BaseState]) -> list[BaseState]: + """Take the state bins and reorder them into a list. + + Args: + state_bins: List of state batches to reorder. + + Returns: + States in their original order. + """ + # Flatten lists + all_states = list(chain.from_iterable(state_bins)) + original_indices = list(chain.from_iterable(self.index_bins)) + + # sort states by original indices + indexed_states = list(zip(original_indices, all_states, strict=False)) + return [state for _, state in sorted(indexed_states)] + + +class HotswappingAutoBatcher: + """Batcher that dynamically swaps states in and out based on convergence.""" + + def __init__( + self, + model: ModelInterface, + states: list[BaseState] | Iterator[BaseState] | BaseState, + metric: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + max_metric: float | None = None, + max_atoms_to_try: int = 1_000_000, + ) -> None: + """Initialize the batcher. + + Args: + model: The model to batch for. + states: States to batch. + metric: Metric to use for batching. + max_metric: Maximum metric value per batch. + max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. + """ + if isinstance(states, BaseState): + states = split_state(states) + if isinstance(states, list): + states = iter(states) + + self.model = model + self.states_iterator = states + self.metric = metric + if not max_metric: + self.max_metric = estimate_max_metric( + model, self.state_slices, self.metrics, max_atoms_to_try + ) + else: + self.max_metric = max_metric + + self.current_metric = 0 + self.empty_states_iterator = False + self.current_states_list = [] + self.current_metrics_list = [] + self.completed_idx_og_order = [] + + def _insert_next_states(self) -> None: + """Insert states from the iterator until max_metric is reached.""" + if self.empty_states_iterator: + return + while self.current_metric < self.max_metric: + try: + state = next(self.states_iterator) + except StopIteration: + self.empty_states_iterator = True + break + metric = calculate_scaling_metric(state, self.metric) + self.current_metric += metric + self.current_metrics_list += [metric] + self.current_states_list += [state] + + def first_batch(self) -> BaseState: + """Get the first batch of states. + + Returns: + The first batch of states. + """ + self._insert_next_states() + return concatenate_states(self.current_states_list) + + def next_batch(self, convergence_tensor: torch.Tensor) -> BaseState | None: + """Get the next batch of states based on convergence. + + Args: + convergence_tensor: Boolean tensor indicating which states have converged. + + Returns: + The next batch of states. + """ + assert len(convergence_tensor) == len(self.current_states_list) + assert len(convergence_tensor.shape) == 1 + + # find indices of all convergence_tensor elements that are True + completed_idx = list(torch.where(convergence_tensor)[0]) + + # Sort in descending order to avoid index shifting problems + completed_idx.sort(reverse=True) + + # remove states at these indices + for idx in completed_idx: + self.current_states_list.pop(idx) + self.current_metric -= self.current_metrics_list.pop(idx) + self.completed_idx_og_order.append(idx + len(self.completed_idx_og_order)) + + # insert next states + self._insert_next_states() + + if not self.current_states_list: + return None + + return concatenate_states(self.current_states_list) + + def restore_original_order( + self, completed_states: list[BaseState] + ) -> list[BaseState]: + """Take the list of completed states and reconstruct the original order. + + Args: + completed_states: List of completed states to reorder. + + Returns: + States in their original order. + + Raises: + ValueError: If the number of completed states doesn't match + the number of indices. + """ + if len(completed_states) != len(self.completed_idx_og_order): + raise ValueError( + f"Number of completed states ({len(completed_states)}) does not match " + f"number of completed indices ({len(self.completed_idx_og_order)})" + ) + + # Create pairs of (original_index, state) + indexed_states = list( + zip(self.completed_idx_og_order, completed_states, strict=False) + ) + + # Sort by original index + return [state for _, state in sorted(indexed_states)] From 8c90ac3ce90d48c281d2ee605d6b39234df3c69e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 8 Mar 2025 22:32:54 +0000 Subject: [PATCH 03/36] rename batching_utils -> autobatching --- torchsim/{batching_utils.py => autobatching.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torchsim/{batching_utils.py => autobatching.py} (100%) diff --git a/torchsim/batching_utils.py b/torchsim/autobatching.py similarity index 100% rename from torchsim/batching_utils.py rename to torchsim/autobatching.py From 748dfa18810c9333375b144d5e70b79199f87e7e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 8 Mar 2025 22:59:07 +0000 Subject: [PATCH 04/36] fix logic for too big states and memory estimation --- torchsim/autobatching.py | 50 ++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index d2c86f6fc..ab743b30a 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -286,30 +286,22 @@ def __init__( self.model = model self.states_iterator = states self.metric = metric - if not max_metric: - self.max_metric = estimate_max_metric( - model, self.state_slices, self.metrics, max_atoms_to_try - ) - else: - self.max_metric = max_metric + self.max_metric = max_metric or None + self.max_atoms_to_try = max_atoms_to_try self.current_metric = 0 - self.empty_states_iterator = False self.current_states_list = [] self.current_metrics_list = [] self.completed_idx_og_order = [] def _insert_next_states(self) -> None: """Insert states from the iterator until max_metric is reached.""" - if self.empty_states_iterator: - return - while self.current_metric < self.max_metric: - try: - state = next(self.states_iterator) - except StopIteration: - self.empty_states_iterator = True - break + for state in self.states_iterator: metric = calculate_scaling_metric(state, self.metric) + if self.current_metric + metric > self.max_metric: + # put the state back in the iterator + self.states_iterator = chain([state], self.states_iterator) + break self.current_metric += metric self.current_metrics_list += [metric] self.current_states_list += [state] @@ -320,7 +312,35 @@ def first_batch(self) -> BaseState: Returns: The first batch of states. """ + # we need to estimate the max metric for the first batch + first_state = next(self.states_iterator) + first_metric = calculate_scaling_metric(first_state, self.metric) + + # if max_metric is not set, estimate it + has_max_metric = bool(self.max_metric) + if not has_max_metric: + self.max_metric = estimate_max_metric( + self.model, + [first_state], + [first_metric], + max_atoms_to_try=self.max_atoms_to_try, + ) + self.max_metric *= 0.8 + + self.current_metric = first_metric + self.current_states_list = [first_state] + self.current_metrics_list = [first_metric] + self._insert_next_states() + + # update estimate of max metric if it was not set + if not has_max_metric: + self.max_metric = estimate_max_metric( + self.model, + self.current_states_list, + self.current_metrics_list, + max_atoms_to_try=1_000_000, + ) return concatenate_states(self.current_states_list) def next_batch(self, convergence_tensor: torch.Tensor) -> BaseState | None: From 7606bb8fc193ed7564c1fb2fa155ea9a14677c3a Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 8 Mar 2025 23:28:18 +0000 Subject: [PATCH 05/36] add tests and correct chunked autobatcher --- tests/test_autobatching.py | 196 +++++++++++++++++++++++++++++++++++++ torchsim/autobatching.py | 11 +-- 2 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 tests/test_autobatching.py diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py new file mode 100644 index 000000000..bafa8982b --- /dev/null +++ b/tests/test_autobatching.py @@ -0,0 +1,196 @@ +import pytest +import torch + +from torchsim.autobatching import ( + ChunkingAutoBatcher, + HotswappingAutoBatcher, + calculate_scaling_metric, + split_state, + determine_max_batch_size, +) +from torchsim.state import BaseState, concatenate_states + + +def test_calculate_scaling_metric(si_base_state: BaseState): + """Test calculation of scaling metrics for a state.""" + # Test n_atoms metric + n_atoms_metric = calculate_scaling_metric(si_base_state, "n_atoms") + assert n_atoms_metric == si_base_state.n_atoms + + # Test n_atoms_x_density metric + density_metric = calculate_scaling_metric(si_base_state, "n_atoms_x_density") + volume = torch.abs(torch.linalg.det(si_base_state.cell[0])) / 1000 + expected = si_base_state.n_atoms * (si_base_state.n_atoms / volume.item()) + assert pytest.approx(density_metric, rel=1e-5) == expected + + # Test invalid metric + with pytest.raises(ValueError, match="Invalid metric"): + calculate_scaling_metric(si_base_state, "invalid_metric") + + +def test_split_state(si_double_base_state: BaseState): + """Test splitting a batched state into individual states.""" + split_states = split_state(si_double_base_state) + + # Check we get the right number of states + assert len(split_states) == 2 + + # Check each state has the correct properties + for i, state in enumerate(split_states): + assert state.n_batches == 1 + assert torch.all(state.batch == 0) # Each split state should have batch indices reset to 0 + assert state.n_atoms == si_double_base_state.n_atoms // 2 + assert state.positions.shape[0] == si_double_base_state.n_atoms // 2 + assert state.cell.shape[0] == 1 + + +def test_chunking_auto_batcher(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): + """Test ChunkingAutoBatcher with different states.""" + # Create a list of states with different sizes + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric to avoid GPU memory testing + batcher = ChunkingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=100.0 # Set a small value to force multiple batches + ) + + # Check that the batcher correctly identified the metrics + assert len(batcher.metrics) == 2 + assert batcher.metrics[0] == si_base_state.n_atoms + assert batcher.metrics[1] == fe_fcc_state.n_atoms + + # Get batches until None is returned + batches = [] + while True: + batch = batcher.next_batch() + if batch is None: + break + batches.append(batch) + + # Check we got the expected number of batches + assert len(batches) == len(batcher.state_bins) + + # Test restore_original_order + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + + # Check that the restored states match the original states in order + # (Note: We can't directly compare the states because they've been through concatenation) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + +def test_chunking_auto_batcher_with_indices(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): + """Test ChunkingAutoBatcher with return_indices=True.""" + states = [si_base_state, fe_fcc_state] + + batcher = ChunkingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=100.0 + ) + + # Get batches with indices + batches_with_indices = [] + while True: + result = batcher.next_batch(return_indices=True) + if result is None: + break + batch, indices = result + batches_with_indices.append((batch, indices)) + + # Check we got the expected number of batches + assert len(batches_with_indices) == len(batcher.state_bins) + + # Check that the indices match the expected bin indices + for i, (batch, indices) in enumerate(batches_with_indices): + assert indices == batcher.index_bins[i] + + +def test_hotswapping_auto_batcher(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): + """Test HotswappingAutoBatcher with different states.""" + # Create a list of states + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric + batcher = HotswappingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=100.0 # Set a small value to force multiple batches + ) + + # Get the first batch + first_batch = batcher.first_batch() + assert isinstance(first_batch, BaseState) + + # Create a convergence tensor where the first state has converged + convergence = torch.tensor([True]) + + # Get the next batch + next_batch = batcher.next_batch(convergence) + assert isinstance(next_batch, BaseState) + + # Check that the converged state was removed + assert len(batcher.current_states_list) == 1 + assert len(batcher.completed_idx_og_order) == 1 + + # Create a convergence tensor where the remaining state has converged + convergence = torch.tensor([True]) + + # Get the next batch, which should be None since all states have converged + final_batch = batcher.next_batch(convergence) + assert final_batch is None + + # Check that all states are marked as completed + assert len(batcher.completed_idx_og_order) == 2 + + +def test_determine_max_batch_size_fibonacci(si_base_state: BaseState, lj_calculator, monkeypatch): + """Test that determine_max_batch_size uses Fibonacci sequence correctly.""" + # Mock measure_model_memory_forward to avoid actual GPU memory testing + def mock_measure(*args, **kwargs): + return 0.1 # Return a small constant memory usage + + monkeypatch.setattr("torchsim.autobatching.measure_model_memory_forward", mock_measure) + + # Test with a small max_atoms value to limit the sequence + max_size = determine_max_batch_size(lj_calculator, si_base_state, max_atoms=10) + + # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] + # Since we're not triggering OOM errors with our mock, it should return the largest value < max_atoms + assert max_size == 8 + + +def test_hotswapping_auto_batcher_restore_order(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): + """Test HotswappingAutoBatcher's restore_original_order method.""" + states = [si_base_state, fe_fcc_state] + + batcher = HotswappingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=100.0 + ) + + # Get the first batch + batcher.first_batch() + + # Simulate convergence of all states + convergence = torch.tensor([True, True]) + batcher.next_batch(convergence) + + # Create some completed states (doesn't matter what they are for this test) + completed_states = [si_base_state, fe_fcc_state] + + # Test restore_original_order + restored_states = batcher.restore_original_order(completed_states) + assert len(restored_states) == 2 + + # Test error when number of states doesn't match + with pytest.raises(ValueError, match="Number of completed states .* does not match"): + batcher.restore_original_order([si_base_state]) \ No newline at end of file diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index ab743b30a..7d7f2e07c 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -70,7 +70,7 @@ def determine_max_batch_size( return fib[i - 2] raise - return fib[-1] + return fib[-2] def split_state(state: BaseState) -> list[BaseState]: @@ -218,7 +218,7 @@ def __init__( def next_batch( self, *, return_indices: bool = False - ) -> BaseState | tuple[BaseState, list[int]] | None: + ) -> list[BaseState] | tuple[list[BaseState], list[int]] | None: """Get the next batch of states. Args: @@ -233,11 +233,8 @@ def next_batch( state_bin = self.state_bins[self.current_state_bin] self.current_state_bin += 1 if return_indices: - return ( - concatenate_states(state_bin), - self.index_bins[self.current_state_bin - 1], - ) - return concatenate_states(state_bin) + return state_bin, self.index_bins[self.current_state_bin - 1] + return state_bin return None def restore_original_order(self, state_bins: list[BaseState]) -> list[BaseState]: From cd0e638cc6109544e3a7bd843e4a1c3898e25dd3 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 8 Mar 2025 23:28:29 +0000 Subject: [PATCH 06/36] lint tests --- tests/test_autobatching.py | 111 ++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 51 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index bafa8982b..377284e88 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -5,10 +5,10 @@ ChunkingAutoBatcher, HotswappingAutoBatcher, calculate_scaling_metric, - split_state, determine_max_batch_size, + split_state, ) -from torchsim.state import BaseState, concatenate_states +from torchsim.state import BaseState def test_calculate_scaling_metric(si_base_state: BaseState): @@ -16,13 +16,13 @@ def test_calculate_scaling_metric(si_base_state: BaseState): # Test n_atoms metric n_atoms_metric = calculate_scaling_metric(si_base_state, "n_atoms") assert n_atoms_metric == si_base_state.n_atoms - + # Test n_atoms_x_density metric density_metric = calculate_scaling_metric(si_base_state, "n_atoms_x_density") volume = torch.abs(torch.linalg.det(si_base_state.cell[0])) / 1000 expected = si_base_state.n_atoms * (si_base_state.n_atoms / volume.item()) assert pytest.approx(density_metric, rel=1e-5) == expected - + # Test invalid metric with pytest.raises(ValueError, match="Invalid metric"): calculate_scaling_metric(si_base_state, "invalid_metric") @@ -31,37 +31,41 @@ def test_calculate_scaling_metric(si_base_state: BaseState): def test_split_state(si_double_base_state: BaseState): """Test splitting a batched state into individual states.""" split_states = split_state(si_double_base_state) - + # Check we get the right number of states assert len(split_states) == 2 - + # Check each state has the correct properties for i, state in enumerate(split_states): assert state.n_batches == 1 - assert torch.all(state.batch == 0) # Each split state should have batch indices reset to 0 + assert torch.all( + state.batch == 0 + ) # Each split state should have batch indices reset to 0 assert state.n_atoms == si_double_base_state.n_atoms // 2 assert state.positions.shape[0] == si_double_base_state.n_atoms // 2 assert state.cell.shape[0] == 1 -def test_chunking_auto_batcher(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): +def test_chunking_auto_batcher( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator +): """Test ChunkingAutoBatcher with different states.""" # Create a list of states with different sizes states = [si_base_state, fe_fcc_state] - + # Initialize the batcher with a fixed max_metric to avoid GPU memory testing batcher = ChunkingAutoBatcher( model=lj_calculator, states=states, metric="n_atoms", - max_metric=100.0 # Set a small value to force multiple batches + max_metric=100.0, # Set a small value to force multiple batches ) - + # Check that the batcher correctly identified the metrics assert len(batcher.metrics) == 2 assert batcher.metrics[0] == si_base_state.n_atoms assert batcher.metrics[1] == fe_fcc_state.n_atoms - + # Get batches until None is returned batches = [] while True: @@ -69,31 +73,30 @@ def test_chunking_auto_batcher(si_base_state: BaseState, fe_fcc_state: BaseState if batch is None: break batches.append(batch) - + # Check we got the expected number of batches assert len(batches) == len(batcher.state_bins) - + # Test restore_original_order restored_states = batcher.restore_original_order(batches) assert len(restored_states) == len(states) - + # Check that the restored states match the original states in order # (Note: We can't directly compare the states because they've been through concatenation) assert restored_states[0].n_atoms == states[0].n_atoms assert restored_states[1].n_atoms == states[1].n_atoms -def test_chunking_auto_batcher_with_indices(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): +def test_chunking_auto_batcher_with_indices( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator +): """Test ChunkingAutoBatcher with return_indices=True.""" states = [si_base_state, fe_fcc_state] - + batcher = ChunkingAutoBatcher( - model=lj_calculator, - states=states, - metric="n_atoms", - max_metric=100.0 + model=lj_calculator, states=states, metric="n_atoms", max_metric=100.0 ) - + # Get batches with indices batches_with_indices = [] while True: @@ -102,95 +105,101 @@ def test_chunking_auto_batcher_with_indices(si_base_state: BaseState, fe_fcc_sta break batch, indices = result batches_with_indices.append((batch, indices)) - + # Check we got the expected number of batches assert len(batches_with_indices) == len(batcher.state_bins) - + # Check that the indices match the expected bin indices for i, (batch, indices) in enumerate(batches_with_indices): assert indices == batcher.index_bins[i] -def test_hotswapping_auto_batcher(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): +def test_hotswapping_auto_batcher( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator +): """Test HotswappingAutoBatcher with different states.""" # Create a list of states states = [si_base_state, fe_fcc_state] - + # Initialize the batcher with a fixed max_metric batcher = HotswappingAutoBatcher( model=lj_calculator, states=states, metric="n_atoms", - max_metric=100.0 # Set a small value to force multiple batches + max_metric=100.0, # Set a small value to force multiple batches ) - + # Get the first batch first_batch = batcher.first_batch() assert isinstance(first_batch, BaseState) - + # Create a convergence tensor where the first state has converged convergence = torch.tensor([True]) - + # Get the next batch next_batch = batcher.next_batch(convergence) assert isinstance(next_batch, BaseState) - + # Check that the converged state was removed assert len(batcher.current_states_list) == 1 assert len(batcher.completed_idx_og_order) == 1 - + # Create a convergence tensor where the remaining state has converged convergence = torch.tensor([True]) - + # Get the next batch, which should be None since all states have converged final_batch = batcher.next_batch(convergence) assert final_batch is None - + # Check that all states are marked as completed assert len(batcher.completed_idx_og_order) == 2 -def test_determine_max_batch_size_fibonacci(si_base_state: BaseState, lj_calculator, monkeypatch): +def test_determine_max_batch_size_fibonacci( + si_base_state: BaseState, lj_calculator, monkeypatch +): """Test that determine_max_batch_size uses Fibonacci sequence correctly.""" + # Mock measure_model_memory_forward to avoid actual GPU memory testing def mock_measure(*args, **kwargs): return 0.1 # Return a small constant memory usage - - monkeypatch.setattr("torchsim.autobatching.measure_model_memory_forward", mock_measure) - + + monkeypatch.setattr( + "torchsim.autobatching.measure_model_memory_forward", mock_measure + ) + # Test with a small max_atoms value to limit the sequence max_size = determine_max_batch_size(lj_calculator, si_base_state, max_atoms=10) - + # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] # Since we're not triggering OOM errors with our mock, it should return the largest value < max_atoms assert max_size == 8 -def test_hotswapping_auto_batcher_restore_order(si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator): +def test_hotswapping_auto_batcher_restore_order( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator +): """Test HotswappingAutoBatcher's restore_original_order method.""" states = [si_base_state, fe_fcc_state] - + batcher = HotswappingAutoBatcher( - model=lj_calculator, - states=states, - metric="n_atoms", - max_metric=100.0 + model=lj_calculator, states=states, metric="n_atoms", max_metric=100.0 ) - + # Get the first batch batcher.first_batch() - + # Simulate convergence of all states convergence = torch.tensor([True, True]) batcher.next_batch(convergence) - + # Create some completed states (doesn't matter what they are for this test) completed_states = [si_base_state, fe_fcc_state] - + # Test restore_original_order restored_states = batcher.restore_original_order(completed_states) assert len(restored_states) == 2 - + # Test error when number of states doesn't match with pytest.raises(ValueError, match="Number of completed states .* does not match"): - batcher.restore_original_order([si_base_state]) \ No newline at end of file + batcher.restore_original_order([si_base_state]) From fcc439d0c99f86be62b4dae6b6f0752ea4307133 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 9 Mar 2025 00:42:01 +0000 Subject: [PATCH 07/36] add testing and make APIs more consistent --- tests/test_autobatching.py | 146 +++++++++++++++++++++++++++++-------- torchsim/autobatching.py | 55 ++++++++++++-- 2 files changed, 161 insertions(+), 40 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 377284e88..4613c2c34 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import torch @@ -8,10 +10,11 @@ determine_max_batch_size, split_state, ) +from torchsim.models.lennard_jones import LennardJonesModel from torchsim.state import BaseState -def test_calculate_scaling_metric(si_base_state: BaseState): +def test_calculate_scaling_metric(si_base_state: BaseState) -> None: """Test calculation of scaling metrics for a state.""" # Test n_atoms metric n_atoms_metric = calculate_scaling_metric(si_base_state, "n_atoms") @@ -28,7 +31,7 @@ def test_calculate_scaling_metric(si_base_state: BaseState): calculate_scaling_metric(si_base_state, "invalid_metric") -def test_split_state(si_double_base_state: BaseState): +def test_split_state(si_double_base_state: BaseState) -> None: """Test splitting a batched state into individual states.""" split_states = split_state(si_double_base_state) @@ -36,19 +39,19 @@ def test_split_state(si_double_base_state: BaseState): assert len(split_states) == 2 # Check each state has the correct properties - for i, state in enumerate(split_states): - assert state.n_batches == 1 + for state in enumerate(split_states): + assert state[1].n_batches == 1 assert torch.all( - state.batch == 0 + state[1].batch == 0 ) # Each split state should have batch indices reset to 0 - assert state.n_atoms == si_double_base_state.n_atoms // 2 - assert state.positions.shape[0] == si_double_base_state.n_atoms // 2 - assert state.cell.shape[0] == 1 + assert state[1].n_atoms == si_double_base_state.n_atoms // 2 + assert state[1].positions.shape[0] == si_double_base_state.n_atoms // 2 + assert state[1].cell.shape[0] == 1 def test_chunking_auto_batcher( - si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator -): + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: """Test ChunkingAutoBatcher with different states.""" # Create a list of states with different sizes states = [si_base_state, fe_fcc_state] @@ -58,7 +61,7 @@ def test_chunking_auto_batcher( model=lj_calculator, states=states, metric="n_atoms", - max_metric=100.0, # Set a small value to force multiple batches + max_metric=260.0, # Set a small value to force multiple batches ) # Check that the batcher correctly identified the metrics @@ -82,19 +85,22 @@ def test_chunking_auto_batcher( assert len(restored_states) == len(states) # Check that the restored states match the original states in order - # (Note: We can't directly compare the states because they've been through concatenation) assert restored_states[0].n_atoms == states[0].n_atoms assert restored_states[1].n_atoms == states[1].n_atoms + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + def test_chunking_auto_batcher_with_indices( - si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator -): + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: """Test ChunkingAutoBatcher with return_indices=True.""" states = [si_base_state, fe_fcc_state] batcher = ChunkingAutoBatcher( - model=lj_calculator, states=states, metric="n_atoms", max_metric=100.0 + model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0 ) # Get batches with indices @@ -110,13 +116,74 @@ def test_chunking_auto_batcher_with_indices( assert len(batches_with_indices) == len(batcher.state_bins) # Check that the indices match the expected bin indices - for i, (batch, indices) in enumerate(batches_with_indices): + for i, (_, indices) in enumerate(batches_with_indices): assert indices == batcher.index_bins[i] +def test_chunking_auto_batcher_restore_order_with_split_states( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test ChunkingAutoBatcher's restore_original_order method with split states.""" + # Create a list of states with different sizes + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric to avoid GPU memory testing + batcher = ChunkingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=260.0, # Set a small value to force multiple batches + ) + + # Get batches until None is returned + batches = [] + while True: + batch = batcher.next_batch() + if batch is None: + break + # Split each batch into individual states to simulate processing + # split_batch = split_state(batch) + batches.append(batch) + + # Test restore_original_order with split states + # This tests the chain.from_iterable functionality + restored_states = batcher.restore_original_order(batches) + + # Check we got the right number of states back + assert len(restored_states) == len(states) + + # Check that the restored states match the original states in order + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + +def test_hotswapping_max_metric_too_small( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + """Test HotswappingAutoBatcher with different states.""" + # Create a list of states + states = [si_base_state, fe_fcc_state] + + # Initialize the batcher with a fixed max_metric + batcher = HotswappingAutoBatcher( + model=lj_calculator, + states=states, + metric="n_atoms", + max_metric=1.0, # Set a small value to force multiple batches + ) + + # Get the first batch + with pytest.raises(ValueError, match="is greater than max_metric"): + batcher.first_batch() + + def test_hotswapping_auto_batcher( - si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator -): + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: """Test HotswappingAutoBatcher with different states.""" # Create a list of states states = [si_base_state, fe_fcc_state] @@ -126,7 +193,7 @@ def test_hotswapping_auto_batcher( model=lj_calculator, states=states, metric="n_atoms", - max_metric=100.0, # Set a small value to force multiple batches + max_metric=260, # Set a small value to force multiple batches ) # Get the first batch @@ -137,8 +204,9 @@ def test_hotswapping_auto_batcher( convergence = torch.tensor([True]) # Get the next batch - next_batch = batcher.next_batch(convergence) - assert isinstance(next_batch, BaseState) + next_batch, idx = batcher.next_batch(convergence, return_indices=True) + assert isinstance(next_batch, list) + assert idx == [1] # Check that the converged state was removed assert len(batcher.current_states_list) == 1 @@ -156,12 +224,12 @@ def test_hotswapping_auto_batcher( def test_determine_max_batch_size_fibonacci( - si_base_state: BaseState, lj_calculator, monkeypatch -): + si_base_state: BaseState, lj_calculator: LennardJonesModel, monkeypatch: Any +) -> None: """Test that determine_max_batch_size uses Fibonacci sequence correctly.""" # Mock measure_model_memory_forward to avoid actual GPU memory testing - def mock_measure(*args, **kwargs): + def mock_measure(*_args: Any, **_kwargs: Any) -> float: return 0.1 # Return a small constant memory usage monkeypatch.setattr( @@ -172,25 +240,29 @@ def mock_measure(*args, **kwargs): max_size = determine_max_batch_size(lj_calculator, si_base_state, max_atoms=10) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] - # Since we're not triggering OOM errors with our mock, it should return the largest value < max_atoms + # Since we're not triggering OOM errors with our mock, it should + # return the largest value < max_atoms assert max_size == 8 def test_hotswapping_auto_batcher_restore_order( - si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator -): + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: """Test HotswappingAutoBatcher's restore_original_order method.""" states = [si_base_state, fe_fcc_state] batcher = HotswappingAutoBatcher( - model=lj_calculator, states=states, metric="n_atoms", max_metric=100.0 + model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0 ) # Get the first batch batcher.first_batch() # Simulate convergence of all states - convergence = torch.tensor([True, True]) + convergence = torch.tensor([True]) + batcher.next_batch(convergence) + + # sample batch a second time batcher.next_batch(convergence) # Create some completed states (doesn't matter what they are for this test) @@ -200,6 +272,16 @@ def test_hotswapping_auto_batcher_restore_order( restored_states = batcher.restore_original_order(completed_states) assert len(restored_states) == 2 - # Test error when number of states doesn't match - with pytest.raises(ValueError, match="Number of completed states .* does not match"): - batcher.restore_original_order([si_base_state]) + # Check that the restored states match the original states in order + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + # Check atomic numbers to verify the correct order + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + # # Test error when number of states doesn't match + # with pytest.raises( + # ValueError, match="Number of completed states .* does not match" + # ): + # batcher.restore_original_order([si_base_state]) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 7d7f2e07c..65611c0eb 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -207,6 +207,17 @@ def __init__( ) else: self.max_metric = max_metric + + # verify that no systems are too large + max_metric_value = max(self.metrics) + max_metric_idx = self.metrics.index(max_metric_value) + if max_metric_value > self.max_metric: + raise ValueError( + f"Max metric of system with index {max_metric_idx} in states: " + f"{max(self.metrics)} is greater than max_metric {self.max_metric}, " + f"please set a larger max_metric or run smaller systems metric." + ) + self.index_to_metric = dict(enumerate(self.metrics)) self.index_bins = binpacking.to_constant_volume( self.index_to_metric, V_max=self.max_metric @@ -237,7 +248,9 @@ def next_batch( return state_bin return None - def restore_original_order(self, state_bins: list[BaseState]) -> list[BaseState]: + def restore_original_order( + self, state_bins: list[list[BaseState]] + ) -> list[BaseState]: """Take the state bins and reorder them into a list. Args: @@ -246,6 +259,9 @@ def restore_original_order(self, state_bins: list[BaseState]) -> list[BaseState] Returns: States in their original order. """ + # TODO: need to assert at some point that the input states list + # are all batch size 1 + # Flatten lists all_states = list(chain.from_iterable(state_bins)) original_indices = list(chain.from_iterable(self.index_bins)) @@ -286,22 +302,36 @@ def __init__( self.max_metric = max_metric or None self.max_atoms_to_try = max_atoms_to_try - self.current_metric = 0 + self.total_metric = 0 + + # TODO: could be smarter about making these all together self.current_states_list = [] self.current_metrics_list = [] + self.current_idx_list = [] + self.completed_idx_og_order = [] def _insert_next_states(self) -> None: """Insert states from the iterator until max_metric is reached.""" for state in self.states_iterator: metric = calculate_scaling_metric(state, self.metric) - if self.current_metric + metric > self.max_metric: + if metric > self.max_metric: + raise ValueError( + f"State metric {metric} is greater than max_metric " + f"{self.max_metric}, please set a larger max_metric " + f"or run smaller systems metric." + ) + if self.total_metric + metric > self.max_metric: # put the state back in the iterator self.states_iterator = chain([state], self.states_iterator) break - self.current_metric += metric + self.total_metric += metric + + # TODO: could be smarter about making these all together self.current_metrics_list += [metric] self.current_states_list += [state] + self.current_idx_list += [self.iterator_idx] + self.iterator_idx += 1 def first_batch(self) -> BaseState: """Get the first batch of states. @@ -324,9 +354,11 @@ def first_batch(self) -> BaseState: ) self.max_metric *= 0.8 - self.current_metric = first_metric + self.total_metric = first_metric self.current_states_list = [first_state] self.current_metrics_list = [first_metric] + self.current_idx_list = [0] + self.iterator_idx = 1 self._insert_next_states() @@ -340,11 +372,14 @@ def first_batch(self) -> BaseState: ) return concatenate_states(self.current_states_list) - def next_batch(self, convergence_tensor: torch.Tensor) -> BaseState | None: + def next_batch( + self, convergence_tensor: torch.Tensor, *, return_indices: bool = False + ) -> list[BaseState] | tuple[list[BaseState], list[int]] | None: """Get the next batch of states based on convergence. Args: convergence_tensor: Boolean tensor indicating which states have converged. + return_indices: Whether to return indices along with the batch. Returns: The next batch of states. @@ -361,8 +396,9 @@ def next_batch(self, convergence_tensor: torch.Tensor) -> BaseState | None: # remove states at these indices for idx in completed_idx: self.current_states_list.pop(idx) - self.current_metric -= self.current_metrics_list.pop(idx) + self.total_metric -= self.current_metrics_list.pop(idx) self.completed_idx_og_order.append(idx + len(self.completed_idx_og_order)) + self.current_idx_list.pop(idx) # insert next states self._insert_next_states() @@ -370,7 +406,10 @@ def next_batch(self, convergence_tensor: torch.Tensor) -> BaseState | None: if not self.current_states_list: return None - return concatenate_states(self.current_states_list) + if return_indices: + return self.current_states_list, self.current_idx_list + + return self.current_states_list def restore_original_order( self, completed_states: list[BaseState] From 6df12f9628f9a6dee27cf6bc3091867ef9dd6cc9 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 9 Mar 2025 01:02:27 +0000 Subject: [PATCH 08/36] small reorganization --- tests/test_autobatching.py | 2 +- torchsim/autobatching.py | 42 +++++++++++++++----------------------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 4613c2c34..e15abc616 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -209,7 +209,7 @@ def test_hotswapping_auto_batcher( assert idx == [1] # Check that the converged state was removed - assert len(batcher.current_states_list) == 1 + assert len(batcher.current_state_metric_idx) == 1 assert len(batcher.completed_idx_og_order) == 1 # Create a convergence tensor where the remaining state has converged diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 65611c0eb..be012c490 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -157,10 +157,6 @@ def estimate_max_metric( Returns: Maximum metric value that fits in GPU memory. """ - # all_metrics = torch.tensor( - # [calculate_scaling_metric(state_slice, metric) for state_slice in state_list] - # ) - # select one state with the min n_atoms min_metric = metric_values.min() max_metric = metric_values.max() @@ -304,11 +300,7 @@ def __init__( self.total_metric = 0 - # TODO: could be smarter about making these all together - self.current_states_list = [] - self.current_metrics_list = [] - self.current_idx_list = [] - + self.current_state_metric_idx = [] self.completed_idx_og_order = [] def _insert_next_states(self) -> None: @@ -328,9 +320,7 @@ def _insert_next_states(self) -> None: self.total_metric += metric # TODO: could be smarter about making these all together - self.current_metrics_list += [metric] - self.current_states_list += [state] - self.current_idx_list += [self.iterator_idx] + self.current_state_metric_idx += [(state, metric, self.iterator_idx)] self.iterator_idx += 1 def first_batch(self) -> BaseState: @@ -355,22 +345,22 @@ def first_batch(self) -> BaseState: self.max_metric *= 0.8 self.total_metric = first_metric - self.current_states_list = [first_state] - self.current_metrics_list = [first_metric] - self.current_idx_list = [0] + self.current_state_metric_idx = [(first_state, first_metric, 0)] self.iterator_idx = 1 self._insert_next_states() # update estimate of max metric if it was not set + current_states = [state for state, _, _ in self.current_state_metric_idx] + current_metrics = [metric for _, metric, _ in self.current_state_metric_idx] if not has_max_metric: self.max_metric = estimate_max_metric( self.model, - self.current_states_list, - self.current_metrics_list, + current_states, + current_metrics, max_atoms_to_try=1_000_000, ) - return concatenate_states(self.current_states_list) + return concatenate_states(current_states) def next_batch( self, convergence_tensor: torch.Tensor, *, return_indices: bool = False @@ -384,7 +374,7 @@ def next_batch( Returns: The next batch of states. """ - assert len(convergence_tensor) == len(self.current_states_list) + assert len(convergence_tensor) == len(self.current_state_metric_idx) assert len(convergence_tensor.shape) == 1 # find indices of all convergence_tensor elements that are True @@ -395,21 +385,23 @@ def next_batch( # remove states at these indices for idx in completed_idx: - self.current_states_list.pop(idx) - self.total_metric -= self.current_metrics_list.pop(idx) + _, metric, _ = self.current_state_metric_idx.pop(idx) + self.total_metric -= metric self.completed_idx_og_order.append(idx + len(self.completed_idx_og_order)) - self.current_idx_list.pop(idx) # insert next states self._insert_next_states() - if not self.current_states_list: + if not self.current_state_metric_idx: return None + current_states = [state for state, _, _ in self.current_state_metric_idx] + if return_indices: - return self.current_states_list, self.current_idx_list + current_idx = [idx for _, _, idx in self.current_state_metric_idx] + return current_states, current_idx - return self.current_states_list + return current_states def restore_original_order( self, completed_states: list[BaseState] From 29dea77fe41325890605525f44b8158f79e75736 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 10 Mar 2025 17:48:46 +0000 Subject: [PATCH 09/36] update state to work with hotswapping, split and concat logic needs rethinking --- torchsim/state.py | 81 +++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/torchsim/state.py b/torchsim/state.py index ed1b57cc1..462062707 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -310,7 +310,7 @@ def slice_substate( return type(state)(**sliced_attrs) -def concatenate_states( # noqa: C901 +def concatenate_states( states: list[BaseState], device: torch.device | None = None ) -> BaseState: """Concatenate a list of BaseStates into a single BaseState. @@ -336,57 +336,56 @@ def concatenate_states( # noqa: C901 if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") - # Categorize properties by scope for each state - property_scopes = [infer_property_scope(state) for state in states] + # Use the target device or default to the first state's device + target_device = device or first_state.device - # Collect all property names across all states - all_props = set() - for scope in property_scopes: - for scope_type in scope.values(): - all_props.update(scope_type) + # Get property scopes from the first state to identify global/per-atom/per-batch properties + first_scope = infer_property_scope(first_state) + global_props = set(first_scope["global"]) + per_atom_props = set(first_scope["per_atom"]) + per_batch_props = set(first_scope["per_batch"]) - # Initialize dictionaries to hold concatenated properties - concatenated = {} + # Initialize result with global properties from first state + concatenated = {prop: getattr(first_state, prop) for prop in global_props} - # Process global properties (take from first state) - for prop_name in property_scopes[0]["global"]: - concatenated[prop_name] = getattr(first_state, prop_name) - - # Process per-atom properties (concatenate) - for prop_name in set().union(*[scope["per_atom"] for scope in property_scopes]): - tensors = [ - getattr(state, prop_name) for state in states if hasattr(state, prop_name) - ] - if tensors: - concatenated[prop_name] = torch.cat(tensors, dim=0) - - # Process per-batch properties (concatenate) - for prop_name in set().union(*[scope["per_batch"] for scope in property_scopes]): - tensors = [ - getattr(state, prop_name) for state in states if hasattr(state, prop_name) - ] - if tensors: - concatenated[prop_name] = torch.cat(tensors, dim=0) - - # Create new batch indices that account for existing batch structure + # Pre-allocate lists for tensors to concatenate + per_atom_tensors = {prop: [] for prop in per_atom_props} + per_batch_tensors = {prop: [] for prop in per_batch_props} new_batch_indices = [] batch_offset = 0 - device = device or states[0].device + # Process all states in a single pass for state in states: - state = state_to_device(state, device) - - # Get the number of unique batches in this state - num_batches = len(torch.unique(state.batch)) - - # For each atom, map its current batch index to a new index with the offset + # Move state to target device if needed + if state.device != target_device: + state = state_to_device(state, target_device) + + # Collect per-atom properties + for prop in per_atom_props: + # if hasattr(state, prop): + per_atom_tensors[prop].append(getattr(state, prop)) + + # Collect per-batch properties + for prop in per_batch_props: + # if hasattr(state, prop): + per_batch_tensors[prop].append(getattr(state, prop)) + + # Update batch indices + num_batches = state.n_batches new_indices = state.batch + batch_offset new_batch_indices.append(new_indices) - - # Update the offset for the next state batch_offset += num_batches - # Concatenate all batch indices + # Concatenate collected tensors + for prop, tensors in per_atom_tensors.items(): + # if tensors: + concatenated[prop] = torch.cat(tensors, dim=0) + + for prop, tensors in per_batch_tensors.items(): + # if tensors: + concatenated[prop] = torch.cat(tensors, dim=0) + + # Concatenate batch indices concatenated["batch"] = torch.cat(new_batch_indices) # Create a new instance of the same class From 78f7a250885b3b6f412509e24833658c1538a0aa Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 10 Mar 2025 17:49:31 +0000 Subject: [PATCH 10/36] add iterator logic to chunking and fix logic for hot swapping --- torchsim/autobatching.py | 53 +++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index be012c490..8aad996a1 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -157,6 +157,8 @@ def estimate_max_metric( Returns: Maximum metric value that fits in GPU memory. """ + metric_values = torch.tensor(metric_values) + # select one state with the min n_atoms min_metric = metric_values.min() max_metric = metric_values.max() @@ -201,6 +203,7 @@ def __init__( self.max_metric = estimate_max_metric( model, self.state_slices, self.metrics, max_atoms_to_try ) + print(f"Max metric calculated: {self.max_metric}") else: self.max_metric = max_metric @@ -244,6 +247,15 @@ def next_batch( return state_bin return None + def __iter__(self): + return self + + def __next__(self): + next_batch = self.next_batch() + if next_batch is None: + raise StopIteration + return next_batch + def restore_original_order( self, state_bins: list[list[BaseState]] ) -> list[BaseState]: @@ -323,7 +335,7 @@ def _insert_next_states(self) -> None: self.current_state_metric_idx += [(state, metric, self.iterator_idx)] self.iterator_idx += 1 - def first_batch(self) -> BaseState: + def first_batch(self) -> list[BaseState]: """Get the first batch of states. Returns: @@ -340,7 +352,7 @@ def first_batch(self) -> BaseState: self.model, [first_state], [first_metric], - max_atoms_to_try=self.max_atoms_to_try, + max_atoms=self.max_atoms_to_try, ) self.max_metric *= 0.8 @@ -358,13 +370,21 @@ def first_batch(self) -> BaseState: self.model, current_states, current_metrics, - max_atoms_to_try=1_000_000, + max_atoms=self.max_atoms_to_try, ) - return concatenate_states(current_states) + print(f"Max metric calculated: {self.max_metric}") + return current_states def next_batch( - self, convergence_tensor: torch.Tensor, *, return_indices: bool = False - ) -> list[BaseState] | tuple[list[BaseState], list[int]] | None: + self, + updated_concat_state: BaseState, + convergence_tensor: torch.Tensor, + *, + return_indices: bool = False, + ) -> ( + tuple[list[BaseState], list[BaseState]] + | tuple[list[BaseState], list[BaseState], list[int]] + ): """Get the next batch of states based on convergence. Args: @@ -374,6 +394,19 @@ def next_batch( Returns: The next batch of states. """ + # TODO: this is bloated and we need to clean this up + # to make it more efficient + states_list = split_state(updated_concat_state) + new_current_state_metric_idx = [] + for i, state in enumerate(states_list): + new_tup = ( + state, + self.current_state_metric_idx[i][1], + self.current_state_metric_idx[i][2], + ) + new_current_state_metric_idx.append(new_tup) + self.current_state_metric_idx = new_current_state_metric_idx + assert len(convergence_tensor) == len(self.current_state_metric_idx) assert len(convergence_tensor.shape) == 1 @@ -384,8 +417,10 @@ def next_batch( completed_idx.sort(reverse=True) # remove states at these indices + completed_states = [] for idx in completed_idx: - _, metric, _ = self.current_state_metric_idx.pop(idx) + state, metric, _ = self.current_state_metric_idx.pop(idx) + completed_states.append(state) self.total_metric -= metric self.completed_idx_og_order.append(idx + len(self.completed_idx_og_order)) @@ -393,7 +428,7 @@ def next_batch( self._insert_next_states() if not self.current_state_metric_idx: - return None + return [], [] current_states = [state for state, _, _ in self.current_state_metric_idx] @@ -401,7 +436,7 @@ def next_batch( current_idx = [idx for _, _, idx in self.current_state_metric_idx] return current_states, current_idx - return current_states + return current_states, completed_states def restore_original_order( self, completed_states: list[BaseState] From 972266d8a1c951c654da40ffa96c9cea2c939a38 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 10 Mar 2025 17:49:45 +0000 Subject: [PATCH 11/36] update convergence handling in runner --- torchsim/runners.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchsim/runners.py b/torchsim/runners.py index 01eaff823..6af25f26a 100644 --- a/torchsim/runners.py +++ b/torchsim/runners.py @@ -113,7 +113,8 @@ def optimize( system: Input system to optimize (ASE Atoms, Pymatgen Structure, or BaseState) model: Neural network calculator module optimizer: Optimization algorithm function - convergence_fn: Condition for convergence + convergence_fn: Condition for convergence, should return a boolean tensor + of length n_batches unit_system: Unit system for energy tolerance optimizer_kwargs: Additional keyword arguments for optimizer trajectory_reporter: Optional reporter for tracking optimization trajectory @@ -124,9 +125,8 @@ def optimize( """ # TODO: document this behavior if convergence_fn is None: - def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: - return torch.all(last_energy - state.energy < 1e-6 * unit_system.energy) + return last_energy - state.energy < 1e-4 * unit_system.energy # we partially evaluate the function to create a new function with # an optional second argument, this can be set to state later on @@ -144,7 +144,7 @@ def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: step: int = 1 last_energy = state.energy + 1 - while not convergence_fn(state, last_energy): + while not torch.all(convergence_fn(state, last_energy)): last_energy = state.energy state = update_fn(state) From aa317942d2ddeaed25fead235e816daa416592e9 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 10 Mar 2025 17:50:01 +0000 Subject: [PATCH 12/36] update convergence function in high level api --- examples/4_High_level_api/4.1_high_level_api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/4_High_level_api/4.1_high_level_api.py b/examples/4_High_level_api/4.1_high_level_api.py index 90baba464..7febcd533 100644 --- a/examples/4_High_level_api/4.1_high_level_api.py +++ b/examples/4_High_level_api/4.1_high_level_api.py @@ -179,9 +179,8 @@ system=systems, model=mace_model, optimizer=unit_cell_fire, - convergence_fn=lambda state, last_energy: torch.all( - last_energy - state.energy < 1e-6 * MetalUnits.energy - ), + convergence_fn=lambda state, last_energy: last_energy - state.energy + < 1e-6 * MetalUnits.energy, max_steps=10 if os.getenv("CI") else 1000, ) From 7b84830e322df5e6cf8c9237e72571227ce6f41e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 10 Mar 2025 17:50:19 +0000 Subject: [PATCH 13/36] add an autobatching example script --- .../4_High_level_api/4_2_auto_batching_api.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 examples/4_High_level_api/4_2_auto_batching_api.py diff --git a/examples/4_High_level_api/4_2_auto_batching_api.py b/examples/4_High_level_api/4_2_auto_batching_api.py new file mode 100644 index 000000000..88adaf2d6 --- /dev/null +++ b/examples/4_High_level_api/4_2_auto_batching_api.py @@ -0,0 +1,125 @@ +# %% +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +from torchsim.integrators import nvt_langevin +from torchsim.models.mace import MaceModel +from torchsim.optimizers import unit_cell_fire +from torchsim.runners import atoms_to_state +from torchsim.autobatching import HotswappingAutoBatcher, ChunkingAutoBatcher, split_state +from torchsim.units import MetalUnits +from torchsim.state import concatenate_states, BaseState + + +si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) +fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) + +device = torch.device("cuda") + +mace = mace_mp(model="small", return_raw_model=True) +mace_model = MaceModel( + model=mace, + device=device, + periodic=True, + dtype=torch.float64, + compute_force=True, +) + +si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64) +fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64) + +fire_init, fire_update = unit_cell_fire(mace_model) + +si_fire_state = fire_init(si_state) +fe_fire_state = fire_init(fe_state) + +fire_states = [si_fire_state, fe_fire_state] * 100 +fire_states = [state.clone() for state in fire_states] +for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + +# %% + +batcher = HotswappingAutoBatcher( + model=mace_model, + states=fire_states, + metric="n_atoms_x_density", + # max_metric=400_000, + max_metric=100_000, +) + +def convergence_fn(state: BaseState) -> bool: + batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) + max_forces = state.forces.norm(dim=1) + batch_wise_max_force = batch_wise_max_force.scatter_reduce( + dim=0, + index=state.batch, + src=max_forces, + reduce="amax", + ) + return batch_wise_max_force < 1e-1 + + +next_batch = batcher.first_batch() + +# %% +all_completed_states = [] +while True: + state = concatenate_states(next_batch) + + print("Starting new batch.") + # run 10 steps, arbitrary number + for i in range(10): + state = fire_update(state) + + convergence_tensor = convergence_fn(state) + + next_batch, completed_states = batcher.next_batch(state, convergence_tensor) + + print("number of completed states", len(completed_states)) + + if not next_batch: + print("No more batches to run.") + break + + all_completed_states.extend(completed_states) + + +# %% + +nvt_init, nvt_update = nvt_langevin(model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature) + + +si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64) +fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64) + +si_nvt_state = nvt_init(si_state) +fe_nvt_state = nvt_init(fe_state) + +nvt_states = [si_nvt_state, fe_nvt_state] * 100 +nvt_states = [state.clone() for state in nvt_states] +for state in nvt_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + +batcher = ChunkingAutoBatcher( + model=mace_model, + states=nvt_states, + metric="n_atoms_x_density", + max_metric=100_000, +) + +finished_states = [] +for batch in batcher: + print(f"Starting new batch of size {len(batch)}") + full_state = concatenate_states(batch) + for _ in range(100): + + full_state = nvt_update(full_state) + + finished_states.extend(split_state(full_state)) + +# %% +len(finished_states) From a04276cee85786531fb3dfb4ec287b4c8ad897e3 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 14:23:02 +0000 Subject: [PATCH 14/36] add optimized pop state and split state utilities to states --- tests/test_state.py | 59 ++++++++++++++++++++++- torchsim/state.py | 113 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 169 insertions(+), 3 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 27b5526ba..66710989c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,6 +7,8 @@ concatenate_states, infer_property_scope, slice_substate, + split_state, + pop_states, ) from torchsim.unbatched_integrators import MDState @@ -87,7 +89,9 @@ def test_concatenate_two_si_states( assert concatenated.positions.shape == si_double_base_state.positions.shape assert concatenated.masses.shape == si_double_base_state.masses.shape assert concatenated.cell.shape == si_double_base_state.cell.shape - assert concatenated.atomic_numbers.shape == si_double_base_state.atomic_numbers.shape + assert ( + concatenated.atomic_numbers.shape == si_double_base_state.atomic_numbers.shape + ) assert concatenated.batch.shape == si_double_base_state.batch.shape # Check batch indices @@ -195,3 +199,56 @@ def test_concatenate_double_si_and_fe_states( si_slice_1.positions, slice_substate(si_double_base_state, 1).positions ) assert torch.allclose(fe_slice.positions, fe_fcc_state.positions) + + +def test_split_state(si_double_base_state: BaseState) -> None: + """Test splitting a state into a list of states.""" + states = split_state(si_double_base_state) + assert len(states) == si_double_base_state.n_batches + for state in states: + assert isinstance(state, BaseState) + assert state.positions.shape == (8, 3) + assert state.masses.shape == (8,) + assert state.cell.shape == (1, 3, 3) + assert state.atomic_numbers.shape == (8,) + assert torch.allclose(state.batch, torch.zeros_like(state.batch)) + + +def test_split_many_states( + si_base_state: BaseState, ar_base_state: BaseState, fe_fcc_state: BaseState +) -> None: + """Test splitting a state into a list of states.""" + states = [si_base_state, ar_base_state, fe_fcc_state] + concatenated = concatenate_states(states) + split_states = split_state(concatenated) + for state, sub_state in zip(states, split_states): + assert isinstance(sub_state, BaseState) + assert torch.allclose(sub_state.positions, state.positions) + assert torch.allclose(sub_state.masses, state.masses) + assert torch.allclose(sub_state.cell, state.cell) + assert torch.allclose(sub_state.atomic_numbers, state.atomic_numbers) + assert torch.allclose(sub_state.batch, state.batch) + + assert len(states) == 3 + + +def test_pop_states( + si_base_state: BaseState, ar_base_state: BaseState, fe_fcc_state: BaseState +) -> None: + """Test popping states from a state.""" + states = [si_base_state, ar_base_state, fe_fcc_state] + concatenated_states = concatenate_states(states) + kept_state, popped_states = pop_states(concatenated_states, torch.tensor([0])) + + assert isinstance(kept_state, BaseState) + assert isinstance(popped_states, list) + assert len(popped_states) == 1 + assert isinstance(popped_states[0], BaseState) + assert popped_states[0].positions.shape == si_base_state.positions.shape + + len_kept = ar_base_state.n_atoms + fe_fcc_state.n_atoms + assert kept_state.positions.shape == (len_kept, 3) + assert kept_state.masses.shape == (len_kept,) + assert kept_state.cell.shape == (2, 3, 3) + assert kept_state.atomic_numbers.shape == (len_kept,) + assert kept_state.batch.shape == (len_kept,) diff --git a/torchsim/state.py b/torchsim/state.py index 462062707..0cb41cb67 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -23,7 +23,7 @@ class ParticleState: import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, Self - +from collections import defaultdict import torch @@ -96,7 +96,9 @@ def __post_init__(self) -> None: ) if self.batch is None: - self.batch = torch.zeros(self.n_atoms, device=self.device, dtype=torch.int64) + self.batch = torch.zeros( + self.n_atoms, device=self.device, dtype=torch.int64 + ) else: # assert that batch indices are unique consecutive integers _, counts = torch.unique_consecutive(self.batch, return_counts=True) @@ -263,6 +265,111 @@ def infer_property_scope( return scope +def split_state( + state: BaseState, + ambiguous_handling: Literal["error", "globalize"] = "error", +) -> list[BaseState]: + """Split a state into a list of states, each containing a single batch element. + This also needs to be optimized.""" + # TODO: make this more efficient + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) + + batch_sizes = torch.bincount(state.batch).tolist() + + global_attrs = {} + + # Process global properties (unchanged) + for attr_name in scope["global"]: + global_attrs[attr_name] = getattr(state, attr_name) + + sliced_attrs = {} + + # Process per-atom properties (filter by batch mask) + for attr_name in scope["per_atom"]: + if attr_name == "batch": + continue + attr_value = getattr(state, attr_name) + sliced_attrs[attr_name] = torch.split(attr_value, batch_sizes, dim=0) + + # Process per-batch properties (select the specific batch) + for attr_name in scope["per_batch"]: + attr_value = getattr(state, attr_name) + sliced_attrs[attr_name] = torch.split(attr_value, 1, dim=0) + + states = [] + for i in range(state.n_batches): + state = type(state)( + batch=torch.zeros(batch_sizes[i], device=state.device, dtype=torch.int64), + **{attr_name: sliced_attrs[attr_name][i] for attr_name in sliced_attrs}, + **global_attrs, + ) + states.append(state) + + return states + + +def pop_states( + state: BaseState, + pop_indices: torch.Tensor, + ambiguous_handling: Literal["error", "globalize"] = "error", +) -> tuple[BaseState, list[BaseState]]: + """Pop off the states with masking in a way that + minimizes memory operations. We can use the mask to make the popped + and remaining states in place then split the popped states. + + Infer batchwise atomwise should also be optimized. + """ + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) + + # Process global properties (unchanged) + global_attrs = {} + for attr_name in scope["global"]: + global_attrs[attr_name] = getattr(state, attr_name) + + keep_attrs = {} + pop_attrs = {} + + # Process per-atom properties (filter by batch mask) + for attr_name in scope["per_atom"]: + keep_mask = torch.isin(state.batch, pop_indices, invert=True) + attr_value = getattr(state, attr_name) + + if attr_name == "batch": + n_popped = len(pop_indices) + n_kept = state.n_batches - n_popped + _, keep_counts = torch.unique_consecutive( + attr_value[keep_mask], return_counts=True + ) + keep_batch_indices = torch.repeat_interleave( + torch.arange(n_kept, device=state.device), keep_counts + ) + keep_attrs[attr_name] = keep_batch_indices + + _, pop_counts = torch.unique_consecutive( + attr_value[~keep_mask], return_counts=True + ) + pop_batch_indices = torch.repeat_interleave( + torch.arange(n_popped, device=state.device), pop_counts + ) + pop_attrs[attr_name] = pop_batch_indices + continue + + keep_attrs[attr_name] = attr_value[keep_mask] + pop_attrs[attr_name] = attr_value[~keep_mask] + + # Process per-batch properties (select the specific batch) + for attr_name in scope["per_batch"]: + attr_value = getattr(state, attr_name) + batch_range = torch.arange(state.n_batches, device=state.device) + keep_mask = torch.isin(batch_range, pop_indices, invert=True) + keep_attrs[attr_name] = attr_value[keep_mask] + pop_attrs[attr_name] = attr_value[~keep_mask] + + keep_state = type(state)(**keep_attrs, **global_attrs) + pop_states = split_state(type(state)(**pop_attrs, **global_attrs)) + return keep_state, pop_states + + def slice_substate( state: BaseState, batch_index: int, @@ -278,6 +385,8 @@ def slice_substate( Returns: A BaseState object containing the sliced substate """ + # TODO: should share more logic with pop_states, basically the same + # TODO: should be renamed slice_state scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) # Create a mask for the atoms in the specified batch From 1100eb10476fce1ddbf1df59c2f992cb1e836b91 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 15:59:50 +0000 Subject: [PATCH 15/36] update hot swapping autobatcher to return full state and pop states more efficiently --- tests/test_autobatching.py | 91 ++++++++++++++++++++---- torchsim/autobatching.py | 142 ++++++++++++++++++++++--------------- 2 files changed, 162 insertions(+), 71 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index e15abc616..0322fb37c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -8,10 +8,10 @@ HotswappingAutoBatcher, calculate_scaling_metric, determine_max_batch_size, - split_state, ) from torchsim.models.lennard_jones import LennardJonesModel -from torchsim.state import BaseState +from torchsim.optimizers import unit_cell_fire +from torchsim.state import BaseState, split_state def test_calculate_scaling_metric(si_base_state: BaseState) -> None: @@ -204,19 +204,24 @@ def test_hotswapping_auto_batcher( convergence = torch.tensor([True]) # Get the next batch - next_batch, idx = batcher.next_batch(convergence, return_indices=True) - assert isinstance(next_batch, list) + next_batch, popped_batch, idx = batcher.next_batch( + first_batch, convergence, return_indices=True + ) + assert isinstance(next_batch, BaseState) + assert isinstance(popped_batch, list) + assert isinstance(popped_batch[0], BaseState) assert idx == [1] # Check that the converged state was removed - assert len(batcher.current_state_metric_idx) == 1 + assert len(batcher.current_metrics) == 1 + assert len(batcher.current_idx) == 1 assert len(batcher.completed_idx_og_order) == 1 # Create a convergence tensor where the remaining state has converged convergence = torch.tensor([True]) # Get the next batch, which should be None since all states have converged - final_batch = batcher.next_batch(convergence) + final_batch, popped_batch = batcher.next_batch(next_batch, convergence) assert final_batch is None # Check that all states are marked as completed @@ -256,20 +261,21 @@ def test_hotswapping_auto_batcher_restore_order( ) # Get the first batch - batcher.first_batch() + first_batch = batcher.first_batch() # Simulate convergence of all states + completed_states_list = [] convergence = torch.tensor([True]) - batcher.next_batch(convergence) + next_batch, completed_states = batcher.next_batch(first_batch, convergence) + completed_states_list.extend(completed_states) # sample batch a second time - batcher.next_batch(convergence) - - # Create some completed states (doesn't matter what they are for this test) - completed_states = [si_base_state, fe_fcc_state] + # sample batch a second time + next_batch, completed_states = batcher.next_batch(next_batch, convergence) + completed_states_list.extend(completed_states) # Test restore_original_order - restored_states = batcher.restore_original_order(completed_states) + restored_states = batcher.restore_original_order(completed_states_list) assert len(restored_states) == 2 # Check that the restored states match the original states in order @@ -285,3 +291,62 @@ def test_hotswapping_auto_batcher_restore_order( # ValueError, match="Number of completed states .* does not match" # ): # batcher.restore_original_order([si_base_state]) + + +def test_hotswapping_with_fire( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + + fire_init, fire_update = unit_cell_fire(lj_calculator) + + si_fire_state = fire_init(si_base_state) + fe_fire_state = fire_init(fe_fcc_state) + + fire_states = [si_fire_state, fe_fire_state] * 5 + fire_states = [state.clone() for state in fire_states] + for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = HotswappingAutoBatcher( + model=lj_calculator, + states=fire_states, + metric="n_atoms", + # max_metric=400_000, + max_metric=600, + ) + + def convergence_fn(state: BaseState) -> bool: + batch_wise_max_force = torch.zeros( + state.n_batches, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + batch_wise_max_force = batch_wise_max_force.scatter_reduce( + dim=0, + index=state.batch, + src=max_forces, + reduce="amax", + ) + return batch_wise_max_force < 1e-1 + + state = batcher.first_batch() + + all_completed_states = [] + while True: + print("Starting new batch.") + # run 10 steps, arbitrary number + for i in range(10): + state = fire_update(state) + + convergence_tensor = convergence_fn(state) + + state, completed_states = batcher.next_batch(state, convergence_tensor) + + print("number of completed states", len(completed_states)) + + all_completed_states.extend(completed_states) + + if not state: + print("No more batches to run.") + break + + assert len(all_completed_states) == len(fire_states) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 8aad996a1..3f877cb84 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -10,7 +10,7 @@ from torchsim.models.interface import ModelInterface from torchsim.runners import atoms_to_state -from torchsim.state import BaseState, concatenate_states, slice_substate +from torchsim.state import BaseState, concatenate_states, pop_states, split_state def measure_model_memory_forward(model: ModelInterface, state: BaseState) -> float: @@ -73,12 +73,6 @@ def determine_max_batch_size( return fib[-2] -def split_state(state: BaseState) -> list[BaseState]: - """Split a state into a list of states, each containing a single batch element.""" - # TODO: make this more efficient - return [slice_substate(state, i) for i in range(state.n_batches)] - - def calculate_baseline_memory(model: ModelInterface) -> float: """Calculate baseline memory usage of the model. @@ -237,6 +231,10 @@ def next_batch( Returns: The next batch of states, optionally with indices, or None if no more batches. """ + # TODO: we need to refactor this to operate on the full states rather + # than the state slices, to be aligned with how the hotswapping batcher + # works. + # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here if self.current_state_bin < len(self.state_bins): @@ -270,6 +268,8 @@ def restore_original_order( # TODO: need to assert at some point that the input states list # are all batch size 1 + # TODO: should act on full states, not state slices + # Flatten lists all_states = list(chain.from_iterable(state_bins)) original_indices = list(chain.from_iterable(self.index_bins)) @@ -310,13 +310,19 @@ def __init__( self.max_metric = max_metric or None self.max_atoms_to_try = max_atoms_to_try - self.total_metric = 0 + # self.total_metric = 0 + self.current_metrics = [] + self.current_idx = [] + self.iterator_idx = 0 - self.current_state_metric_idx = [] + # self.current_metric_idx = [] self.completed_idx_og_order = [] - def _insert_next_states(self) -> None: + def _get_next_states(self) -> None: """Insert states from the iterator until max_metric is reached.""" + new_metrics = [] + new_states = [] + new_idx = [] for state in self.states_iterator: metric = calculate_scaling_metric(state, self.metric) if metric > self.max_metric: @@ -325,25 +331,34 @@ def _insert_next_states(self) -> None: f"{self.max_metric}, please set a larger max_metric " f"or run smaller systems metric." ) - if self.total_metric + metric > self.max_metric: + # new_metric += sum(new_metrics) + if sum(self.current_metrics) + sum(new_metrics) + metric > self.max_metric: # put the state back in the iterator self.states_iterator = chain([state], self.states_iterator) break - self.total_metric += metric - # TODO: could be smarter about making these all together - self.current_state_metric_idx += [(state, metric, self.iterator_idx)] - self.iterator_idx += 1 + new_metrics.append(metric) + new_states.append(state) + new_idx.append(self.iterator_idx) + # self.total_metric += metric + # self.iterator_idx += 1 - def first_batch(self) -> list[BaseState]: + return new_states, new_metrics, new_idx + + def first_batch(self) -> BaseState: """Get the first batch of states. Returns: The first batch of states. """ - # we need to estimate the max metric for the first batch + # we need to sample a state and use it to estimate the max metric + # for the first batch first_state = next(self.states_iterator) first_metric = calculate_scaling_metric(first_state, self.metric) + self.current_metrics += [first_metric] + self.current_idx += [0] + self.iterator_idx += 1 + # self.total_metric += first_metric # if max_metric is not set, estimate it has_max_metric = bool(self.max_metric) @@ -356,87 +371,96 @@ def first_batch(self) -> list[BaseState]: ) self.max_metric *= 0.8 - self.total_metric = first_metric - self.current_state_metric_idx = [(first_state, first_metric, 0)] - self.iterator_idx = 1 - - self._insert_next_states() + states, metrics, idx = self._get_next_states() + self.current_metrics += metrics + self.current_idx += idx + self.iterator_idx += len(idx) # update estimate of max metric if it was not set - current_states = [state for state, _, _ in self.current_state_metric_idx] - current_metrics = [metric for _, metric, _ in self.current_state_metric_idx] + # current_states = [state for state, _, _ in self.current_metric_idx] + # current_metrics = [metric for _, metric, _ in self.current_metric_idx] if not has_max_metric: self.max_metric = estimate_max_metric( self.model, - current_states, - current_metrics, + [first_state] + states, + metrics, max_atoms=self.max_atoms_to_try, ) print(f"Max metric calculated: {self.max_metric}") - return current_states + return concatenate_states([first_state] + states) def next_batch( self, - updated_concat_state: BaseState, + updated_state: BaseState, convergence_tensor: torch.Tensor, *, return_indices: bool = False, ) -> ( - tuple[list[BaseState], list[BaseState]] - | tuple[list[BaseState], list[BaseState], list[int]] + tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] ): """Get the next batch of states based on convergence. Args: + updated_state: The updated state. convergence_tensor: Boolean tensor indicating which states have converged. return_indices: Whether to return indices along with the batch. Returns: The next batch of states. """ - # TODO: this is bloated and we need to clean this up - # to make it more efficient - states_list = split_state(updated_concat_state) - new_current_state_metric_idx = [] - for i, state in enumerate(states_list): - new_tup = ( - state, - self.current_state_metric_idx[i][1], - self.current_state_metric_idx[i][2], - ) - new_current_state_metric_idx.append(new_tup) - self.current_state_metric_idx = new_current_state_metric_idx + # TODO: this needs to be refactored to avoid so + # many split and concatenate operations, we should + # take the updated_concat_state and pop off + # the states that have converged. with the pop_states function - assert len(convergence_tensor) == len(self.current_state_metric_idx) + assert len(convergence_tensor) == len(self.current_metrics) + assert len(self.current_idx) == len(self.current_metrics) assert len(convergence_tensor.shape) == 1 # find indices of all convergence_tensor elements that are True - completed_idx = list(torch.where(convergence_tensor)[0]) + completed_idx_tensor = torch.where(convergence_tensor)[0] + completed_idx = completed_idx_tensor.tolist() # Sort in descending order to avoid index shifting problems completed_idx.sort(reverse=True) - # remove states at these indices - completed_states = [] + # update state tracking lists for idx in completed_idx: - state, metric, _ = self.current_state_metric_idx.pop(idx) - completed_states.append(state) - self.total_metric -= metric - self.completed_idx_og_order.append(idx + len(self.completed_idx_og_order)) + og_idx = self.current_idx.pop(idx) + self.current_metrics.pop(idx) + # self.total_metric -= metric + self.completed_idx_og_order.append( + og_idx + len(self.completed_idx_og_order) + ) - # insert next states - self._insert_next_states() + # pop completed states from updated state + assert updated_state.n_batches > 0 + remaining_state, completed_states = pop_states( + updated_state, completed_idx_tensor + ) - if not self.current_state_metric_idx: - return [], [] + # insert next states + next_states, metrics, idx = self._get_next_states() + self.current_metrics += metrics + self.current_idx += idx + self.iterator_idx += len(idx) + + if not self.current_idx: + return ( + (None, completed_states, []) + if return_indices + else (None, completed_states) + ) - current_states = [state for state, _, _ in self.current_state_metric_idx] + # concatenate remaining state with next states + if remaining_state.n_batches > 0: + next_states = [remaining_state] + next_states + next_batch = concatenate_states(next_states) if return_indices: - current_idx = [idx for _, _, idx in self.current_state_metric_idx] - return current_states, current_idx + return next_batch, completed_states, self.current_idx - return current_states, completed_states + return next_batch, completed_states def restore_original_order( self, completed_states: list[BaseState] @@ -453,6 +477,8 @@ def restore_original_order( ValueError: If the number of completed states doesn't match the number of indices. """ + # TODO: should act on full states, not state slices + if len(completed_states) != len(self.completed_idx_og_order): raise ValueError( f"Number of completed states ({len(completed_states)}) does not match " From 4ddb5c64abb9a4d76448c29efc99a806eb571aa2 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 21:09:48 +0000 Subject: [PATCH 16/36] add improved and more efficient pop_states and split states methods --- torchsim/state.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchsim/state.py b/torchsim/state.py index 0cb41cb67..557981b37 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -319,6 +319,9 @@ def pop_states( Infer batchwise atomwise should also be optimized. """ + if len(pop_indices) == 0: + return state, [] + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) # Process global properties (unchanged) From 0e21fd4c5784c39ee19ddca989fc096e6a761819 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 21:10:11 +0000 Subject: [PATCH 17/36] finish hotswapping autobatching function --- tests/test_autobatching.py | 8 ++-- torchsim/autobatching.py | 82 +++++++++++++++++++------------------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 0322fb37c..69199f07d 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -178,7 +178,7 @@ def test_hotswapping_max_metric_too_small( # Get the first batch with pytest.raises(ValueError, match="is greater than max_metric"): - batcher.first_batch() + batcher._first_batch() def test_hotswapping_auto_batcher( @@ -197,7 +197,7 @@ def test_hotswapping_auto_batcher( ) # Get the first batch - first_batch = batcher.first_batch() + first_batch = batcher._first_batch() assert isinstance(first_batch, BaseState) # Create a convergence tensor where the first state has converged @@ -261,7 +261,7 @@ def test_hotswapping_auto_batcher_restore_order( ) # Get the first batch - first_batch = batcher.first_batch() + first_batch = batcher._first_batch() # Simulate convergence of all states completed_states_list = [] @@ -328,7 +328,7 @@ def convergence_fn(state: BaseState) -> bool: ) return batch_wise_max_force < 1e-1 - state = batcher.first_batch() + state = batcher._first_batch() all_completed_states = [] while True: diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 3f877cb84..638f0e434 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -310,19 +310,17 @@ def __init__( self.max_metric = max_metric or None self.max_atoms_to_try = max_atoms_to_try - # self.total_metric = 0 self.current_metrics = [] self.current_idx = [] self.iterator_idx = 0 - # self.current_metric_idx = [] self.completed_idx_og_order = [] def _get_next_states(self) -> None: """Insert states from the iterator until max_metric is reached.""" new_metrics = [] - new_states = [] new_idx = [] + new_states = [] for state in self.states_iterator: metric = calculate_scaling_metric(state, self.metric) if metric > self.max_metric: @@ -338,14 +336,31 @@ def _get_next_states(self) -> None: break new_metrics.append(metric) - new_states.append(state) new_idx.append(self.iterator_idx) - # self.total_metric += metric - # self.iterator_idx += 1 + new_states.append(state) + self.iterator_idx += 1 + + self.current_metrics.extend(new_metrics) + self.current_idx.extend(new_idx) + + return new_states + + def _delete_old_states(self, completed_idx: torch.Tensor) -> None: + completed_idx = completed_idx.tolist() + + # Sort in descending order to avoid index shifting problems + completed_idx.sort(reverse=True) - return new_states, new_metrics, new_idx + # update state tracking lists + for idx in completed_idx: + og_idx = self.current_idx.pop(idx) + self.current_metrics.pop(idx) + # self.total_metric -= metric + self.completed_idx_og_order.append( + og_idx + len(self.completed_idx_og_order) + ) - def first_batch(self) -> BaseState: + def _first_batch(self) -> BaseState: """Get the first batch of states. Returns: @@ -371,28 +386,22 @@ def first_batch(self) -> BaseState: ) self.max_metric *= 0.8 - states, metrics, idx = self._get_next_states() - self.current_metrics += metrics - self.current_idx += idx - self.iterator_idx += len(idx) + states = self._get_next_states() - # update estimate of max metric if it was not set - # current_states = [state for state, _, _ in self.current_metric_idx] - # current_metrics = [metric for _, metric, _ in self.current_metric_idx] if not has_max_metric: self.max_metric = estimate_max_metric( self.model, [first_state] + states, - metrics, + self.current_metrics, max_atoms=self.max_atoms_to_try, ) print(f"Max metric calculated: {self.max_metric}") - return concatenate_states([first_state] + states) + return concatenate_states([first_state] + states), [] def next_batch( self, updated_state: BaseState, - convergence_tensor: torch.Tensor, + convergence_tensor: torch.Tensor | None = None, *, return_indices: bool = False, ) -> ( @@ -413,38 +422,31 @@ def next_batch( # take the updated_concat_state and pop off # the states that have converged. with the pop_states function - assert len(convergence_tensor) == len(self.current_metrics) + if convergence_tensor is None: + if self.iterator_idx > 0: + raise ValueError( + "A convergence tensor must be provided after the " + "first batch has been run." + ) + return self._first_batch() + + # assert statements helpful for debugging, should be moved to validate fn + # the first two are most important + assert len(convergence_tensor) == updated_state.n_batches assert len(self.current_idx) == len(self.current_metrics) assert len(convergence_tensor.shape) == 1 + assert updated_state.n_batches > 0 - # find indices of all convergence_tensor elements that are True completed_idx_tensor = torch.where(convergence_tensor)[0] - completed_idx = completed_idx_tensor.tolist() - # Sort in descending order to avoid index shifting problems - completed_idx.sort(reverse=True) - - # update state tracking lists - for idx in completed_idx: - og_idx = self.current_idx.pop(idx) - self.current_metrics.pop(idx) - # self.total_metric -= metric - self.completed_idx_og_order.append( - og_idx + len(self.completed_idx_og_order) - ) - - # pop completed states from updated state - assert updated_state.n_batches > 0 remaining_state, completed_states = pop_states( updated_state, completed_idx_tensor ) - # insert next states - next_states, metrics, idx = self._get_next_states() - self.current_metrics += metrics - self.current_idx += idx - self.iterator_idx += len(idx) + self._delete_old_states(completed_idx_tensor) + next_states = self._get_next_states() + # there are no states left to run, return the completed states if not self.current_idx: return ( (None, completed_states, []) From 6d8e68ce3b69c44bcbdad3f8fd67407ed74018c9 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 21:38:15 +0000 Subject: [PATCH 18/36] make pop_states take list of ints --- torchsim/state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchsim/state.py b/torchsim/state.py index 557981b37..2c904ce4e 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -310,7 +310,7 @@ def split_state( def pop_states( state: BaseState, - pop_indices: torch.Tensor, + pop_indices: list[int], ambiguous_handling: Literal["error", "globalize"] = "error", ) -> tuple[BaseState, list[BaseState]]: """Pop off the states with masking in a way that @@ -322,6 +322,8 @@ def pop_states( if len(pop_indices) == 0: return state, [] + pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) + scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) # Process global properties (unchanged) From 468e561691b7ba46b9f9704349645a144cf88243 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 11 Mar 2025 21:38:29 +0000 Subject: [PATCH 19/36] update autobatching and example --- .../4_High_level_api/4_2_auto_batching_api.py | 64 +++++++++++-------- torchsim/autobatching.py | 52 +++++++-------- 2 files changed, 64 insertions(+), 52 deletions(-) diff --git a/examples/4_High_level_api/4_2_auto_batching_api.py b/examples/4_High_level_api/4_2_auto_batching_api.py index 88adaf2d6..21ac460d2 100644 --- a/examples/4_High_level_api/4_2_auto_batching_api.py +++ b/examples/4_High_level_api/4_2_auto_batching_api.py @@ -7,7 +7,11 @@ from torchsim.models.mace import MaceModel from torchsim.optimizers import unit_cell_fire from torchsim.runners import atoms_to_state -from torchsim.autobatching import HotswappingAutoBatcher, ChunkingAutoBatcher, split_state +from torchsim.autobatching import ( + HotswappingAutoBatcher, + ChunkingAutoBatcher, + split_state, +) from torchsim.units import MetalUnits from torchsim.state import concatenate_states, BaseState @@ -34,21 +38,14 @@ si_fire_state = fire_init(si_state) fe_fire_state = fire_init(fe_state) -fire_states = [si_fire_state, fe_fire_state] * 100 +fire_states = [si_fire_state, fe_fire_state] * 20 fire_states = [state.clone() for state in fire_states] for state in fire_states: state.positions += torch.randn_like(state.positions) * 0.01 - +len(fire_states) # %% -batcher = HotswappingAutoBatcher( - model=mace_model, - states=fire_states, - metric="n_atoms_x_density", - # max_metric=400_000, - max_metric=100_000, -) def convergence_fn(state: BaseState) -> bool: batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) @@ -62,34 +59,43 @@ def convergence_fn(state: BaseState) -> bool: return batch_wise_max_force < 1e-1 -next_batch = batcher.first_batch() - # %% -all_completed_states = [] +batcher = HotswappingAutoBatcher( + model=mace_model, + states=fire_states, + metric="n_atoms_x_density", + max_metric=400_000, + # max_metric=400_000, +) + +all_completed_states, convergence_tensor = [], None while True: - state = concatenate_states(next_batch) + print(f"Starting new batch of {state.n_batches} states.") + + state, completed_states = batcher.next_batch(state, convergence_tensor) + print("Number of completed states", len(completed_states)) + + all_completed_states.extend(completed_states) + if state is None: + break - print("Starting new batch.") # run 10 steps, arbitrary number for i in range(10): state = fire_update(state) - convergence_tensor = convergence_fn(state) - next_batch, completed_states = batcher.next_batch(state, convergence_tensor) - - print("number of completed states", len(completed_states)) - - if not next_batch: - print("No more batches to run.") - break - all_completed_states.extend(completed_states) +# %% +batcher.restore_original_order(all_completed_states) +# %% +sorted(batcher.completed_idx_og_order) # %% -nvt_init, nvt_update = nvt_langevin(model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature) +nvt_init, nvt_update = nvt_langevin( + model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature +) si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64) @@ -113,7 +119,6 @@ def convergence_fn(state: BaseState) -> bool: finished_states = [] for batch in batcher: - print(f"Starting new batch of size {len(batch)}") full_state = concatenate_states(batch) for _ in range(100): @@ -123,3 +128,10 @@ def convergence_fn(state: BaseState) -> bool: # %% len(finished_states) + + +# %% +t = torch.tensor([1, 1, 3, 3, 3, 3]) +torch.bincount(t) +_, counts = torch.unique_consecutive(t, return_counts=True) +counts diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 638f0e434..c0288bf0f 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -184,7 +184,7 @@ def __init__( states: States to batch. metric: Metric to use for batching. max_metric: Maximum metric value per batch. - max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. + max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. """ self.state_slices = ( split_state(states) if isinstance(states, BaseState) else states @@ -215,14 +215,14 @@ def __init__( self.index_bins = binpacking.to_constant_volume( self.index_to_metric, V_max=self.max_metric ) - self.state_bins = [] + self.batched_states = [] for index_bin in self.index_bins: - self.state_bins.append([self.state_slices[i] for i in index_bin]) + self.batched_states.append([self.state_slices[i] for i in index_bin]) self.current_state_bin = 0 def next_batch( self, *, return_indices: bool = False - ) -> list[BaseState] | tuple[list[BaseState], list[int]] | None: + ) -> BaseState | tuple[list[BaseState], list[int]] | None: """Get the next batch of states. Args: @@ -237,12 +237,13 @@ def next_batch( # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here - if self.current_state_bin < len(self.state_bins): - state_bin = self.state_bins[self.current_state_bin] + if self.current_state_bin < len(self.batched_states): + state_bin = self.batched_states[self.current_state_bin] + state = concatenate_states(state_bin) self.current_state_bin += 1 if return_indices: - return state_bin, self.index_bins[self.current_state_bin - 1] - return state_bin + return state, self.index_bins[self.current_state_bin - 1] + return state return None def __iter__(self): @@ -255,28 +256,31 @@ def __next__(self): return next_batch def restore_original_order( - self, state_bins: list[list[BaseState]] + self, batched_states: list[BaseState] ) -> list[BaseState]: """Take the state bins and reorder them into a list. Args: - state_bins: List of state batches to reorder. + batched_states: List of state batches to reorder. Returns: States in their original order. """ - # TODO: need to assert at some point that the input states list - # are all batch size 1 - - # TODO: should act on full states, not state slices + state_bins = [split_state(state) for state in batched_states] # Flatten lists all_states = list(chain.from_iterable(state_bins)) original_indices = list(chain.from_iterable(self.index_bins)) + if len(all_states) != len(original_indices): + raise ValueError( + f"Number of states ({len(all_states)}) does not match " + f"number of original indices ({len(original_indices)})" + ) + # sort states by original indices indexed_states = list(zip(original_indices, all_states, strict=False)) - return [state for _, state in sorted(indexed_states)] + return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] class HotswappingAutoBatcher: @@ -345,9 +349,7 @@ def _get_next_states(self) -> None: return new_states - def _delete_old_states(self, completed_idx: torch.Tensor) -> None: - completed_idx = completed_idx.tolist() - + def _delete_old_states(self, completed_idx: list[int]) -> None: # Sort in descending order to avoid index shifting problems completed_idx.sort(reverse=True) @@ -355,10 +357,7 @@ def _delete_old_states(self, completed_idx: torch.Tensor) -> None: for idx in completed_idx: og_idx = self.current_idx.pop(idx) self.current_metrics.pop(idx) - # self.total_metric -= metric - self.completed_idx_og_order.append( - og_idx + len(self.completed_idx_og_order) - ) + self.completed_idx_og_order.append(og_idx) def _first_batch(self) -> BaseState: """Get the first batch of states. @@ -437,13 +436,14 @@ def next_batch( assert len(convergence_tensor.shape) == 1 assert updated_state.n_batches > 0 - completed_idx_tensor = torch.where(convergence_tensor)[0] + completed_idx = torch.where(convergence_tensor)[0].tolist() + completed_idx.sort(reverse=True) remaining_state, completed_states = pop_states( - updated_state, completed_idx_tensor + updated_state, completed_idx ) - self._delete_old_states(completed_idx_tensor) + self._delete_old_states(completed_idx) next_states = self._get_next_states() # there are no states left to run, return the completed states @@ -493,4 +493,4 @@ def restore_original_order( ) # Sort by original index - return [state for _, state in sorted(indexed_states)] + return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] From 69ffc285fc3b924e3da7726f830fab7a5c674a42 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 11 Mar 2025 18:59:54 -0400 Subject: [PATCH 20/36] add binpacking>=1.5.2 to pkg deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 23b2b43ae..54e72b592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ requires-python = ">=3.11" dependencies = [ "ase>=3.24", + "binpacking>=1.5.2", "h5py>=3.12.1", "numpy>=1.26", "tables>=3.10.2", From bac08612001c87168c53b3bea9394c140642180e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 11 Mar 2025 19:07:45 -0400 Subject: [PATCH 21/36] lint --- ...tching_api.py => 4.2_auto_batching_api.py} | 32 +++++++++++-------- tests/test_autobatching.py | 3 +- tests/test_state.py | 8 ++--- torchsim/autobatching.py | 20 ++++-------- torchsim/runners.py | 1 + torchsim/state.py | 9 +++--- 6 files changed, 35 insertions(+), 38 deletions(-) rename examples/4_High_level_api/{4_2_auto_batching_api.py => 4.2_auto_batching_api.py} (90%) diff --git a/examples/4_High_level_api/4_2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py similarity index 90% rename from examples/4_High_level_api/4_2_auto_batching_api.py rename to examples/4_High_level_api/4.2_auto_batching_api.py index 21ac460d2..069322b05 100644 --- a/examples/4_High_level_api/4_2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -1,25 +1,29 @@ +# /// script +# dependencies = [ +# "mace-torch>=0.3.10", +# "pymatgen>=2025.2.18", +# ] +# /// + + # %% import torch from ase.build import bulk from mace.calculators.foundations_models import mace_mp +from torchsim.autobatching import ChunkingAutoBatcher, HotswappingAutoBatcher, split_state from torchsim.integrators import nvt_langevin from torchsim.models.mace import MaceModel from torchsim.optimizers import unit_cell_fire from torchsim.runners import atoms_to_state -from torchsim.autobatching import ( - HotswappingAutoBatcher, - ChunkingAutoBatcher, - split_state, -) +from torchsim.state import BaseState, concatenate_states from torchsim.units import MetalUnits -from torchsim.state import concatenate_states, BaseState si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) -device = torch.device("cuda") +device = torch.device("cpu") mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel( @@ -44,9 +48,9 @@ state.positions += torch.randn_like(state.positions) * 0.01 len(fire_states) -# %% +# %% def convergence_fn(state: BaseState) -> bool: batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) max_forces = state.forces.norm(dim=1) @@ -80,19 +84,20 @@ def convergence_fn(state: BaseState) -> bool: break # run 10 steps, arbitrary number - for i in range(10): + for _step in range(10): state = fire_update(state) convergence_tensor = convergence_fn(state) - # %% batcher.restore_original_order(all_completed_states) + + # %% sorted(batcher.completed_idx_og_order) -# %% +# %% nvt_init, nvt_update = nvt_langevin( model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature ) @@ -121,11 +126,11 @@ def convergence_fn(state: BaseState) -> bool: for batch in batcher: full_state = concatenate_states(batch) for _ in range(100): - full_state = nvt_update(full_state) finished_states.extend(split_state(full_state)) + # %% len(finished_states) @@ -134,4 +139,5 @@ def convergence_fn(state: BaseState) -> bool: t = torch.tensor([1, 1, 3, 3, 3, 3]) torch.bincount(t) _, counts = torch.unique_consecutive(t, return_counts=True) -counts + +print(f"{counts=}") diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 69199f07d..71affdc59 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -296,7 +296,6 @@ def test_hotswapping_auto_batcher_restore_order( def test_hotswapping_with_fire( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: - fire_init, fire_update = unit_cell_fire(lj_calculator) si_fire_state = fire_init(si_base_state) @@ -334,7 +333,7 @@ def convergence_fn(state: BaseState) -> bool: while True: print("Starting new batch.") # run 10 steps, arbitrary number - for i in range(10): + for _step in range(10): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/tests/test_state.py b/tests/test_state.py index 66710989c..e77fa30fe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -6,9 +6,9 @@ BaseState, concatenate_states, infer_property_scope, + pop_states, slice_substate, split_state, - pop_states, ) from torchsim.unbatched_integrators import MDState @@ -89,9 +89,7 @@ def test_concatenate_two_si_states( assert concatenated.positions.shape == si_double_base_state.positions.shape assert concatenated.masses.shape == si_double_base_state.masses.shape assert concatenated.cell.shape == si_double_base_state.cell.shape - assert ( - concatenated.atomic_numbers.shape == si_double_base_state.atomic_numbers.shape - ) + assert concatenated.atomic_numbers.shape == si_double_base_state.atomic_numbers.shape assert concatenated.batch.shape == si_double_base_state.batch.shape # Check batch indices @@ -221,7 +219,7 @@ def test_split_many_states( states = [si_base_state, ar_base_state, fe_fcc_state] concatenated = concatenate_states(states) split_states = split_state(concatenated) - for state, sub_state in zip(states, split_states): + for state, sub_state in zip(states, split_states, strict=True): assert isinstance(sub_state, BaseState) assert torch.allclose(sub_state.positions, state.positions) assert torch.allclose(sub_state.masses, state.masses) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index c0288bf0f..7dfbea4e8 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -184,7 +184,7 @@ def __init__( states: States to batch. metric: Metric to use for batching. max_metric: Maximum metric value per batch. - max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. + max_atoms_to_try: Max number of atoms to try when estimating max_metric. """ self.state_slices = ( split_state(states) if isinstance(states, BaseState) else states @@ -255,9 +255,7 @@ def __next__(self): raise StopIteration return next_batch - def restore_original_order( - self, batched_states: list[BaseState] - ) -> list[BaseState]: + def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: """Take the state bins and reorder them into a list. Args: @@ -390,12 +388,12 @@ def _first_batch(self) -> BaseState: if not has_max_metric: self.max_metric = estimate_max_metric( self.model, - [first_state] + states, + [first_state, *states], self.current_metrics, max_atoms=self.max_atoms_to_try, ) print(f"Max metric calculated: {self.max_metric}") - return concatenate_states([first_state] + states), [] + return concatenate_states([first_state, *states]), [] def next_batch( self, @@ -403,9 +401,7 @@ def next_batch( convergence_tensor: torch.Tensor | None = None, *, return_indices: bool = False, - ) -> ( - tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] - ): + ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: """Get the next batch of states based on convergence. Args: @@ -439,9 +435,7 @@ def next_batch( completed_idx = torch.where(convergence_tensor)[0].tolist() completed_idx.sort(reverse=True) - remaining_state, completed_states = pop_states( - updated_state, completed_idx - ) + remaining_state, completed_states = pop_states(updated_state, completed_idx) self._delete_old_states(completed_idx) next_states = self._get_next_states() @@ -456,7 +450,7 @@ def next_batch( # concatenate remaining state with next states if remaining_state.n_batches > 0: - next_states = [remaining_state] + next_states + next_states = [remaining_state, *next_states] next_batch = concatenate_states(next_states) if return_indices: diff --git a/torchsim/runners.py b/torchsim/runners.py index 6af25f26a..eafe864c5 100644 --- a/torchsim/runners.py +++ b/torchsim/runners.py @@ -125,6 +125,7 @@ def optimize( """ # TODO: document this behavior if convergence_fn is None: + def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: return last_energy - state.energy < 1e-4 * unit_system.energy diff --git a/torchsim/state.py b/torchsim/state.py index 2c904ce4e..7f6604b2d 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -23,7 +23,7 @@ class ParticleState: import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, Self -from collections import defaultdict + import torch @@ -96,9 +96,7 @@ def __post_init__(self) -> None: ) if self.batch is None: - self.batch = torch.zeros( - self.n_atoms, device=self.device, dtype=torch.int64 - ) + self.batch = torch.zeros(self.n_atoms, device=self.device, dtype=torch.int64) else: # assert that batch indices are unique consecutive integers _, counts = torch.unique_consecutive(self.batch, return_counts=True) @@ -270,7 +268,8 @@ def split_state( ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[BaseState]: """Split a state into a list of states, each containing a single batch element. - This also needs to be optimized.""" + This also needs to be optimized. + """ # TODO: make this more efficient scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) From fcc6da1b7078faea0e96a3e9da0cbd76c3e5302a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 11 Mar 2025 19:07:52 -0400 Subject: [PATCH 22/36] .gitignore logging and model checkpoint --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index f9d99f49f..d75c18368 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,10 @@ __pycache__ build/ dist/ *.egg-info + +# logging +*.log +*log.txt + +# model checkpoints +*.model From ebae9fb2fa53f8c4de40f71c06dc2664fb8428e5 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 12 Mar 2025 17:45:43 +0000 Subject: [PATCH 23/36] fix testing and add return indices to chunking auto batcher --- tests/test_autobatching.py | 46 ++++++++++++++------------------------ torchsim/autobatching.py | 4 +++- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 71affdc59..f0c054f69 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -71,14 +71,11 @@ def test_chunking_auto_batcher( # Get batches until None is returned batches = [] - while True: - batch = batcher.next_batch() - if batch is None: - break + for batch in batcher: batches.append(batch) # Check we got the expected number of batches - assert len(batches) == len(batcher.state_bins) + assert len(batches) == len(batcher.batched_states) # Test restore_original_order restored_states = batcher.restore_original_order(batches) @@ -100,20 +97,16 @@ def test_chunking_auto_batcher_with_indices( states = [si_base_state, fe_fcc_state] batcher = ChunkingAutoBatcher( - model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0 + model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0, return_indices=True, ) # Get batches with indices batches_with_indices = [] - while True: - result = batcher.next_batch(return_indices=True) - if result is None: - break - batch, indices = result + for batch, indices in batcher: batches_with_indices.append((batch, indices)) # Check we got the expected number of batches - assert len(batches_with_indices) == len(batcher.state_bins) + assert len(batches_with_indices) == len(batcher.batched_states) # Check that the indices match the expected bin indices for i, (_, indices) in enumerate(batches_with_indices): @@ -197,7 +190,7 @@ def test_hotswapping_auto_batcher( ) # Get the first batch - first_batch = batcher._first_batch() + first_batch, [] = batcher.next_batch(states, None) assert isinstance(first_batch, BaseState) # Create a convergence tensor where the first state has converged @@ -261,7 +254,7 @@ def test_hotswapping_auto_batcher_restore_order( ) # Get the first batch - first_batch = batcher._first_batch() + first_batch, [] = batcher.next_batch(states, None) # Simulate convergence of all states completed_states_list = [] @@ -325,27 +318,22 @@ def convergence_fn(state: BaseState) -> bool: src=max_forces, reduce="amax", ) - return batch_wise_max_force < 1e-1 + return batch_wise_max_force < 5e-1 - state = batcher._first_batch() - - all_completed_states = [] + all_completed_states, convergence_tensor = [], None while True: - print("Starting new batch.") - # run 10 steps, arbitrary number - for _step in range(10): - state = fire_update(state) - - convergence_tensor = convergence_fn(state) + print(f"Starting new batch of {state.n_batches} states.") state, completed_states = batcher.next_batch(state, convergence_tensor) - - print("number of completed states", len(completed_states)) + print("Number of completed states", len(completed_states)) all_completed_states.extend(completed_states) - - if not state: - print("No more batches to run.") + if state is None: break + # run 10 steps, arbitrary number + for i in range(10): + state = fire_update(state) + convergence_tensor = convergence_fn(state) + assert len(all_completed_states) == len(fire_states) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 7dfbea4e8..b2708268a 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -176,6 +176,7 @@ def __init__( metric: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_metric: float | None = None, max_atoms_to_try: int = 1_000_000, + return_indices: bool = False, ) -> None: """Initialize the batcher. @@ -201,6 +202,7 @@ def __init__( else: self.max_metric = max_metric + self.return_indices = return_indices # verify that no systems are too large max_metric_value = max(self.metrics) max_metric_idx = self.metrics.index(max_metric_value) @@ -250,7 +252,7 @@ def __iter__(self): return self def __next__(self): - next_batch = self.next_batch() + next_batch = self.next_batch(return_indices=self.return_indices) if next_batch is None: raise StopIteration return next_batch From e5a8f2dca8115a9ff83f1483f04e5526d30d69a7 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 12 Mar 2025 21:56:26 +0000 Subject: [PATCH 24/36] tighten convergence on runners --- torchsim/runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsim/runners.py b/torchsim/runners.py index eafe864c5..9e0005ac0 100644 --- a/torchsim/runners.py +++ b/torchsim/runners.py @@ -127,7 +127,7 @@ def optimize( if convergence_fn is None: def convergence_fn(state: BaseState, last_energy: torch.Tensor) -> bool: - return last_energy - state.energy < 1e-4 * unit_system.energy + return last_energy - state.energy < 1e-6 * unit_system.energy # we partially evaluate the function to create a new function with # an optional second argument, this can be set to state later on From cc3f1914ed4dd0acbe734fc1a2d7c19f7b851b0a Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 12 Mar 2025 22:15:08 +0000 Subject: [PATCH 25/36] finish chunking tests --- .../4_High_level_api/4.2_auto_batching_api.py | 14 +++++----- tests/test_autobatching.py | 27 +++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index 069322b05..dac5c709a 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -11,7 +11,11 @@ from ase.build import bulk from mace.calculators.foundations_models import mace_mp -from torchsim.autobatching import ChunkingAutoBatcher, HotswappingAutoBatcher, split_state +from torchsim.autobatching import ( + ChunkingAutoBatcher, + HotswappingAutoBatcher, + split_state, +) from torchsim.integrators import nvt_langevin from torchsim.models.mace import MaceModel from torchsim.optimizers import unit_cell_fire @@ -63,7 +67,6 @@ def convergence_fn(state: BaseState) -> bool: return batch_wise_max_force < 1e-1 -# %% batcher = HotswappingAutoBatcher( model=mace_model, states=fire_states, @@ -109,7 +112,7 @@ def convergence_fn(state: BaseState) -> bool: si_nvt_state = nvt_init(si_state) fe_nvt_state = nvt_init(fe_state) -nvt_states = [si_nvt_state, fe_nvt_state] * 100 +nvt_states = [si_nvt_state, fe_nvt_state] * 5 nvt_states = [state.clone() for state in nvt_states] for state in nvt_states: state.positions += torch.randn_like(state.positions) * 0.01 @@ -124,11 +127,10 @@ def convergence_fn(state: BaseState) -> bool: finished_states = [] for batch in batcher: - full_state = concatenate_states(batch) for _ in range(100): - full_state = nvt_update(full_state) + batch = nvt_update(batch) - finished_states.extend(split_state(full_state)) + finished_states.extend(split_state(batch)) # %% diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index f0c054f69..aa1eca4ba 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -337,3 +337,30 @@ def convergence_fn(state: BaseState) -> bool: convergence_tensor = convergence_fn(state) assert len(all_completed_states) == len(fire_states) + +def test_chunking_auto_batcher_with_fire( + si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel +) -> None: + fire_init, fire_update = unit_cell_fire(lj_calculator) + + si_fire_state = fire_init(si_base_state) + fe_fire_state = fire_init(fe_fcc_state) + + fire_states = [si_fire_state, fe_fire_state] * 5 + fire_states = [state.clone() for state in fire_states] + for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = ChunkingAutoBatcher( + model=lj_calculator, + states=fire_states, + metric="n_atoms", + max_metric=400, + ) + + finished_states = [] + for batch in batcher: + for _ in range(100): + batch = fire_update(batch) + + finished_states.extend(split_state(batch)) From 3e54f4db9fb2223f1054b8478e0ecb5650b7ac37 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 14:15:37 +0000 Subject: [PATCH 26/36] rename metric and associated vars --- .../4_High_level_api/4.2_auto_batching_api.py | 8 +- tests/test_autobatching.py | 63 ++++++++----- torchsim/autobatching.py | 93 ++++++++++--------- 3 files changed, 92 insertions(+), 72 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index dac5c709a..4f945fe5a 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -70,8 +70,8 @@ def convergence_fn(state: BaseState) -> bool: batcher = HotswappingAutoBatcher( model=mace_model, states=fire_states, - metric="n_atoms_x_density", - max_metric=400_000, + memory_scales_with="n_atoms_x_density", + max_memory_scaler=400_000, # max_metric=400_000, ) @@ -121,8 +121,8 @@ def convergence_fn(state: BaseState) -> bool: batcher = ChunkingAutoBatcher( model=mace_model, states=nvt_states, - metric="n_atoms_x_density", - max_metric=100_000, + memory_scales_with="n_atoms_x_density", + max_memory_scaler=100_000, ) finished_states = [] diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index aa1eca4ba..1bae2038b 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -58,16 +58,16 @@ def test_chunking_auto_batcher( # Initialize the batcher with a fixed max_metric to avoid GPU memory testing batcher = ChunkingAutoBatcher( - model=lj_calculator, states=states, - metric="n_atoms", - max_metric=260.0, # Set a small value to force multiple batches + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, # Set a small value to force multiple batches ) # Check that the batcher correctly identified the metrics - assert len(batcher.metrics) == 2 - assert batcher.metrics[0] == si_base_state.n_atoms - assert batcher.metrics[1] == fe_fcc_state.n_atoms + assert len(batcher.memory_scalers) == 2 + assert batcher.memory_scalers[0] == si_base_state.n_atoms + assert batcher.memory_scalers[1] == fe_fcc_state.n_atoms # Get batches until None is returned batches = [] @@ -97,7 +97,11 @@ def test_chunking_auto_batcher_with_indices( states = [si_base_state, fe_fcc_state] batcher = ChunkingAutoBatcher( - model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0, return_indices=True, + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + return_indices=True, ) # Get batches with indices @@ -122,10 +126,10 @@ def test_chunking_auto_batcher_restore_order_with_split_states( # Initialize the batcher with a fixed max_metric to avoid GPU memory testing batcher = ChunkingAutoBatcher( - model=lj_calculator, states=states, - metric="n_atoms", - max_metric=260.0, # Set a small value to force multiple batches + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, # Set a small value to force multiple batches ) # Get batches until None is returned @@ -163,10 +167,10 @@ def test_hotswapping_max_metric_too_small( # Initialize the batcher with a fixed max_metric batcher = HotswappingAutoBatcher( - model=lj_calculator, states=states, - metric="n_atoms", - max_metric=1.0, # Set a small value to force multiple batches + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=1.0, # Set a small value to force multiple batches ) # Get the first batch @@ -183,10 +187,10 @@ def test_hotswapping_auto_batcher( # Initialize the batcher with a fixed max_metric batcher = HotswappingAutoBatcher( - model=lj_calculator, states=states, - metric="n_atoms", - max_metric=260, # Set a small value to force multiple batches + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260, # Set a small value to force multiple batches ) # Get the first batch @@ -206,7 +210,7 @@ def test_hotswapping_auto_batcher( assert idx == [1] # Check that the converged state was removed - assert len(batcher.current_metrics) == 1 + assert len(batcher.current_scalers) == 1 assert len(batcher.current_idx) == 1 assert len(batcher.completed_idx_og_order) == 1 @@ -250,7 +254,10 @@ def test_hotswapping_auto_batcher_restore_order( states = [si_base_state, fe_fcc_state] batcher = HotswappingAutoBatcher( - model=lj_calculator, states=states, metric="n_atoms", max_metric=260.0 + states=states, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, ) # Get the first batch @@ -300,11 +307,11 @@ def test_hotswapping_with_fire( state.positions += torch.randn_like(state.positions) * 0.01 batcher = HotswappingAutoBatcher( - model=lj_calculator, states=fire_states, - metric="n_atoms", + model=lj_calculator, + memory_scales_with="n_atoms", # max_metric=400_000, - max_metric=600, + max_memory_scaler=600, ) def convergence_fn(state: BaseState) -> bool: @@ -338,6 +345,7 @@ def convergence_fn(state: BaseState) -> bool: assert len(all_completed_states) == len(fire_states) + def test_chunking_auto_batcher_with_fire( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: @@ -352,15 +360,20 @@ def test_chunking_auto_batcher_with_fire( state.positions += torch.randn_like(state.positions) * 0.01 batcher = ChunkingAutoBatcher( - model=lj_calculator, states=fire_states, - metric="n_atoms", - max_metric=400, + model=lj_calculator, + memory_scales_with="n_atoms", + max_memory_scaler=400, ) finished_states = [] for batch in batcher: - for _ in range(100): + for _ in range(5): batch = fire_update(batch) finished_states.extend(split_state(batch)) + + restored_states = batcher.restore_original_order(finished_states) + assert len(restored_states) == len(fire_states) + for restored, original in zip(restored_states, fire_states): + assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index b2708268a..49ad9b293 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -171,10 +171,12 @@ class ChunkingAutoBatcher: def __init__( self, - model: ModelInterface, states: list[BaseState] | BaseState, - metric: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", - max_metric: float | None = None, + model: ModelInterface, + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", + max_memory_scaler: float | None = None, max_atoms_to_try: int = 1_000_000, return_indices: bool = False, ) -> None: @@ -190,32 +192,32 @@ def __init__( self.state_slices = ( split_state(states) if isinstance(states, BaseState) else states ) - self.metrics = [ - calculate_scaling_metric(state_slice, metric) + self.memory_scalers = [ + calculate_scaling_metric(state_slice, memory_scales_with) for state_slice in self.state_slices ] - if not max_metric: - self.max_metric = estimate_max_metric( - model, self.state_slices, self.metrics, max_atoms_to_try + if not max_memory_scaler: + self.max_memory_scaler = estimate_max_metric( + model, self.state_slices, self.memory_scalers, max_atoms_to_try ) - print(f"Max metric calculated: {self.max_metric}") + print(f"Max metric calculated: {self.max_memory_scaler}") else: - self.max_metric = max_metric + self.max_memory_scaler = max_memory_scaler self.return_indices = return_indices # verify that no systems are too large - max_metric_value = max(self.metrics) - max_metric_idx = self.metrics.index(max_metric_value) - if max_metric_value > self.max_metric: + max_metric_value = max(self.memory_scalers) + max_metric_idx = self.memory_scalers.index(max_metric_value) + if max_metric_value > self.max_memory_scaler: raise ValueError( f"Max metric of system with index {max_metric_idx} in states: " - f"{max(self.metrics)} is greater than max_metric {self.max_metric}, " + f"{max(self.memory_scalers)} is greater than max_metric {self.max_memory_scaler}, " f"please set a larger max_metric or run smaller systems metric." ) - self.index_to_metric = dict(enumerate(self.metrics)) + self.index_to_scaler = dict(enumerate(self.memory_scalers)) self.index_bins = binpacking.to_constant_volume( - self.index_to_metric, V_max=self.max_metric + self.index_to_scaler, V_max=self.max_memory_scaler ) self.batched_states = [] for index_bin in self.index_bins: @@ -233,10 +235,6 @@ def next_batch( Returns: The next batch of states, optionally with indices, or None if no more batches. """ - # TODO: we need to refactor this to operate on the full states rather - # than the state slices, to be aligned with how the hotswapping batcher - # works. - # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here if self.current_state_bin < len(self.batched_states): @@ -257,7 +255,9 @@ def __next__(self): raise StopIteration return next_batch - def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: + def restore_original_order( + self, batched_states: list[BaseState] + ) -> list[BaseState]: """Take the state bins and reorder them into a list. Args: @@ -288,10 +288,12 @@ class HotswappingAutoBatcher: def __init__( self, - model: ModelInterface, states: list[BaseState] | Iterator[BaseState] | BaseState, - metric: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", - max_metric: float | None = None, + model: ModelInterface, + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", + max_memory_scaler: float | None = None, max_atoms_to_try: int = 1_000_000, ) -> None: """Initialize the batcher. @@ -310,11 +312,11 @@ def __init__( self.model = model self.states_iterator = states - self.metric = metric - self.max_metric = max_metric or None + self.memory_scales_with = memory_scales_with + self.max_memory_scaler = max_memory_scaler or None self.max_atoms_to_try = max_atoms_to_try - self.current_metrics = [] + self.current_scalers = [] self.current_idx = [] self.iterator_idx = 0 @@ -326,15 +328,18 @@ def _get_next_states(self) -> None: new_idx = [] new_states = [] for state in self.states_iterator: - metric = calculate_scaling_metric(state, self.metric) - if metric > self.max_metric: + metric = calculate_scaling_metric(state, self.memory_scales_with) + if metric > self.max_memory_scaler: raise ValueError( f"State metric {metric} is greater than max_metric " - f"{self.max_metric}, please set a larger max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " f"or run smaller systems metric." ) # new_metric += sum(new_metrics) - if sum(self.current_metrics) + sum(new_metrics) + metric > self.max_metric: + if ( + sum(self.current_scalers) + sum(new_metrics) + metric + > self.max_memory_scaler + ): # put the state back in the iterator self.states_iterator = chain([state], self.states_iterator) break @@ -344,7 +349,7 @@ def _get_next_states(self) -> None: new_states.append(state) self.iterator_idx += 1 - self.current_metrics.extend(new_metrics) + self.current_scalers.extend(new_metrics) self.current_idx.extend(new_idx) return new_states @@ -356,7 +361,7 @@ def _delete_old_states(self, completed_idx: list[int]) -> None: # update state tracking lists for idx in completed_idx: og_idx = self.current_idx.pop(idx) - self.current_metrics.pop(idx) + self.current_scalers.pop(idx) self.completed_idx_og_order.append(og_idx) def _first_batch(self) -> BaseState: @@ -368,33 +373,33 @@ def _first_batch(self) -> BaseState: # we need to sample a state and use it to estimate the max metric # for the first batch first_state = next(self.states_iterator) - first_metric = calculate_scaling_metric(first_state, self.metric) - self.current_metrics += [first_metric] + first_metric = calculate_scaling_metric(first_state, self.memory_scales_with) + self.current_scalers += [first_metric] self.current_idx += [0] self.iterator_idx += 1 # self.total_metric += first_metric # if max_metric is not set, estimate it - has_max_metric = bool(self.max_metric) + has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - self.max_metric = estimate_max_metric( + self.max_memory_scaler = estimate_max_metric( self.model, [first_state], [first_metric], max_atoms=self.max_atoms_to_try, ) - self.max_metric *= 0.8 + self.max_memory_scaler *= 0.8 states = self._get_next_states() if not has_max_metric: - self.max_metric = estimate_max_metric( + self.max_memory_scaler = estimate_max_metric( self.model, [first_state, *states], - self.current_metrics, + self.current_scalers, max_atoms=self.max_atoms_to_try, ) - print(f"Max metric calculated: {self.max_metric}") + print(f"Max metric calculated: {self.max_memory_scaler}") return concatenate_states([first_state, *states]), [] def next_batch( @@ -403,7 +408,9 @@ def next_batch( convergence_tensor: torch.Tensor | None = None, *, return_indices: bool = False, - ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: + ) -> ( + tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] + ): """Get the next batch of states based on convergence. Args: @@ -430,7 +437,7 @@ def next_batch( # assert statements helpful for debugging, should be moved to validate fn # the first two are most important assert len(convergence_tensor) == updated_state.n_batches - assert len(self.current_idx) == len(self.current_metrics) + assert len(self.current_idx) == len(self.current_scalers) assert len(convergence_tensor.shape) == 1 assert updated_state.n_batches > 0 From ca84c0b2fc050316a4ab5bf10a586e90cc2e288b Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 14:18:51 +0000 Subject: [PATCH 27/36] change names of utility functions --- tests/test_autobatching.py | 8 ++-- torchsim/autobatching.py | 81 ++++++++++---------------------------- 2 files changed, 25 insertions(+), 64 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 1bae2038b..bdbc1a2f1 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -6,7 +6,7 @@ from torchsim.autobatching import ( ChunkingAutoBatcher, HotswappingAutoBatcher, - calculate_scaling_metric, + calculate_memory_scaler, determine_max_batch_size, ) from torchsim.models.lennard_jones import LennardJonesModel @@ -17,18 +17,18 @@ def test_calculate_scaling_metric(si_base_state: BaseState) -> None: """Test calculation of scaling metrics for a state.""" # Test n_atoms metric - n_atoms_metric = calculate_scaling_metric(si_base_state, "n_atoms") + n_atoms_metric = calculate_memory_scaler(si_base_state, "n_atoms") assert n_atoms_metric == si_base_state.n_atoms # Test n_atoms_x_density metric - density_metric = calculate_scaling_metric(si_base_state, "n_atoms_x_density") + density_metric = calculate_memory_scaler(si_base_state, "n_atoms_x_density") volume = torch.abs(torch.linalg.det(si_base_state.cell[0])) / 1000 expected = si_base_state.n_atoms * (si_base_state.n_atoms / volume.item()) assert pytest.approx(density_metric, rel=1e-5) == expected # Test invalid metric with pytest.raises(ValueError, match="Invalid metric"): - calculate_scaling_metric(si_base_state, "invalid_metric") + calculate_memory_scaler(si_base_state, "invalid_metric") def test_split_state(si_double_base_state: BaseState) -> None: diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 49ad9b293..057f46437 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -13,12 +13,12 @@ from torchsim.state import BaseState, concatenate_states, pop_states, split_state -def measure_model_memory_forward(model: ModelInterface, state: BaseState) -> float: +def measure_model_memory_forward(state: BaseState, model: ModelInterface) -> float: """Measure peak GPU memory usage during model forward pass. Args: - model: The model to measure memory usage for. state: The input state to pass to the model. + model: The model to measure memory usage for. Returns: Peak memory usage in GB. @@ -42,7 +42,7 @@ def measure_model_memory_forward(model: ModelInterface, state: BaseState) -> flo def determine_max_batch_size( - model: ModelInterface, state: BaseState, max_atoms: int = 20000 + state: BaseState, model: ModelInterface, max_atoms: int = 500_000 ) -> int: """Determine maximum batch size that fits in GPU memory. @@ -64,7 +64,7 @@ def determine_max_batch_size( concat_state = concatenate_states([state] * n_batches) try: - measure_model_memory_forward(model, concat_state) + measure_model_memory_forward(concat_state, model) except RuntimeError as e: if "CUDA out of memory" in str(e): return fib[i - 2] @@ -73,48 +73,9 @@ def determine_max_batch_size( return fib[-2] -def calculate_baseline_memory(model: ModelInterface) -> float: - """Calculate baseline memory usage of the model. - - Args: - model: The model to measure baseline memory for. - - Returns: - Baseline memory usage in GB. - """ - # Create baseline atoms with different sizes - baseline_atoms = [bulk("Al", "fcc").repeat((i, 1, 1)) for i in range(1, 9, 2)] - baseline_states = [ - atoms_to_state(atoms, model.device, model.dtype) for atoms in baseline_atoms - ] - - # Measure memory usage for each state - memory_list = [ - measure_model_memory_forward(model, state) for state in baseline_states - ] - - # Calculate number of atoms in each baseline state - n_atoms_list = [state.n_atoms for state in baseline_states] - - # Convert to tensors - n_atoms_tensor = torch.tensor(n_atoms_list, dtype=torch.float) - memory_tensor = torch.tensor(memory_list, dtype=torch.float) - - # Prepare design matrix (with column of ones for intercept) - X = torch.stack([torch.ones_like(n_atoms_tensor), n_atoms_tensor], dim=1) - - # Solve normal equations - beta = torch.linalg.lstsq(X, memory_tensor.unsqueeze(1)).solution.squeeze() - - # Extract intercept (b) and slope (m) - intercept, _ = beta[0].item(), beta[1].item() - - return intercept - - -def calculate_scaling_metric( +def calculate_memory_scaler( state_slice: BaseState, - metric: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", ) -> float: """Calculate scaling metric for a state. @@ -125,20 +86,20 @@ def calculate_scaling_metric( Returns: The calculated metric value. """ - if metric == "n_atoms": + if memory_scales_with == "n_atoms": return state_slice.n_atoms - if metric == "n_atoms_x_density": + if memory_scales_with == "n_atoms_x_density": volume = torch.abs(torch.linalg.det(state_slice.cell[0])) / 1000 number_density = state_slice.n_atoms / volume.item() return state_slice.n_atoms * number_density - raise ValueError(f"Invalid metric: {metric}") + raise ValueError(f"Invalid metric: {memory_scales_with}") -def estimate_max_metric( +def estimate_max_memory_scaler( model: ModelInterface, state_list: list[BaseState], metric_values: list[float], - max_atoms: int = 20000, + max_atoms: int = 500_000, ) -> float: """Estimate maximum metric value that fits in GPU memory. @@ -160,8 +121,8 @@ def estimate_max_metric( min_state = state_list[metric_values.argmin()] max_state = state_list[metric_values.argmax()] - min_state_max_batches = determine_max_batch_size(model, min_state, max_atoms) - max_state_max_batches = determine_max_batch_size(model, max_state, max_atoms) + min_state_max_batches = determine_max_batch_size(min_state, model, max_atoms) + max_state_max_batches = determine_max_batch_size(max_state, model, max_atoms) return min(min_state_max_batches * min_metric, max_state_max_batches * max_metric) @@ -177,7 +138,7 @@ def __init__( "n_atoms", "n_atoms_x_density" ] = "n_atoms_x_density", max_memory_scaler: float | None = None, - max_atoms_to_try: int = 1_000_000, + max_atoms_to_try: int = 500_000, return_indices: bool = False, ) -> None: """Initialize the batcher. @@ -193,11 +154,11 @@ def __init__( split_state(states) if isinstance(states, BaseState) else states ) self.memory_scalers = [ - calculate_scaling_metric(state_slice, memory_scales_with) + calculate_memory_scaler(state_slice, memory_scales_with) for state_slice in self.state_slices ] if not max_memory_scaler: - self.max_memory_scaler = estimate_max_metric( + self.max_memory_scaler = estimate_max_memory_scaler( model, self.state_slices, self.memory_scalers, max_atoms_to_try ) print(f"Max metric calculated: {self.max_memory_scaler}") @@ -294,7 +255,7 @@ def __init__( "n_atoms", "n_atoms_x_density" ] = "n_atoms_x_density", max_memory_scaler: float | None = None, - max_atoms_to_try: int = 1_000_000, + max_atoms_to_try: int = 500_000, ) -> None: """Initialize the batcher. @@ -328,7 +289,7 @@ def _get_next_states(self) -> None: new_idx = [] new_states = [] for state in self.states_iterator: - metric = calculate_scaling_metric(state, self.memory_scales_with) + metric = calculate_memory_scaler(state, self.memory_scales_with) if metric > self.max_memory_scaler: raise ValueError( f"State metric {metric} is greater than max_metric " @@ -373,7 +334,7 @@ def _first_batch(self) -> BaseState: # we need to sample a state and use it to estimate the max metric # for the first batch first_state = next(self.states_iterator) - first_metric = calculate_scaling_metric(first_state, self.memory_scales_with) + first_metric = calculate_memory_scaler(first_state, self.memory_scales_with) self.current_scalers += [first_metric] self.current_idx += [0] self.iterator_idx += 1 @@ -382,7 +343,7 @@ def _first_batch(self) -> BaseState: # if max_metric is not set, estimate it has_max_metric = bool(self.max_memory_scaler) if not has_max_metric: - self.max_memory_scaler = estimate_max_metric( + self.max_memory_scaler = estimate_max_memory_scaler( self.model, [first_state], [first_metric], @@ -393,7 +354,7 @@ def _first_batch(self) -> BaseState: states = self._get_next_states() if not has_max_metric: - self.max_memory_scaler = estimate_max_metric( + self.max_memory_scaler = estimate_max_memory_scaler( self.model, [first_state, *states], self.current_scalers, From 012518de5d159c0dc1be8895f42d4341f71e8ad8 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 14:31:30 +0000 Subject: [PATCH 28/36] lint --- .../4_High_level_api/4.2_auto_batching_api.py | 14 +++--- tests/test_autobatching.py | 10 ++--- torchsim/autobatching.py | 43 ++++++++----------- torchsim/state.py | 3 +- 4 files changed, 33 insertions(+), 37 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index 4f945fe5a..b67b47bca 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -1,3 +1,5 @@ +"""Examples of using the auto-batching API.""" + # /// script # dependencies = [ # "mace-torch>=0.3.10", @@ -5,22 +7,21 @@ # ] # /// +"""Run as a interactive script.""" +# ruff: noqa: E402 + # %% import torch from ase.build import bulk from mace.calculators.foundations_models import mace_mp -from torchsim.autobatching import ( - ChunkingAutoBatcher, - HotswappingAutoBatcher, - split_state, -) +from torchsim.autobatching import ChunkingAutoBatcher, HotswappingAutoBatcher, split_state from torchsim.integrators import nvt_langevin from torchsim.models.mace import MaceModel from torchsim.optimizers import unit_cell_fire from torchsim.runners import atoms_to_state -from torchsim.state import BaseState, concatenate_states +from torchsim.state import BaseState from torchsim.units import MetalUnits @@ -56,6 +57,7 @@ # %% def convergence_fn(state: BaseState) -> bool: + """Check if the system has converged.""" batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) max_forces = state.forces.norm(dim=1) batch_wise_max_force = batch_wise_max_force.scatter_reduce( diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index bdbc1a2f1..d32594a04 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -70,9 +70,7 @@ def test_chunking_auto_batcher( assert batcher.memory_scalers[1] == fe_fcc_state.n_atoms # Get batches until None is returned - batches = [] - for batch in batcher: - batches.append(batch) + batches = list(batcher) # Check we got the expected number of batches assert len(batches) == len(batcher.batched_states) @@ -175,7 +173,7 @@ def test_hotswapping_max_metric_too_small( # Get the first batch with pytest.raises(ValueError, match="is greater than max_metric"): - batcher._first_batch() + batcher.next_batch() def test_hotswapping_auto_batcher( @@ -339,7 +337,7 @@ def convergence_fn(state: BaseState) -> bool: break # run 10 steps, arbitrary number - for i in range(10): + for _ in range(10): state = fire_update(state) convergence_tensor = convergence_fn(state) @@ -375,5 +373,5 @@ def test_chunking_auto_batcher_with_fire( restored_states = batcher.restore_original_order(finished_states) assert len(restored_states) == len(fire_states) - for restored, original in zip(restored_states, fire_states): + for restored, original in zip(restored_states, fire_states, strict=False): assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index 057f46437..a6ae2beb5 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -6,10 +6,8 @@ import binpacking import torch -from ase.build import bulk from torchsim.models.interface import ModelInterface -from torchsim.runners import atoms_to_state from torchsim.state import BaseState, concatenate_states, pop_states, split_state @@ -81,7 +79,7 @@ def calculate_memory_scaler( Args: state_slice: The state to calculate metric for. - metric: The type of metric to calculate. + memory_scales_with: The type of metric to calculate. Returns: The calculated metric value. @@ -134,9 +132,8 @@ def __init__( self, states: list[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + *, + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, return_indices: bool = False, @@ -146,9 +143,10 @@ def __init__( Args: model: The model to batch for. states: States to batch. - metric: Metric to use for batching. - max_metric: Maximum metric value per batch. - max_atoms_to_try: Max number of atoms to try when estimating max_metric. + memory_scales_with: Metric to use for batching. + max_memory_scaler: Maximum metric value per batch. + max_atoms_to_try: Max number of atoms to try when estimating max_metric. + return_indices: Whether to return indices along with the batch. """ self.state_slices = ( split_state(states) if isinstance(states, BaseState) else states @@ -172,8 +170,9 @@ def __init__( if max_metric_value > self.max_memory_scaler: raise ValueError( f"Max metric of system with index {max_metric_idx} in states: " - f"{max(self.memory_scalers)} is greater than max_metric {self.max_memory_scaler}, " - f"please set a larger max_metric or run smaller systems metric." + f"{max(self.memory_scalers)} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." ) self.index_to_scaler = dict(enumerate(self.memory_scalers)) @@ -207,18 +206,18 @@ def next_batch( return state return None - def __iter__(self): + def __iter__(self) -> Iterator[BaseState]: + """Iterate over the batches.""" return self - def __next__(self): + def __next__(self) -> BaseState: + """Get the next batch.""" next_batch = self.next_batch(return_indices=self.return_indices) if next_batch is None: raise StopIteration return next_batch - def restore_original_order( - self, batched_states: list[BaseState] - ) -> list[BaseState]: + def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: """Take the state bins and reorder them into a list. Args: @@ -251,9 +250,7 @@ def __init__( self, states: list[BaseState] | Iterator[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: @@ -262,8 +259,8 @@ def __init__( Args: model: The model to batch for. states: States to batch. - metric: Metric to use for batching. - max_metric: Maximum metric value per batch. + memory_scales_with: Metric to use for batching. + max_memory_scaler: Maximum metric value per batch. max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. """ if isinstance(states, BaseState): @@ -369,9 +366,7 @@ def next_batch( convergence_tensor: torch.Tensor | None = None, *, return_indices: bool = False, - ) -> ( - tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] - ): + ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: """Get the next batch of states based on convergence. Args: diff --git a/torchsim/state.py b/torchsim/state.py index 7f6604b2d..f5bdec42c 100644 --- a/torchsim/state.py +++ b/torchsim/state.py @@ -452,7 +452,8 @@ def concatenate_states( # Use the target device or default to the first state's device target_device = device or first_state.device - # Get property scopes from the first state to identify global/per-atom/per-batch properties + # Get property scopes from the first state to identify + # global/per-atom/per-batch properties first_scope = infer_property_scope(first_state) global_props = set(first_scope["global"]) per_atom_props = set(first_scope["per_atom"]) From 8202ab523975033700df8e40c2fa7dd80a8e60bf Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 14:41:13 +0000 Subject: [PATCH 29/36] fix testing --- tests/test_autobatching.py | 4 +- torchsim/autobatching.py | 198 +++++++++++++++++++++++++++---------- 2 files changed, 150 insertions(+), 52 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index d32594a04..3aad2a006 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -173,7 +173,7 @@ def test_hotswapping_max_metric_too_small( # Get the first batch with pytest.raises(ValueError, match="is greater than max_metric"): - batcher.next_batch() + batcher.next_batch(None, None) def test_hotswapping_auto_batcher( @@ -237,7 +237,7 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: ) # Test with a small max_atoms value to limit the sequence - max_size = determine_max_batch_size(lj_calculator, si_base_state, max_atoms=10) + max_size = determine_max_batch_size(si_base_state, lj_calculator, max_atoms=10) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] # Since we're not triggering OOM errors with our mock, it should diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index a6ae2beb5..d27a4d512 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -12,14 +12,17 @@ def measure_model_memory_forward(state: BaseState, model: ModelInterface) -> float: - """Measure peak GPU memory usage during model forward pass. + """Measure peak GPU memory usage during a model's forward pass. + + Clears GPU cache, runs a forward pass with the provided state, and measures + the maximum memory allocated during execution. Args: - state: The input state to pass to the model. - model: The model to measure memory usage for. + state: Input state to pass to the model. + model: Model to measure memory usage for. Returns: - Peak memory usage in GB. + Peak memory usage in gigabytes. """ # Clear GPU memory @@ -44,10 +47,13 @@ def determine_max_batch_size( ) -> int: """Determine maximum batch size that fits in GPU memory. + Uses a Fibonacci sequence to efficiently search for the largest number of + batches that can be processed without running out of GPU memory. + Args: - model: The model to test with. - state: The base state to replicate. - max_atoms: Maximum number of atoms to try. + state: Base state to replicate for testing. + model: Model to test with. + max_atoms: Upper limit on number of atoms to try (for safety). Returns: Maximum number of batches that fit in GPU memory. @@ -75,14 +81,19 @@ def calculate_memory_scaler( state_slice: BaseState, memory_scales_with: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", ) -> float: - """Calculate scaling metric for a state. + """Calculate a metric that estimates memory requirements for a state. + + Provides different scaling metrics based on system properties that correlate + with memory usage. Args: - state_slice: The state to calculate metric for. - memory_scales_with: The type of metric to calculate. + state_slice: State to calculate metric for. + memory_scales_with: Type of metric to use: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density Returns: - The calculated metric value. + Calculated metric value. """ if memory_scales_with == "n_atoms": return state_slice.n_atoms @@ -99,16 +110,19 @@ def estimate_max_memory_scaler( metric_values: list[float], max_atoms: int = 500_000, ) -> float: - """Estimate maximum metric value that fits in GPU memory. + """Estimate maximum memory scaling metric that fits in GPU memory. + + Tests both minimum and maximum metric states to determine a safe upper bound + for the memory scaling metric. Args: - model: The model to test with. + model: Model to test with. state_list: List of states to test. metric_values: Corresponding metric values for each state. max_atoms: Maximum number of atoms to try. Returns: - Maximum metric value that fits in GPU memory. + Maximum safe metric value that fits in GPU memory. """ metric_values = torch.tensor(metric_values) @@ -126,27 +140,39 @@ def estimate_max_memory_scaler( class ChunkingAutoBatcher: - """Batcher that chunks states into bins of similar computational cost.""" + """Batcher that groups states into bins of similar computational cost. + + Divides a collection of states into batches that can be processed efficiently + without exceeding GPU memory. States are grouped based on a memory scaling + metric to maximize GPU utilization. + """ def __init__( self, states: list[BaseState] | BaseState, model: ModelInterface, *, - memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, return_indices: bool = False, ) -> None: - """Initialize the batcher. + """Initialize the chunking auto-batcher. Args: - model: The model to batch for. - states: States to batch. - memory_scales_with: Metric to use for batching. - max_memory_scaler: Maximum metric value per batch. - max_atoms_to_try: Max number of atoms to try when estimating max_metric. - return_indices: Whether to return indices along with the batch. + states: Collection of states to batch (either a list or a single state + that will be split). + model: Model to batch for, used to estimate memory requirements. + memory_scales_with: Metric to use for estimating memory requirements: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density + max_memory_scaler: Maximum metric value allowed per batch. If None, + will be automatically estimated. + max_atoms_to_try: Maximum number of atoms to try when estimating + max_memory_scaler. + return_indices: Whether to return original indices along with batches. """ self.state_slices = ( split_state(states) if isinstance(states, BaseState) else states @@ -189,11 +215,17 @@ def next_batch( ) -> BaseState | tuple[list[BaseState], list[int]] | None: """Get the next batch of states. + Returns batches sequentially until all states have been processed. + Args: - return_indices: Whether to return indices along with the batch. + return_indices: Whether to return original indices along with the batch. + Overrides the value set during initialization. Returns: - The next batch of states, optionally with indices, or None if no more batches. + - If return_indices is False: The next batch of states, + or None if no more batches. + - If return_indices is True: Tuple of (batch, indices), + or None if no more batches. """ # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here @@ -207,24 +239,48 @@ def next_batch( return None def __iter__(self) -> Iterator[BaseState]: - """Iterate over the batches.""" + """Return self as an iterator. + + Allows using the batcher in a for loop. + + Returns: + Self as an iterator. + """ return self def __next__(self) -> BaseState: - """Get the next batch.""" + """Get the next batch for iteration. + + Implements the iterator protocol to allow using the batcher in a for loop. + + Returns: + The next batch of states. + + Raises: + StopIteration: When there are no more batches. + """ next_batch = self.next_batch(return_indices=self.return_indices) if next_batch is None: raise StopIteration return next_batch - def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: - """Take the state bins and reorder them into a list. + def restore_original_order( + self, batched_states: list[BaseState] + ) -> list[BaseState]: + """Reorder processed states back to their original sequence. + + Takes states that were processed in batches and restores them to the + original order they were provided in. Args: batched_states: List of state batches to reorder. Returns: States in their original order. + + Raises: + ValueError: If the number of states doesn't match + the number of original indices. """ state_bins = [split_state(state) for state in batched_states] @@ -244,24 +300,36 @@ def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseSt class HotswappingAutoBatcher: - """Batcher that dynamically swaps states in and out based on convergence.""" + """Batcher that dynamically swaps states based on convergence. + + Optimizes GPU utilization by removing converged states from the batch and + adding new states to process. This approach is ideal for iterative processes + where different states may converge at different rates. + """ def __init__( self, states: list[BaseState] | Iterator[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: - """Initialize the batcher. + """Initialize the hotswapping auto-batcher. Args: - model: The model to batch for. - states: States to batch. - memory_scales_with: Metric to use for batching. - max_memory_scaler: Maximum metric value per batch. - max_atoms_to_try: Maximum number of atoms to try when estimating max_metric. + states: Collection of states to process (list, iterator, or single state + that will be split). + model: Model to batch for, used to estimate memory requirements. + memory_scales_with: Metric to use for estimating memory requirements: + - "n_atoms": Uses only atom count + - "n_atoms_x_density": Uses atom count multiplied by number density + max_memory_scaler: Maximum metric value allowed per batch. If None, + will be automatically estimated. + max_atoms_to_try: Maximum number of atoms to try when estimating + max_memory_scaler. """ if isinstance(states, BaseState): states = split_state(states) @@ -281,7 +349,14 @@ def __init__( self.completed_idx_og_order = [] def _get_next_states(self) -> None: - """Insert states from the iterator until max_metric is reached.""" + """Add states from the iterator until max_memory_scaler is reached. + + Pulls states from the iterator and adds them to the current batch until + adding another would exceed the maximum memory scaling metric. + + Returns: + List of new states added to the batch. + """ new_metrics = [] new_idx = [] new_states = [] @@ -313,6 +388,14 @@ def _get_next_states(self) -> None: return new_states def _delete_old_states(self, completed_idx: list[int]) -> None: + """Remove completed states from tracking lists. + + Updates internal tracking of states and their metrics when states are + completed and removed from processing. + + Args: + completed_idx: Indices of completed states to remove. + """ # Sort in descending order to avoid index shifting problems completed_idx.sort(reverse=True) @@ -323,10 +406,13 @@ def _delete_old_states(self, completed_idx: list[int]) -> None: self.completed_idx_og_order.append(og_idx) def _first_batch(self) -> BaseState: - """Get the first batch of states. + """Create and return the first batch of states. + + Initializes the batcher by estimating memory requirements if needed + and creating the first batch of states to process. Returns: - The first batch of states. + Tuple of (first batch, empty list of completed states). """ # we need to sample a state and use it to estimate the max metric # for the first batch @@ -362,27 +448,36 @@ def _first_batch(self) -> BaseState: def next_batch( self, - updated_state: BaseState, - convergence_tensor: torch.Tensor | None = None, + updated_state: BaseState | None, + convergence_tensor: torch.Tensor | None, *, return_indices: bool = False, - ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: + ) -> ( + tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] + ): """Get the next batch of states based on convergence. + Removes converged states from the batch, adds new states if possible, + and returns both the updated batch and the completed states. + Args: - updated_state: The updated state. + updated_state: Current state after processing. convergence_tensor: Boolean tensor indicating which states have converged. - return_indices: Whether to return indices along with the batch. + If None, assumes this is the first call. + return_indices: Whether to return original indices along with the batch. Returns: - The next batch of states. + - If return_indices is False: Tuple of (next_batch, completed_states) + - If return_indices is True: Tuple of (next_batch, completed_states, indices) + + When no states remain to process, next_batch will be None. """ # TODO: this needs to be refactored to avoid so # many split and concatenate operations, we should # take the updated_concat_state and pop off # the states that have converged. with the pop_states function - if convergence_tensor is None: + if convergence_tensor is None or updated_state is None: if self.iterator_idx > 0: raise ValueError( "A convergence tensor must be provided after the " @@ -426,7 +521,10 @@ def next_batch( def restore_original_order( self, completed_states: list[BaseState] ) -> list[BaseState]: - """Take the list of completed states and reconstruct the original order. + """Reorder completed states back to their original sequence. + + Takes states that were completed in arbitrary order and restores them + to the original order they were provided in. Args: completed_states: List of completed states to reorder. @@ -435,8 +533,8 @@ def restore_original_order( States in their original order. Raises: - ValueError: If the number of completed states doesn't match - the number of indices. + ValueError: If the number of completed states doesn't match the + number of completed indices. """ # TODO: should act on full states, not state slices From e1f2140f9f49222e6d9a8319a6dc11770e4fe338 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 14:42:42 +0000 Subject: [PATCH 30/36] final lint --- torchsim/autobatching.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index d27a4d512..af218f1e5 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -152,9 +152,7 @@ def __init__( states: list[BaseState] | BaseState, model: ModelInterface, *, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, return_indices: bool = False, @@ -264,9 +262,7 @@ def __next__(self) -> BaseState: raise StopIteration return next_batch - def restore_original_order( - self, batched_states: list[BaseState] - ) -> list[BaseState]: + def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: """Reorder processed states back to their original sequence. Takes states that were processed in batches and restores them to the @@ -311,9 +307,7 @@ def __init__( self, states: list[BaseState] | Iterator[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: @@ -452,9 +446,7 @@ def next_batch( convergence_tensor: torch.Tensor | None, *, return_indices: bool = False, - ) -> ( - tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] - ): + ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: """Get the next batch of states based on convergence. Removes converged states from the batch, adds new states if possible, From e69478516c7fdfee65f8121ad6d7595d2d36d9d9 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 15:27:00 +0000 Subject: [PATCH 31/36] skip example if not cuda and correct case of hotswapping autobatcher --- .../4_High_level_api/4.2_auto_batching_api.py | 13 ++++++---- tests/test_autobatching.py | 10 ++++---- torchsim/autobatching.py | 24 +++++++++++++++---- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index b67b47bca..f4c247ec9 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -9,14 +9,19 @@ """Run as a interactive script.""" # ruff: noqa: E402 - +if not torch.cuda.is_available(): + raise SystemExit(0) # %% import torch from ase.build import bulk from mace.calculators.foundations_models import mace_mp -from torchsim.autobatching import ChunkingAutoBatcher, HotswappingAutoBatcher, split_state +from torchsim.autobatching import ( + ChunkingAutoBatcher, + HotSwappingAutoBatcher, + split_state, +) from torchsim.integrators import nvt_langevin from torchsim.models.mace import MaceModel from torchsim.optimizers import unit_cell_fire @@ -28,7 +33,7 @@ si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) -device = torch.device("cpu") +device = torch.device("cuda") mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel( @@ -69,7 +74,7 @@ def convergence_fn(state: BaseState) -> bool: return batch_wise_max_force < 1e-1 -batcher = HotswappingAutoBatcher( +batcher = HotSwappingAutoBatcher( model=mace_model, states=fire_states, memory_scales_with="n_atoms_x_density", diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 3aad2a006..a3dca4b63 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -5,7 +5,7 @@ from torchsim.autobatching import ( ChunkingAutoBatcher, - HotswappingAutoBatcher, + HotSwappingAutoBatcher, calculate_memory_scaler, determine_max_batch_size, ) @@ -164,7 +164,7 @@ def test_hotswapping_max_metric_too_small( states = [si_base_state, fe_fcc_state] # Initialize the batcher with a fixed max_metric - batcher = HotswappingAutoBatcher( + batcher = HotSwappingAutoBatcher( states=states, model=lj_calculator, memory_scales_with="n_atoms", @@ -184,7 +184,7 @@ def test_hotswapping_auto_batcher( states = [si_base_state, fe_fcc_state] # Initialize the batcher with a fixed max_metric - batcher = HotswappingAutoBatcher( + batcher = HotSwappingAutoBatcher( states=states, model=lj_calculator, memory_scales_with="n_atoms", @@ -251,7 +251,7 @@ def test_hotswapping_auto_batcher_restore_order( """Test HotswappingAutoBatcher's restore_original_order method.""" states = [si_base_state, fe_fcc_state] - batcher = HotswappingAutoBatcher( + batcher = HotSwappingAutoBatcher( states=states, model=lj_calculator, memory_scales_with="n_atoms", @@ -304,7 +304,7 @@ def test_hotswapping_with_fire( for state in fire_states: state.positions += torch.randn_like(state.positions) * 0.01 - batcher = HotswappingAutoBatcher( + batcher = HotSwappingAutoBatcher( states=fire_states, model=lj_calculator, memory_scales_with="n_atoms", diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index af218f1e5..f64336d1c 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -124,6 +124,12 @@ def estimate_max_memory_scaler( Returns: Maximum safe metric value that fits in GPU memory. """ + # assert model device is not cpu + if model.device.type == "cpu": + raise ValueError( + "Using the CPU to estimate max memory scaler is not supported." + ) + metric_values = torch.tensor(metric_values) # select one state with the min n_atoms @@ -152,7 +158,9 @@ def __init__( states: list[BaseState] | BaseState, model: ModelInterface, *, - memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, return_indices: bool = False, @@ -262,7 +270,9 @@ def __next__(self) -> BaseState: raise StopIteration return next_batch - def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: + def restore_original_order( + self, batched_states: list[BaseState] + ) -> list[BaseState]: """Reorder processed states back to their original sequence. Takes states that were processed in batches and restores them to the @@ -295,7 +305,7 @@ def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseSt return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] -class HotswappingAutoBatcher: +class HotSwappingAutoBatcher: """Batcher that dynamically swaps states based on convergence. Optimizes GPU utilization by removing converged states from the batch and @@ -307,7 +317,9 @@ def __init__( self, states: list[BaseState] | Iterator[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", + memory_scales_with: Literal[ + "n_atoms", "n_atoms_x_density" + ] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: @@ -446,7 +458,9 @@ def next_batch( convergence_tensor: torch.Tensor | None, *, return_indices: bool = False, - ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: + ) -> ( + tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] + ): """Get the next batch of states based on convergence. Removes converged states from the batch, adds new states if possible, From 1124eb631c3aa040243172b099a01245843b3fdf Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 15:28:09 +0000 Subject: [PATCH 32/36] clean script --- .../4_High_level_api/4.2_auto_batching_api.py | 24 ++----------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index f4c247ec9..fe57b57b7 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -60,7 +60,7 @@ len(fire_states) -# %% +# %% run hot swapping autobatcher def convergence_fn(state: BaseState) -> bool: """Check if the system has converged.""" batch_wise_max_force = torch.zeros(state.n_batches, device=state.device) @@ -99,15 +99,7 @@ def convergence_fn(state: BaseState) -> bool: convergence_tensor = convergence_fn(state) -# %% -batcher.restore_original_order(all_completed_states) - - -# %% -sorted(batcher.completed_idx_og_order) - - -# %% +# %% run chunking autobatcher nvt_init, nvt_update = nvt_langevin( model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature ) @@ -138,15 +130,3 @@ def convergence_fn(state: BaseState) -> bool: batch = nvt_update(batch) finished_states.extend(split_state(batch)) - - -# %% -len(finished_states) - - -# %% -t = torch.tensor([1, 1, 3, 3, 3, 3]) -torch.bincount(t) -_, counts = torch.unique_consecutive(t, return_counts=True) - -print(f"{counts=}") From 04bbeef1703130fca99fc8d3a2f7ca4bb60e715a Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 15:32:37 +0000 Subject: [PATCH 33/36] system exit in proper place and raise error if memory estimation is attempted on CPU --- examples/4_High_level_api/4.2_auto_batching_api.py | 4 ++-- torchsim/autobatching.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index fe57b57b7..8b8d49971 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -9,8 +9,6 @@ """Run as a interactive script.""" # ruff: noqa: E402 -if not torch.cuda.is_available(): - raise SystemExit(0) # %% import torch @@ -29,6 +27,8 @@ from torchsim.state import BaseState from torchsim.units import MetalUnits +if not torch.cuda.is_available(): + raise SystemExit(0) si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3)) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index f64336d1c..aaee4e238 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -24,9 +24,13 @@ def measure_model_memory_forward(state: BaseState, model: ModelInterface) -> flo Returns: Peak memory usage in gigabytes. """ - # Clear GPU memory + # assert model device is not cpu + if model.device.type == "cpu": + raise ValueError( + "Memory estimation does not make sense on CPU and is unsupported." + ) - # gc.collect() + # Clear GPU memory torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() @@ -124,11 +128,6 @@ def estimate_max_memory_scaler( Returns: Maximum safe metric value that fits in GPU memory. """ - # assert model device is not cpu - if model.device.type == "cpu": - raise ValueError( - "Using the CPU to estimate max memory scaler is not supported." - ) metric_values = torch.tensor(metric_values) From 97fb53b71d8f8749793c4629236952ac45116a1d Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 15:48:48 +0000 Subject: [PATCH 34/36] try enabling running on CPU --- examples/4_High_level_api/4.2_auto_batching_api.py | 11 ++++++----- torchsim/autobatching.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index 8b8d49971..a22da2437 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -52,7 +52,7 @@ si_fire_state = fire_init(si_state) fe_fire_state = fire_init(fe_state) -fire_states = [si_fire_state, fe_fire_state] * 20 +fire_states = [si_fire_state, fe_fire_state] * (2 if os.getenv("CI") else 20) fire_states = [state.clone() for state in fire_states] for state in fire_states: state.positions += torch.randn_like(state.positions) * 0.01 @@ -74,12 +74,12 @@ def convergence_fn(state: BaseState) -> bool: return batch_wise_max_force < 1e-1 +single_system_memory = calculate_memory_scaler(fire_states[0]) batcher = HotSwappingAutoBatcher( model=mace_model, states=fire_states, memory_scales_with="n_atoms_x_density", - max_memory_scaler=400_000, - # max_metric=400_000, + max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, ) all_completed_states, convergence_tensor = [], None @@ -111,17 +111,18 @@ def convergence_fn(state: BaseState) -> bool: si_nvt_state = nvt_init(si_state) fe_nvt_state = nvt_init(fe_state) -nvt_states = [si_nvt_state, fe_nvt_state] * 5 +nvt_states = [si_nvt_state, fe_nvt_state] * (2 if os.getenv("CI") else 20) nvt_states = [state.clone() for state in nvt_states] for state in nvt_states: state.positions += torch.randn_like(state.positions) * 0.01 +single_system_memory = calculate_memory_scaler(fire_states[0]) batcher = ChunkingAutoBatcher( model=mace_model, states=nvt_states, memory_scales_with="n_atoms_x_density", - max_memory_scaler=100_000, + max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, ) finished_states = [] diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index aaee4e238..a528a86a5 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -82,7 +82,7 @@ def determine_max_batch_size( def calculate_memory_scaler( - state_slice: BaseState, + state: BaseState, memory_scales_with: Literal["n_atoms_x_density", "n_atoms"] = "n_atoms_x_density", ) -> float: """Calculate a metric that estimates memory requirements for a state. @@ -91,7 +91,7 @@ def calculate_memory_scaler( with memory usage. Args: - state_slice: State to calculate metric for. + state: State to calculate metric for. memory_scales_with: Type of metric to use: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density @@ -99,12 +99,14 @@ def calculate_memory_scaler( Returns: Calculated metric value. """ + if state.n_batches > 1: + raise ValueError("State must be a single batch") if memory_scales_with == "n_atoms": - return state_slice.n_atoms + return state.n_atoms if memory_scales_with == "n_atoms_x_density": - volume = torch.abs(torch.linalg.det(state_slice.cell[0])) / 1000 - number_density = state_slice.n_atoms / volume.item() - return state_slice.n_atoms * number_density + volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 + number_density = state.n_atoms / volume.item() + return state.n_atoms * number_density raise ValueError(f"Invalid metric: {memory_scales_with}") From 264a66099543f892a25ecc63651576e5c85170b5 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 13 Mar 2025 15:49:18 +0000 Subject: [PATCH 35/36] lint --- .../4_High_level_api/4.2_auto_batching_api.py | 5 +++++ torchsim/autobatching.py | 17 ++++------------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index a22da2437..93da9465c 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -10,7 +10,10 @@ """Run as a interactive script.""" # ruff: noqa: E402 + # %% +import os + import torch from ase.build import bulk from mace.calculators.foundations_models import mace_mp @@ -18,6 +21,7 @@ from torchsim.autobatching import ( ChunkingAutoBatcher, HotSwappingAutoBatcher, + calculate_memory_scaler, split_state, ) from torchsim.integrators import nvt_langevin @@ -27,6 +31,7 @@ from torchsim.state import BaseState from torchsim.units import MetalUnits + if not torch.cuda.is_available(): raise SystemExit(0) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index a528a86a5..f84139c98 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -130,7 +130,6 @@ def estimate_max_memory_scaler( Returns: Maximum safe metric value that fits in GPU memory. """ - metric_values = torch.tensor(metric_values) # select one state with the min n_atoms @@ -159,9 +158,7 @@ def __init__( states: list[BaseState] | BaseState, model: ModelInterface, *, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, return_indices: bool = False, @@ -271,9 +268,7 @@ def __next__(self) -> BaseState: raise StopIteration return next_batch - def restore_original_order( - self, batched_states: list[BaseState] - ) -> list[BaseState]: + def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseState]: """Reorder processed states back to their original sequence. Takes states that were processed in batches and restores them to the @@ -318,9 +313,7 @@ def __init__( self, states: list[BaseState] | Iterator[BaseState] | BaseState, model: ModelInterface, - memory_scales_with: Literal[ - "n_atoms", "n_atoms_x_density" - ] = "n_atoms_x_density", + memory_scales_with: Literal["n_atoms", "n_atoms_x_density"] = "n_atoms_x_density", max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: @@ -459,9 +452,7 @@ def next_batch( convergence_tensor: torch.Tensor | None, *, return_indices: bool = False, - ) -> ( - tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]] - ): + ) -> tuple[BaseState, list[BaseState]] | tuple[BaseState, list[BaseState], list[int]]: """Get the next batch of states based on convergence. Removes converged states from the batch, adds new states if possible, From 07addcc316d2b4a646e13551a318ceb70c0f090d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 13 Mar 2025 13:27:01 -0400 Subject: [PATCH 36/36] propagate Hotswapping->HotSwapping rename to tests --- .../4_High_level_api/4.2_auto_batching_api.py | 1 - tests/test_autobatching.py | 16 ++++++++-------- torchsim/autobatching.py | 6 +++--- torchsim/workflows.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/4_High_level_api/4.2_auto_batching_api.py b/examples/4_High_level_api/4.2_auto_batching_api.py index 93da9465c..19f3cb816 100644 --- a/examples/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/4_High_level_api/4.2_auto_batching_api.py @@ -3,7 +3,6 @@ # /// script # dependencies = [ # "mace-torch>=0.3.10", -# "pymatgen>=2025.2.18", # ] # /// diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a3dca4b63..25b61fd80 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -156,10 +156,10 @@ def test_chunking_auto_batcher_restore_order_with_split_states( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) -def test_hotswapping_max_metric_too_small( +def test_hot_swapping_max_metric_too_small( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: - """Test HotswappingAutoBatcher with different states.""" + """Test HotSwappingAutoBatcher with different states.""" # Create a list of states states = [si_base_state, fe_fcc_state] @@ -176,10 +176,10 @@ def test_hotswapping_max_metric_too_small( batcher.next_batch(None, None) -def test_hotswapping_auto_batcher( +def test_hot_swapping_auto_batcher( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: - """Test HotswappingAutoBatcher with different states.""" + """Test HotSwappingAutoBatcher with different states.""" # Create a list of states states = [si_base_state, fe_fcc_state] @@ -245,10 +245,10 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: assert max_size == 8 -def test_hotswapping_auto_batcher_restore_order( +def test_hot_swapping_auto_batcher_restore_order( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: - """Test HotswappingAutoBatcher's restore_original_order method.""" + """Test HotSwappingAutoBatcher's restore_original_order method.""" states = [si_base_state, fe_fcc_state] batcher = HotSwappingAutoBatcher( @@ -291,7 +291,7 @@ def test_hotswapping_auto_batcher_restore_order( # batcher.restore_original_order([si_base_state]) -def test_hotswapping_with_fire( +def test_hot_swapping_with_fire( si_base_state: BaseState, fe_fcc_state: BaseState, lj_calculator: LennardJonesModel ) -> None: fire_init, fire_update = unit_cell_fire(lj_calculator) @@ -373,5 +373,5 @@ def test_chunking_auto_batcher_with_fire( restored_states = batcher.restore_original_order(finished_states) assert len(restored_states) == len(fire_states) - for restored, original in zip(restored_states, fire_states, strict=False): + for restored, original in zip(restored_states, fire_states, strict=True): assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/torchsim/autobatching.py b/torchsim/autobatching.py index f84139c98..6842b974b 100644 --- a/torchsim/autobatching.py +++ b/torchsim/autobatching.py @@ -297,7 +297,7 @@ def restore_original_order(self, batched_states: list[BaseState]) -> list[BaseSt ) # sort states by original indices - indexed_states = list(zip(original_indices, all_states, strict=False)) + indexed_states = list(zip(original_indices, all_states, strict=True)) return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] @@ -317,7 +317,7 @@ def __init__( max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, ) -> None: - """Initialize the hotswapping auto-batcher. + """Initialize the hot-swapping auto-batcher. Args: states: Collection of states to process (list, iterator, or single state @@ -544,7 +544,7 @@ def restore_original_order( # Create pairs of (original_index, state) indexed_states = list( - zip(self.completed_idx_og_order, completed_states, strict=False) + zip(self.completed_idx_og_order, completed_states, strict=True) ) # Sort by original index diff --git a/torchsim/workflows.py b/torchsim/workflows.py index b1793639b..91eba7503 100644 --- a/torchsim/workflows.py +++ b/torchsim/workflows.py @@ -580,7 +580,7 @@ def get_subcells_to_crystallize( # Convert stoichiometries to composition formulas comps = [] for stoich in stoichs: - comp = dict(zip(elements, stoich, strict=False)) + comp = dict(zip(elements, stoich, strict=True)) comps.append(Composition.from_dict(comp).reduced_formula) restrict_to_compositions = set(comps)