-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Description
The line:
edge_types[edge_types == 4] = 1.5fails because edge_types is a long tensor, truncating 1.5 to 1. This replaces 4 with 1 instead of 1.5, affecting logic and results.
Impact:
This propagates incorrect values, impacting downstream research like molecule and atom stabilities in works such as EQGAT-Diff and Semla-Flow.
Suggested Fix:
Convert edge_types to a floating-point tensor before replacement:
edge_types = edge_types.float()
edge_types[edge_types == 4] = 1.5MiDi/midi/metrics/molecular_metrics.py
Line 304 in 775b731
| edge_types[edge_types == 4] = 1.5 |
Metadata
Metadata
Assignees
Labels
No labels