feat(pt_expt): full model #5244
Conversation
| def forward( | ||
| self, | ||
| coord: torch.Tensor, | ||
| atype: torch.Tensor, | ||
| box: torch.Tensor | None = None, | ||
| fparam: torch.Tensor | None = None, | ||
| aparam: torch.Tensor | None = None, | ||
| do_atomic_virial: bool = False, | ||
| ) -> dict[str, torch.Tensor]: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning
| return | ||
| return super().__setattr__(name, value) | ||
|
|
||
| def call(self, x: torch.Tensor) -> torch.Tensor: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning
| "energy": energy_data, | ||
| "find_energy": np.float32(1.0), | ||
| } | ||
| ] |
Check notice
Code scanning / CodeQL
Imprecise assert Note test
📝 WalkthroughWalkthroughThis PR adds a PyTorch-exportable backend (pt_expt) with model, fitting, and network utilities; refactors dpmodel casting helpers to private variants; adds output-bias management and a descriptor accessor; makes tensor/device handling and assignment changes in transform/utility code; and expands tests for pt_expt (autodiff, exportable tracing, cross-backend API checks). Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant EM as EnergyModel (pt_expt)
participant CM as make_model.CM
participant Atomic as AtomicModel
participant Transform as transform_output.fit_output_to_model_output
participant Autograd as torch.autograd
App->>EM: forward(...)/forward_lower(...)
EM->>CM: call / call_lower
CM->>Atomic: forward_common_atomic (extended_coord, ...)
Atomic-->>CM: fit_ret (atom energies etc.)
CM->>Transform: fit_output_to_model_output(fit_ret, ...)
Transform->>Autograd: compute gradients (forces/virial) if needed
Autograd-->>Transform: derivative tensors
Transform-->>CM: model-formatted outputs (energy, atom_energy, force, virial, ...)
CM-->>EM: dict[str, Tensor]
EM-->>App: return outputs
sequenceDiagram
participant Dev as Developer/Test
participant EM as EnergyModel
participant Trace as torch.fx.make_fx
participant Export as torch.export.export
Dev->>EM: forward_lower_exportable(inputs)
EM->>Trace: make_fx(_forward_lower) with grad enabled
Trace-->>EM: traced Module
Dev->>Export: torch.export.export(traced Module)
Export-->>Dev: exportable artifact
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (4)
source/tests/consistent/model/common.py (1)
110-121: Consider using a shared utility for tensor-to-numpy conversion.The other eval methods (e.g.,
eval_pt_model) use dedicated conversion utilities (torch_to_numpy), while this method uses raw.detach().cpu().numpy(). For consistency, consider importing apt_exptequivalent if one exists — or keep as-is since it's just test code.The
natomsparameter is unused but matches the interface contract of all siblingeval_*_modelmethods, so the static analysis warning is a false positive.deepmd/pt_expt/model/ener_model.py (1)
120-171: Note:do_atomic_virialis baked into the traced module.The
do_atomic_virialflag is captured in the closure (line 166) and becomes a constant in themake_fx-traced module. This is documented in the docstring and appears intentional, but callers should be aware that a separate trace is needed for each value ofdo_atomic_virial.deepmd/pt_expt/model/transform_output.py (1)
107-110: Consider addingstrict=Trueto thezipcall.
split_vv1andsplit_svv1are guaranteed to have the same length (both split with[1] * size), so this is safe. However, addingstrict=Trueprovides a defensive check and silences the Ruff B905 warning.Proposed fix
- for vvi, svvi in zip(split_vv1, split_svv1): + for vvi, svvi in zip(split_vv1, split_svv1, strict=True):source/tests/consistent/model/test_ener.py (1)
137-138: Inconsistent guard: PT_EXPT checkspt_expt_class is not Nonebut other backends don't.Lines 133–134 for PT only check
self.skip_pt, while PT_EXPT additionally checksself.pt_expt_class is not None. This is defensive and harmless, but the inconsistency is worth noting. IfINSTALLED_PT_EXPTisTrue,EnergyModelPTExptshould never beNone.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e51e9d4fb
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt_expt/model/ener_model.py (1)
148-167:selfanddo_atomic_virialare captured by the closure — both become constants in the traced graph.This is fine for single-configuration export, but worth documenting explicitly:
do_atomic_virial(bool) is resolved statically during tracing — the exported module will always (or never) produce virial outputs, depending on the value passed here.model = selfmeans the traced graph inlines all model parameters. Any parameter updates after tracing won't be reflected.If this is the intended contract, a brief inline comment on lines 148 and 166-167 noting this would save future readers from wondering.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5244 +/- ##
==========================================
+ Coverage 82.12% 82.16% +0.03%
==========================================
Files 736 740 +4
Lines 74237 74420 +183
Branches 3615 3616 +1
==========================================
+ Hits 60966 61144 +178
- Misses 12107 12114 +7
+ Partials 1164 1162 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Bug Fixes
Tests