Add oeq/hybrid acceleration, torch.compile support, and TorchSim interface#1390
Merged
ilyes319 merged 6 commits intoACEsuit:developfrom Mar 3, 2026
Merged
Add oeq/hybrid acceleration, torch.compile support, and TorchSim interface#1390ilyes319 merged 6 commits intoACEsuit:developfrom
ilyes319 merged 6 commits intoACEsuit:developfrom
Conversation
merge LoRA fixes and re-estimated E0s for ft
remove les from cfg because pypi release block
Add mace-polar models
…rface - Fix missing oeq_config in RealAgnosticResidualNonLinearInteractionBlock - Fix convert_e3nn_oeq: enabled=False -> enabled=True, add key guard - Add with_oeq_conv_fusion/with_oeq_scatter_sum wrappers in wrapper_ops - Add ir_mul/mul_ir layout transposes for hybrid (cueq+oeq) mode - Add convert_e3nn_hybrid converter (cueq symcon + oeq conv TP) - Support external displacement in get_symmetric_displacement for compile - Add padding_tools with isolated-system padding for torch.compile - Update ASE calculator: hybrid mode, compile fixes, padding support - Add MaceTorchSimModel with compile padding and kernel acceleration - Fix extract_model strict=False for converted models Made-with: Cursor
5b6a544 to
88ac388
Compare
When conv_fusion=True fails (non-uniform edge irreps like 128x0e+32x1o), automatically retry with conv_fusion=False instead of crashing. Made-with: Cursor
ab232b8 to
e2d8d0d
Compare
Made-with: Cursor
e2d8d0d to
b8161ca
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for accelerated inference backends and
torch.compileto the MACE ASE calculator, along with a new TorchSim interface:conv_fusionfallback for non-uniform edge irreps (e.g.128x0e+32x1o).torch.compilesupport: Works with both pure e3nn and accelerated backends. Handles oeq-specific constraints (displacement created outside compiled graph for stress,disallow_in_graph(autograd.grad)for oeq ops).2*r_maxto prevent spurious interactions with real atoms.mace_torchsim.py): Native integration with TorchSim for batched simulation, with padding and backend conversion built in.convert_e3nn_hybridCLI: New conversion script for e3nn-to-hybrid (cueq+oeq) model conversion.Files changed (10 files, +1140/-90)
mace/calculators/mace.py-- Refactored calculator with backend conversion, compile setup, and paddingmace/calculators/mace_torchsim.py-- New TorchSim MACE interfacemace/cli/convert_e3nn_hybrid.py-- New e3nn-to-hybrid conversion CLImace/cli/convert_e3nn_oeq.py-- Minor layout fixmace/data/__init__.py-- Export padding_toolsmace/data/padding_tools.py-- New padding utilities for graph paddingmace/modules/blocks.py-- 1-line fixmace/modules/utils.py-- Guardrequires_grad_withtorch.compiler.is_compiling()mace/modules/wrapper_ops.py-- cueq/oeq layout transposes for hybrid modemace/tools/scripts_utils.py-- 1-line fixTest plan
test_calculator.py-- 16 passedtest_compile.py-- Core compile tests passtest_models.py-- 6 passedtest_modules.py-- 8 passedtest_run_train.py-- 17 passedMade with Cursor