From 4874218e37eb055e0359ca10a5f1fc8617490868 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:38:37 +0800 Subject: [PATCH] Perf: replace unnecessary torch.split with indexing --- deepmd/pt/model/descriptor/repflow_layer.py | 2 +- deepmd/pt/model/descriptor/repformer_layer.py | 2 +- deepmd/pt/utils/nlist.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 94c4945c76..2278527ef3 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -314,7 +314,7 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: # nb x nloc x 3 x e_dim nb, nloc, _, e_dim = h2g2.shape # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] + h2g2m = h2g2[..., :axis_neuron] # nb x nloc x axis x e_dim g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 86b09e9b40..1e2cba66d6 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -1003,7 +1003,7 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: # nb x nloc x 3 x ng2 nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] + h2g2m = h2g2[..., :axis_neuron] # nb x nloc x axis x ng2 g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index db1e87785b..ec94e8cd60 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -310,7 +310,7 @@ def nlist_distinguish_types( inlist = torch.gather(nlist, 2, imap) inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) # nloc x nsel[ii] - ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) + ret_nlist.append(inlist[..., :ss]) return torch.concat(ret_nlist, dim=-1)