diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 469b6c008f..655e6ce0a5 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1117,9 +1117,9 @@ def call( # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - edge_angle_update - * a_sw[:, :, :, xp.newaxis, xp.newaxis] + a_sw[:, :, :, xp.newaxis, xp.newaxis] * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 43cae8c746..712edf9b0a 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -698,9 +698,9 @@ def forward( # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - edge_angle_update - * a_sw[:, :, :, None, None] + a_sw[:, :, :, None, None] * a_sw[:, :, None, :, None] + * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = torch.sum(