From 7620a3530628efaefc392a5efb3d15f3fb18a329 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Fri, 28 Mar 2025 12:11:14 +0800 Subject: [PATCH] perf: change order of element-wise op in edge angle update calculations --- deepmd/dpmodel/descriptor/repflows.py | 4 ++-- deepmd/pt/model/descriptor/repflow_layer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 469b6c008f..655e6ce0a5 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1117,9 +1117,9 @@ def call( # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - edge_angle_update - * a_sw[:, :, :, xp.newaxis, xp.newaxis] + a_sw[:, :, :, xp.newaxis, xp.newaxis] * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 43cae8c746..712edf9b0a 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -698,9 +698,9 @@ def forward( # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - edge_angle_update - * a_sw[:, :, :, None, None] + a_sw[:, :, :, None, None] * a_sw[:, :, None, :, None] + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = torch.sum(