diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 200e538b7..96e9ec319 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.12 + rev: v0.11.13 hooks: - id: ruff args: [--fix] diff --git a/codecov.yml b/codecov.yml index a630999c0..196afc60e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,5 +1,3 @@ -ignore: [torch_sim/unbatched] - comment: false coverage: diff --git a/docs/_static/get_module_graph_dot_file.py b/docs/_static/get_module_graph_dot_file.py deleted file mode 100644 index 5d27f623d..000000000 --- a/docs/_static/get_module_graph_dot_file.py +++ /dev/null @@ -1,394 +0,0 @@ -"""Generate a direct colored DOT representation of the torch_sim module dependencies. - -This script analyzes the torch_sim package structure by directly inspecting imports -and creates a DOT file with: -1. Nodes colored by their connectedness (number of connections) -2. No "torch_sim." prefix in node labels -3. Line count information for each module -4. Customizable layout parameters via command-line options -""" - -import argparse -import ast -import colorsys -import os -import tomllib -from collections import defaultdict - - -with open("pyproject.toml", "rb") as pyproject: - github_base_url = tomllib.load(pyproject)["project"]["urls"]["Repo"] - - -def rgb_to_hex(red: float, green: float, blue: float) -> str: - """Convert RGB values to hex color code.""" - return f"#{int(red * 255):02x}{int(green * 255):02x}{int(blue * 255):02x}" - - -def generate_heat_colors(num_colors: int) -> list: - """Generate a color gradient from blue (cool) to red (hot).""" - colors = [] - for index in range(num_colors): - # Hue range from blue (0.66) to red (0) - hue = 0.66 * (1 - index / (num_colors - 1)) if num_colors > 1 else 0 - red, green, blue = colorsys.hsv_to_rgb(hue, 0.8, 0.9) - colors.append(rgb_to_hex(red, green, blue)) - return colors - - -def get_file_imports(file_path: str) -> list: - """Extract all import statements from a Python file.""" - try: - with open(file_path, encoding="utf-8") as file_handle: - tree = ast.parse(file_handle.read()) - - imports = [] - - # Extract regular imports (import x, import y) - for node in ast.walk(tree): - if isinstance(node, ast.Import): - imports.extend(name.name for name in node.names) - elif isinstance(node, ast.ImportFrom) and node.module: - imports.append(node.module) - except SyntaxError: - print(f"Syntax error in {file_path}") - return [] - else: - return imports - - -def count_lines(file_path: str) -> int: - """Count non-empty, non-comment lines in a Python file.""" - with open(file_path, encoding="utf-8") as file_handle: - return sum( - bool(line.strip() and not line.strip().startswith("#")) - for line in file_handle - ) - - -def analyze_package( - package_path: str, package_name: str = "torch_sim" -) -> dict[str, set[str]]: - """Build a dependency graph of package modules.""" - dependency_graph: defaultdict[str, set[str]] = defaultdict(set) - - # Get Python files excluding those starting with _ - python_files = [ - f"{root}/{file}" - for root, _, files in os.walk(package_path) - for file in files - if file.endswith(".py") and not file.startswith("_") - ] - - print(f"Found {len(python_files)} Python files in {package_path}") - - for file_path in python_files: - # Convert file path to module name - rel_path = os.path.relpath(file_path, os.path.dirname(package_path)) - module_name = rel_path.replace(".py", "").replace("/", ".").replace("\\", ".") - - # Get imports and add only torch_sim imports to the graph - imports = get_file_imports(file_path) - torch_sim_imports = [ - import_name - for import_name in imports - if import_name == package_name or import_name.startswith(f"{package_name}.") - ] - - for imported_module in torch_sim_imports: - dependency_graph[module_name].add(imported_module) - - return dependency_graph - - -def simplify_module_name(full_name: str, base_package: str = "torch_sim") -> str: - """Remove the base package prefix from module names.""" - if full_name == base_package: - return base_package - if full_name.startswith(f"{base_package}."): - return full_name[len(base_package) + 1 :] - return full_name - - -def module_to_node_id(module_name: str) -> str: - """Convert module name to valid DOT node ID.""" - return module_name.replace(".", "_") - - -def generate_dot_file( # noqa: C901, PLR0915 - dependency_graph: dict[str, set[str]], - output_file: str, - package_path: str, - args: argparse.Namespace, -) -> None: - """Generate a DOT file with connectedness-based coloring.""" - connections: defaultdict[str, int] = defaultdict(int) - all_modules = set(dependency_graph.keys()).union( - {dep for deps in dependency_graph.values() for dep in deps} - ) - - # Count connections (outgoing + incoming) - for module, deps in dependency_graph.items(): - connections[module] += len(deps) # Outgoing - for dep in deps: - if dep in all_modules: - connections[dep] += 1 # Incoming - - # Define connection ranges and colors - connection_ranges = ["0-1", "2-3", "4-5", "6-7", "8-10", "11-15", "16+"] - range_thresholds = [1, 3, 5, 7, 10, 15, float("inf")] # Upper bounds for each range - colors = generate_heat_colors(len(connection_ranges)) - - # Map connection counts to ranges - def get_range(count: int) -> str: - for idx, threshold in enumerate(range_thresholds): - if count <= threshold: - return connection_ranges[idx] - return connection_ranges[-1] # Should never reach here - - # Start generating DOT content - lines = ["digraph G {", " layout=dot;"] - - # Add layout parameters - if args.engine != "dot": - lines.append(f" layout = {args.engine};") - - lines += [f" concentrate = {str(args.concentrate).lower()};"] - lines += [ - f" {key} = {getattr(args, key)};" - for key in "ratio nodesep ranksep rankdir overlap splines maxiter".split() # noqa: SIM905 - if getattr(args, key, None) is not None - ] - - if args.pack: - lines.append(" pack = true;") - - # Engine-specific options - if args.engine == "fdp": - lines.append(f" K = {args.K};") - elif args.engine == "neato": - lines.append(" epsilon = 0.00001;") - - # Node styling - node_attrs = [ - "style=filled", - 'fillcolor="#ffffff"', - 'fontcolor="#000000"', - "fontname=Helvetica", - "fontsize=10", - f'margin="{args.margin}"', - ] - if args.node_height: - node_attrs.append(f"height={args.node_height}") - - lines.append(f" node [{','.join(node_attrs)}];") - lines.append("") - - # Add color legend - lines.append(" // Color legend by node connectedness") - range_to_color = { - range_name: colors[idx] for idx, range_name in enumerate(connection_ranges) - } - for range_name in connection_ranges: - lines.append(f" // {range_to_color[range_name]} = {range_name} connections") - lines.append("") - - # Add nodes - for module in sorted(all_modules): - simple_name = simplify_module_name(module) - node_id = module_to_node_id(module) - range_name = get_range(connections[module]) - color = range_to_color[range_name] - - # Count lines if it's a torch_sim module - label = simple_name - github_url = "" - - if module.startswith("torch_sim"): - relative_module = module.replace(".", os.sep) - module_name = relative_module[len("torch_sim") + 1 :] - module_path = f"{package_path}/{module_name}.py" - # Include 'torch_sim' in the GitHub URL path - github_url = f"{github_base_url}/blob/main/torch_sim/{module_name}.py" - - if os.path.isfile(module_path): - line_count = count_lines(module_path) - label = f"{simple_name}\\n({line_count} lines)" - - node_style = f'fillcolor="{color}",fontcolor="white",label="{label}",shape="box"' - if github_url: - node_style += f',URL="{github_url}",tooltip="View source on GitHub"' - - lines.append(f" {node_id} [{node_style}];") - - # Add edges - lines.append("") - for module, deps in sorted(dependency_graph.items()): - source_id = module_to_node_id(module) - for dep in sorted(deps): - if dep in all_modules: - target_id = module_to_node_id(dep) - lines.append(f" {source_id} -> {target_id};") - - lines.append("}") - - # Write to file - with open(output_file, "w", encoding="utf-8") as file_handle: - file_handle.write("\n".join(lines) + "\n") - - print(f"Generated DOT file: {output_file}") - - # Print statistics - print("\nNode connectedness statistics:") - stats = defaultdict(list) - for module, count in connections.items(): - stats[get_range(count)].append((module, count)) - - for range_name in connection_ranges: - if nodes := stats.get(range_name): - avg = sum(node[1] for node in nodes) / len(nodes) - print(f" {range_name}: {len(nodes)} nodes, avg {avg:.1f} connections") - if range_name in ["11-15", "16+"]: - hub_modules = ", ".join(simplify_module_name(node[0]) for node in nodes) - print(f" - Hub modules: {hub_modules}") - - -def main() -> None: - """Entry point that parses args and generates the DOT file.""" - parser = argparse.ArgumentParser( - description="Generate a colored DOT file of torch_sim module dependencies" - ) - - # Graph layout options - parser.add_argument( - "--engine", - choices=["dot", "neato", "fdp", "circo", "twopi", "osage"], - default="dot", - help="GraphViz layout engine (default: dot)", - ) - parser.add_argument( - "--concentrate", - action="store_true", - default=True, - help="Concentrate edges (default: True)", - ) - parser.add_argument( - "--ratio", type=float, default=0.8, help="Aspect ratio (default: 0.8)" - ) - parser.add_argument( - "--nodesep", - type=float, - default=0.08, - help="Horizontal node separation (default: 0.08)", - ) - parser.add_argument( - "--ranksep", - type=float, - default=0.1, - help="Vertical rank separation (default: 0.1)", - ) - parser.add_argument( - "--overlap", - choices=["true", "false", "scale", "compress", "vpsc", "prism", "none"], - default="false", - help="Overlap handling (default: false)", - ) - parser.add_argument( - "--splines", - choices=["true", "false", "ortho", "curved", "line", "polyline", "none"], - default=None, - help="Edge spline style (default: none)", - ) - parser.add_argument( - "--margin", default="0.08,0.02", help="Node margin (default: 0.08,0.02)" - ) - parser.add_argument( - "--node-height", type=float, default=0.5, help="Fixed node height (default: 0.5)" - ) - parser.add_argument( - "--rankdir", - choices=["TB", "LR", "BT", "RL"], - default="LR", - help="Rank direction (default: LR)", - ) - - # Advanced options - parser.add_argument( - "--pack", - action="store_true", - default=False, - help="Pack graph components tightly (default: False)", - ) - parser.add_argument( - "--maxiter", - type=int, - default=None, - help="Max iterations for force-directed layouts", - ) - parser.add_argument( - "--compact", - action="store_true", - default=False, - help="Enable compact layout preset", - ) - parser.add_argument( - "--K", - type=float, - default=0.1, - help="Spring constant for force-directed layouts (default: 0.1)", - ) - - # Output options - parser.add_argument( - "--output-dir", - default="docs/_static", - help="Output directory (default: docs/_static)", - ) - parser.add_argument( - "--output-file", - default="torch-sim-module-graph.dot", - help="Output filename (default: torch-sim-module-graph.dot)", - ) - - args, _unknown = parser.parse_known_args() - - # Apply compact preset if selected - if args.compact: - args.rankdir = args.rankdir or "LR" - if args.engine == "dot": - args.ranksep, args.nodesep, args.overlap, args.splines = ( - 0.01, - 0.05, - "compress", - "ortho", - ) - args.node_height = args.node_height or 0.1 - elif args.engine in ["fdp", "neato"]: - args.overlap, args.pack, args.maxiter = "prism", True, 500 - - # Set up paths - package_name = "torch_sim" - current_dir = os.path.dirname(os.path.abspath(__file__)) - # Look for the package at the project root - project_root = os.path.dirname(os.path.dirname(current_dir)) - package_path = f"{project_root}/{package_name}" - - # Set up output directory - output_dir = args.output_dir - os.makedirs(output_dir, exist_ok=True) - output_file = f"{output_dir}/{args.output_file}" - - print(f"Analyzing package: {package_name} at {package_path}") - - dependency_graph = analyze_package(package_path, package_name) - generate_dot_file(dependency_graph, output_file, package_path, args) - - print(f"\nDone! Generated DOT file: {output_file}") - print(f"Layout engine: {args.engine}") - print("\nTo render directly with GraphViz (if installed):") - svg_path = output_file.replace(".dot", ".svg") - print(f" dot -T{args.engine} -Tsvg {output_file} -o {svg_path}") - - -if __name__ == "__main__": - main() diff --git a/docs/conf.py b/docs/conf.py index acb691174..34e59dd5d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# ruff: noqa: E501 +# ruff: noqa: E501, INP001 """Sphinx configuration file.""" diff --git a/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py b/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py index 4e42897f2..4a71bcdfd 100644 --- a/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py +++ b/examples/scripts/7_Others/7.1_Soft_sphere_autograd.py @@ -8,6 +8,8 @@ # ] # /// +# ruff: noqa: RUF001 + import numpy as np import torch from plotly.subplots import make_subplots @@ -62,10 +64,7 @@ ) # Add figure titles and labels -fig.update_layout( - title="Soft Sphere Potential", - xaxis_title="Distance (r/σ)", -) +fig.update_layout(title="Soft Sphere Potential", xaxis_title="Distance (r/σ)") # Update y-axes labels fig.update_yaxes(title_text="Energy (ε)", secondary_y=False) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index abda6ee0d..82fdf8f0a 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -484,7 +484,8 @@ def run_optimization_ase( # noqa: C901, PLR0915 final_state_opt: SimState | GDState | None = None if optimizer_type_val == "torch_sim": - assert ts_md_flavor_val is not None, "ts_md_flavor must be provided for torch_sim" + if ts_md_flavor_val is None: + raise ValueError(f"{ts_md_flavor_val=} must be provided for torch_sim") steps, final_state_opt = run_optimization_ts( initial_state=state.clone(), ts_md_flavor=ts_md_flavor_val, diff --git a/pyproject.toml b/pyproject.toml index a87cadc1e..84a3d46dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,28 +96,18 @@ ignore = [ "EM102", # Exception must not use an f-string literal, assign to variable first "ERA001", # Found commented-out code "FIX002", # Line contains TODO, consider resolving the issue - "G003", # logging-string-concat - "G004", # logging uses f-string - "INP001", # implicit-namespace-package - "ISC001", # avoid conflicts with the formatter "N803", # Variable name should be lowercase "N806", # Uppercase letters in variable names - "PD010", # .pivot_table is preferred to .pivot or .unstack; provides same functionality - "PD015", # pandas-use-of-pd-merge "PLR0912", # too many branches "PLR0913", # too many function arguments "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable "PLW2901", # Outer for loop variable overwritten by inner assignment target "PTH", # flake8-use-pathlib - "RUF001", # String contains ambiguous - "S101", # Use of assertion statements "S301", # pickle and modules that wrap it can be unsafe, possible security issue "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "SIM105", # Use contextlib.suppress instead of try-except-pass - "T201", # print found "TD", # flake8-todos "TRY003", # Avoid specifying long messages outside the exception class - "TRY301", # Abstract raise to an inner function ] pydocstyle.convention = "google" isort.split-on-trailing-comma = false @@ -125,8 +115,8 @@ isort.lines-after-imports = 2 pep8-naming.ignore-names = ["get_kT", "kT"] [tool.ruff.lint.per-file-ignores] -"**/tests/*" = ["ANN201", "D", "S101"] -"examples/**/*" = ["B018"] +"**/tests/*" = ["ANN201", "D", "INP001", "S101"] +"examples/**/*" = ["B018", "T201"] "examples/tutorials/**/*" = ["ALL"] [tool.ruff.format] diff --git a/tests/models/test_morse.py b/tests/models/test_morse.py index 3c59320ba..5dfce5f42 100644 --- a/tests/models/test_morse.py +++ b/tests/models/test_morse.py @@ -28,7 +28,6 @@ def test_morse_pair_asymptotic() -> None: dr = torch.tensor([[1.0]]) # Large distance epsilon = 5.0 energy = morse_pair(dr, epsilon=epsilon) - print(energy, -epsilon * torch.ones_like(energy)) torch.testing.assert_close( energy, -epsilon * torch.ones_like(energy), rtol=1e-2, atol=1e-5 ) @@ -55,8 +54,6 @@ def test_morse_force_energy_consistency() -> None: force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0] # Compare forces - print(force_direct) - print(force_from_grad) assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4) diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index 5284db3c5..f14cf833f 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -327,8 +327,8 @@ def test_matrix_symmetry_validation(matrix_name: str, matrix: torch.Tensor) -> N # Replace one matrix with the non-symmetric version params[matrix_name] = matrix - # Should raise AssertionError due to asymmetric matrix - with pytest.raises(AssertionError): + # Should raise ValueError due to asymmetric matrix + with pytest.raises(ValueError, match="is not symmetric"): ss.SoftSphereMultiModel(**params) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 53ea4932f..2c974ef64 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -457,10 +457,7 @@ def convergence_fn(state: ts.SimState) -> bool: all_completed_states, convergence_tensor = [], None while True: - print(f"Starting new batch of {state.n_batches} states.") - state, completed_states = batcher.next_batch(state, convergence_tensor) - print("Number of completed states", len(completed_states)) all_completed_states.extend(completed_states) if state is None: diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index d32f29c51..87ff8697f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -152,8 +152,9 @@ def test_fire_optimization( energies.append(state.energy.item()) steps_taken += 1 - if steps_taken == max_steps: - print(f"FIRE optimization for {md_flavor=} did not converge in {max_steps} steps") + assert steps_taken < max_steps, ( + f"FIRE optimization for {md_flavor=} did not converge in {max_steps=}" + ) energies = energies[1:] @@ -327,7 +328,6 @@ def test_unit_cell_fire_optimization( ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor ) -> None: """Test that the Unit Cell FIRE optimizer actually minimizes energy.""" - print(f"\n--- Starting test_unit_cell_fire_optimization for {md_flavor=} ---") # Add random displacement to positions and cell current_positions = ( @@ -347,48 +347,33 @@ def test_unit_cell_fire_optimization( atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), batch=ar_supercell_sim_state.batch.clone(), ) - print(f"[{md_flavor}] Initial SimState created.") initial_state_positions = current_sim_state.positions.clone() initial_state_cell = current_sim_state.cell.clone() # Initialize FIRE optimizer - print(f"Initializing {md_flavor} optimizer...") init_fn, update_fn = unit_cell_fire( model=lj_model, dt_max=0.3, dt_start=0.1, md_flavor=md_flavor, ) - print(f"[{md_flavor}] Optimizer functions obtained.") state = init_fn(current_sim_state) - energy = float(getattr(state, "energy", "nan")) - print(f"[{md_flavor}] Initial state created by init_fn. {energy=:.4f}") # Run optimization for a few steps energies = [1000.0, state.energy.item()] max_steps = 1000 steps_taken = 0 - print(f"[{md_flavor}] Entering optimization loop (max_steps: {max_steps})...") while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: state = update_fn(state) energies.append(state.energy.item()) steps_taken += 1 - print(f"[{md_flavor}] Loop finished after {steps_taken} steps.") - - if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6: - print( - f"WARNING: Unit Cell FIRE {md_flavor=} optimization did not converge " - f"in {max_steps} steps. Final energy: {energies[-1]:.4f}" - ) - else: - print( - f"Unit Cell FIRE {md_flavor=} optimization converged in {steps_taken} " - f"steps. Final energy: {energies[-1]:.4f}" - ) + assert steps_taken < max_steps, ( + f"Unit Cell FIRE {md_flavor=} optimization did not converge in {max_steps=}" + ) energies = energies[1:] @@ -522,7 +507,6 @@ def test_frechet_cell_fire_optimization( ) -> None: """Test that the Frechet Cell FIRE optimizer actually minimizes energy for different md_flavors.""" - print(f"\n--- Starting test_frechet_cell_fire_optimization for {md_flavor=} ---") # Add random displacement to positions and cell # Create a fresh copy for each test run to avoid interference @@ -543,48 +527,33 @@ def test_frechet_cell_fire_optimization( atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), batch=ar_supercell_sim_state.batch.clone(), ) - print(f"[{md_flavor}] Initial SimState created for Frechet test.") initial_state_positions = current_sim_state.positions.clone() initial_state_cell = current_sim_state.cell.clone() # Initialize FIRE optimizer - print(f"Initializing Frechet {md_flavor} optimizer...") init_fn, update_fn = frechet_cell_fire( model=lj_model, dt_max=0.3, dt_start=0.1, md_flavor=md_flavor, ) - print(f"[{md_flavor}] Frechet optimizer functions obtained.") state = init_fn(current_sim_state) - energy = float(getattr(state, "energy", "nan")) - print(f"[{md_flavor}] Initial state created by Frechet init_fn. {energy=:.4f}") # Run optimization for a few steps energies = [1000.0, state.energy.item()] # Ensure float for comparison max_steps = 1000 steps_taken = 0 - print(f"[{md_flavor}] Entering Frechet optimization loop (max_steps: {max_steps})...") while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: state = update_fn(state) energies.append(state.energy.item()) steps_taken += 1 - print(f"[{md_flavor}] Frechet loop finished after {steps_taken} steps.") - - if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6: - print( - f"WARNING: Frechet Cell FIRE {md_flavor=} optimization did not converge " - f"in {max_steps} steps. Final energy: {energies[-1]:.4f}" - ) - else: - print( - f"Frechet Cell FIRE {md_flavor=} optimization converged in {steps_taken} " - f"steps. Final energy: {energies[-1]:.4f}" - ) + assert steps_taken < max_steps, ( + f"Frechet FIRE {md_flavor=} optimization did not converge in {max_steps=}" + ) energies = energies[1:] @@ -600,8 +569,7 @@ def test_frechet_cell_fire_optimization( pressure = torch.trace(state.stress.squeeze(0)) / 3.0 # Adjust tolerances if needed, Frechet might behave slightly differently - pressure_tol = 0.01 - force_tol = 0.2 + pressure_tol, force_tol = 0.01, 0.2 assert torch.abs(pressure) < pressure_tol, ( f"{md_flavor=} pressure should be below {pressure_tol=} after Frechet " diff --git a/tests/test_runners.py b/tests/test_runners.py index d25da25cf..5d7201c92 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -768,12 +768,12 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: cu_atoms = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2)) many_cu_atoms = [cu_atoms] * 5 - trajectory_files = [tmp_path / f"Cu_traj_{i}" for i in range(len(many_cu_atoms))] + trajectory_files = [tmp_path / f"Cu_traj_{i}.h5md" for i in range(len(many_cu_atoms))] # run them all simultaneously with batching final_state = ts.integrate( system=many_cu_atoms, - model=lj_model, + model=lj_model, # using LJ instead of MACE for testing n_steps=50, timestep=0.002, temperature=1000, @@ -788,17 +788,17 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: with ts.TorchSimTrajectory(filename) as traj: final_energies.append(traj.get_array("potential_energy")[-1]) - print(final_energies) + assert len(final_energies) == len(trajectory_files) # relax all of the high temperature states relaxed_state = ts.optimize( system=final_state, model=lj_model, optimizer=ts.frechet_cell_fire, - # autobatcher=True, + # autobatcher=True, # disabled for CPU-based LJ model in test ) - print(relaxed_state.energy) + assert relaxed_state.energy.shape == (final_state.n_batches,) @pytest.fixture diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index e7781d26b..f2eb32c72 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -20,7 +20,6 @@ model architectures and GPU configurations. """ -import logging from collections.abc import Callable, Iterator from itertools import chain from typing import Any, get_args @@ -233,7 +232,7 @@ def measure_model_memory_forward(state: SimState, model: ModelInterface) -> floa "Memory estimation does not make sense on CPU and is unsupported." ) - logging.info( # noqa: LOG015 + print( # noqa: T201 "Model Memory Estimation: Running forward pass on state with " f"{state.n_atoms} atoms and {state.n_batches} batches.", ) @@ -403,7 +402,7 @@ def estimate_max_memory_scaler( min_state = state_list[metric_values.argmin()] max_state = state_list[metric_values.argmax()] - logging.info( # noqa: LOG015 + print( # noqa: T201 "Model Memory Estimation: Estimating memory from worst case of " f"largest and smallest system. Largest system has {max_state.n_atoms} atoms " f"and {max_state.n_batches} batches, and smallest system has " @@ -955,10 +954,8 @@ def _get_first_batch(self) -> SimState: self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding return concatenate_states([first_state, *states]) - def next_batch( - self, - updated_state: SimState | None, - convergence_tensor: torch.Tensor | None, + def next_batch( # noqa: C901 + self, updated_state: SimState | None, convergence_tensor: torch.Tensor | None ) -> ( tuple[SimState | None, list[SimState]] | tuple[SimState | None, list[SimState], list[int]] @@ -1022,10 +1019,14 @@ def next_batch( # assert statements helpful for debugging, should be moved to validate fn # the first two are most important - assert len(convergence_tensor) == updated_state.n_batches - assert len(self.current_idx) == len(self.current_scalers) - assert len(convergence_tensor.shape) == 1 - assert updated_state.n_batches > 0 + if len(convergence_tensor) != updated_state.n_batches: + raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_batches=}") + if len(self.current_idx) != len(self.current_scalers): + raise ValueError(f"{len(self.current_idx)=} != {len(self.current_scalers)=}") + if len(convergence_tensor.shape) != 1: + raise ValueError(f"{len(convergence_tensor.shape)=} != 1") + if updated_state.n_batches <= 0: + raise ValueError(f"{updated_state.n_batches=} <= 0") # Increment attempt counters and check for max attempts in a single loop for cur_idx, abs_idx in enumerate(self.current_idx): diff --git a/torch_sim/math.py b/torch_sim/math.py index 07037e0e8..40228ba4d 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -984,7 +984,7 @@ def matrix_log_33( "Falling back to scipy" ) if fallback_warning: - print(msg) + print(msg) # noqa: T201 # Fall back to scipy implementation return matrix_log_scipy(matrix).to(sim_dtype) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 48b3004fc..3197c6aa9 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -14,6 +14,8 @@ pretrained model checkpoints. """ +# ruff: noqa: T201 + from __future__ import annotations import copy @@ -173,7 +175,8 @@ def __init__( # noqa: C901, PLR0915 ) # Either the config path or the checkpoint path needs to be provided - assert config_yml or model is not None + if not config_yml and model is None: + raise ValueError("Either config_yml or model must be provided") checkpoint = None if config_yml is not None: diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 57dc08a42..a6ba406b6 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -199,10 +199,8 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens """ -def validate_model_outputs( - model: ModelInterface, - device: torch.device, - dtype: torch.dtype, +def validate_model_outputs( # noqa: C901, PLR0915 + model: ModelInterface, device: torch.device, dtype: torch.dtype ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -233,10 +231,9 @@ def validate_model_outputs( """ from ase.build import bulk - assert model.dtype is not None - assert model.device is not None - assert model.compute_stress is not None - assert model.compute_forces is not None + for attr in ("dtype", "device", "compute_stress", "compute_forces"): + if not hasattr(model, attr): + raise ValueError(f"model.{attr} is not set") try: if not model.compute_stress: @@ -265,52 +262,56 @@ def validate_model_outputs( model_output = model.forward(sim_state) # assert model did not mutate the input - assert torch.allclose(og_positions, sim_state.positions) - assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.batch) - assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers) + if not torch.allclose(og_positions, sim_state.positions): + raise ValueError(f"{og_positions=} != {sim_state.positions=}") + if not torch.allclose(og_cell, sim_state.cell): + raise ValueError(f"{og_cell=} != {sim_state.cell=}") + if not torch.allclose(og_batch, sim_state.batch): + raise ValueError(f"{og_batch=} != {sim_state.batch=}") + if not torch.allclose(og_atomic_numbers, sim_state.atomic_numbers): + raise ValueError(f"{og_atomic_numbers=} != {sim_state.atomic_numbers=}") # assert model output has the correct keys - assert "energy" in model_output - assert "forces" in model_output if force_computed else True - assert "stress" in model_output if stress_computed else True + if "energy" not in model_output: + raise ValueError("energy not in model output") + if force_computed and "forces" not in model_output: + raise ValueError("forces not in model output") + if stress_computed and "stress" not in model_output: + raise ValueError("stress not in model output") # assert model output shapes are correct - assert model_output["energy"].shape == (2,) - assert model_output["forces"].shape == (20, 3) if force_computed else True - assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True + if model_output["energy"].shape != (2,): + raise ValueError(f"{model_output['energy'].shape=} != (2,)") + if force_computed and model_output["forces"].shape != (20, 3): + raise ValueError(f"{model_output['forces'].shape=} != (20, 3)") + if stress_computed and model_output["stress"].shape != (2, 3, 3): + raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) si_model_output = model.forward(si_state) - assert torch.allclose( + if not torch.allclose( si_model_output["energy"], model_output["energy"][0], atol=10e-3 - ) - assert torch.allclose( - si_model_output["forces"], - model_output["forces"][: si_state.n_atoms], + ): + raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") + if not torch.allclose( + forces := si_model_output["forces"], + expected_forces := model_output["forces"][: si_state.n_atoms], atol=10e-3, - ) - # assert torch.allclose( - # si_model_output["stress"], - # model_output["stress"][0], - # atol=10e-3, - # ) + ): + raise ValueError(f"{forces=} != {expected_forces=}") fe_model_output = model.forward(fe_state) si_model_output = model.forward(si_state) - assert torch.allclose( + if not torch.allclose( fe_model_output["energy"], model_output["energy"][1], atol=10e-2 - ) - assert torch.allclose( - fe_model_output["forces"], - model_output["forces"][si_state.n_atoms :], + ): + raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") + if not torch.allclose( + forces := fe_model_output["forces"], + expected_forces := model_output["forces"][si_state.n_atoms :], atol=10e-2, - ) - # assert torch.allclose( - # arr_model_output["stress"], - # model_output["stress"][1], - # atol=10e-3, - # ) + ): + raise ValueError(f"{forces=} != {expected_forces=}") diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 2f56fe4fe..d9d92b9b0 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -169,7 +169,7 @@ def __init__( self.model = self.model.to(dtype=self.dtype) if enable_cueq: - print("Converting models to CuEq for acceleration") + print("Converting models to CuEq for acceleration") # noqa: T201 self.model = run_e3nn_to_cueq(self.model) # Set model properties diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index f855fd51c..8cbc0e7fc 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -642,9 +642,10 @@ def __init__( ) # Ensure parameter matrices are symmetric (required for energy conservation) - assert torch.allclose(self.sigma_matrix, self.sigma_matrix.T) - assert torch.allclose(self.epsilon_matrix, self.epsilon_matrix.T) - assert torch.allclose(self.alpha_matrix, self.alpha_matrix.T) + for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"): + matrix = getattr(self, matrix_name) + if not torch.allclose(matrix, matrix.T): + raise ValueError(f"{matrix_name} is not symmetric") # Set interaction cutoff distance self.cutoff = torch.tensor( diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 1b0bf1b80..cd4c439dd 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -79,19 +79,18 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 """ # Naming conventions: Suffixes indicate the dimension of an array. The # following convention is used here: - # c: Cartesian index, can have values 0, 1, 2 - # i: Global atom index, can have values 0..len(a)-1 - # xyz: Bin index, three values identifying x-, y- and z-component of a - # spatial bin that is used to make neighbor search O(n) - # b: Linearized version of the 'xyz' bin index - # a: Bin-local atom index, i.e. index identifying an atom *within* a - # bin - # p: Pair index, can have value 0 or 1 - # n: (Linear) neighbor index - - # Return empty neighbor list if no atoms are passed here + # c: Cartesian index, can have values 0, 1, 2 + # i: Global atom index, can have values 0..len(a)-1 + # xyz: Bin index, three values identifying x-, y- and z-component of a + # spatial bin that is used to make neighbor search O(n) + # b: Linearized version of the 'xyz' bin index + # a: Bin-local atom index, i.e. index identifying an atom *within* a + # bin + # p: Pair index, can have value 0 or 1 + # n: (Linear) neighbor index + if len(positions) == 0: - raise AssertionError("No atoms provided") + raise RuntimeError("No atoms provided") # Compute reciprocal lattice vectors. recip_cell = torch.linalg.pinv(cell).T @@ -109,7 +108,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 1 / l3 if l3 > 0 else pytorch_scalar_1, ] ) - assert face_dist_c.shape == (3,) + if face_dist_c.shape != (3,): + raise ValueError(f"face_dist_c.shape={face_dist_c.shape} != (3,)") # we don't handle other fancier cutoffs max_cutoff: torch.Tensor = cutoff @@ -214,8 +214,10 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 bin_index_i = bin_index_i[mask] # Make sure that all atoms have been sorted into bins. - assert len(atom_i) == 0 - assert len(bin_index_i) == 0 + if len(atom_i) != 0: + raise ValueError(f"len(atom_i)={len(atom_i)} != 0") + if len(bin_index_i) != 0: + raise ValueError(f"len(bin_index_i)={len(bin_index_i)} != 0") # Now we construct neighbor pairs by pairing up all atoms within a bin or # between bin and neighboring bin. atom_pairs_pn is a helper buffer that diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index bdc37239e..5c98dafec 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1273,7 +1273,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: - assert isinstance(state, FrechetCellFIREState) + if not isinstance(state, expected_cls := FrechetCellFIREState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() deform_grad_log = torch.zeros_like(cur_deform_grad) for b in range(n_batches): @@ -1291,7 +1292,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.row_vector_cell = new_row_vector_cell state.cell_positions = cell_positions_log_scaled_new else: - assert isinstance(state, UnitCellFireState) + if not isinstance(state, expected_cls := UnitCellFireState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( @@ -1329,7 +1331,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 ).unsqueeze(0).expand(n_batches, -1, -1) if is_frechet: - assert isinstance(state, FrechetCellFIREState) + if not isinstance(state, expected_cls := FrechetCellFIREState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2)) ) @@ -1353,7 +1356,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 new_cell_forces[b] = forces_flat.reshape(3, 3) state.cell_forces = new_cell_forces / cell_factor_reshaped else: - assert isinstance(state, UnitCellFireState) + if not isinstance(state, expected_cls := UnitCellFireState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") state.cell_forces = virial / cell_factor_reshaped state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) @@ -1564,7 +1568,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) if is_frechet: - assert isinstance(state, FrechetCellFIREState) + if not isinstance(state, expected_cls := FrechetCellFIREState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) @@ -1572,7 +1577,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT) state.row_vector_cell = new_row_vector_cell else: - assert isinstance(state, UnitCellFireState) + if not isinstance(state, expected_cls := UnitCellFireState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") F_current = state.deform_grad() cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult @@ -1599,7 +1605,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() - print( + print( # noqa: T201 f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) @@ -1619,13 +1625,16 @@ def _ase_fire_step( # noqa: C901, PLR0915 ).unsqueeze(0).expand(n_batches, -1, -1) if is_frechet: - assert isinstance(state, FrechetCellFIREState) - assert F_new is not None, ( - "F_new should be defined for Frechet cell force calculation" - ) - assert logm_F_new is not None, ( - "logm_F_new should be defined for Frechet cell force calculation" - ) + if not isinstance(state, expected_cls := FrechetCellFIREState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") + if F_new is None: + raise ValueError( + "F_new should be defined for Frechet cell force calculation" + ) + if logm_F_new is None: + raise ValueError( + "logm_F_new should be defined for Frechet cell force calculation" + ) ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) @@ -1649,7 +1658,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) else: - assert isinstance(state, UnitCellFireState) + if not isinstance(state, expected_cls := UnitCellFireState): + raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") state.cell_forces = virial / state.cell_factor return state diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 9cc9092cc..44d271818 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -521,8 +521,10 @@ def compute_distances_with_cell_shifts( torch.Tensor: A tensor of shape (n_pairs,) containing the computed distances for each pair. """ - assert mapping.dim() == 2 - assert mapping.shape[0] == 2 + if mapping.dim() != 2: + raise ValueError(f"Mapping must be a 2D tensor, got {mapping.shape}") + if mapping.shape[0] != 2: + raise ValueError(f"Mapping must have 2 rows, got {mapping.shape[0]}") if cell_shifts is None: dr = pos[mapping[1]] - pos[mapping[0]] diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 6f2043bf2..140c0031b 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -7,6 +7,8 @@ - Converting between different structural representations """ +# ruff: noqa: T201 + import itertools from collections.abc import Sequence from typing import Any