From 346cbf3708c9083a731cab01e2beb1ef3e745b26 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 19 Sep 2023 10:49:38 -0400 Subject: [PATCH 1/2] Import node representation --- matdeeplearn/models/torchmd_etEarly.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/matdeeplearn/models/torchmd_etEarly.py b/matdeeplearn/models/torchmd_etEarly.py index 905c774d..ca9685a2 100644 --- a/matdeeplearn/models/torchmd_etEarly.py +++ b/matdeeplearn/models/torchmd_etEarly.py @@ -15,6 +15,7 @@ from matdeeplearn.models.base_model import BaseModel, conditional_grad from matdeeplearn.models.torchmd_output_modules import Scalar, EquivariantScalar from matdeeplearn.common.registry import registry +from matdeeplearn.preprocessor.helpers import node_rep_one_hot @registry.register_model("torchmd_etEarly") @@ -411,4 +412,4 @@ def aggregate( def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: - return inputs \ No newline at end of file + return inputs From 29d9774a7e04237b23b40022133da3156d29c7ca Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 19 Sep 2023 10:50:33 -0400 Subject: [PATCH 2/2] Import node_rep_one_hot for schnet --- matdeeplearn/models/schnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matdeeplearn/models/schnet.py b/matdeeplearn/models/schnet.py index ac23b7c0..acf86a15 100644 --- a/matdeeplearn/models/schnet.py +++ b/matdeeplearn/models/schnet.py @@ -15,7 +15,7 @@ from matdeeplearn.common.registry import registry from matdeeplearn.models.base_model import BaseModel, conditional_grad -from matdeeplearn.preprocessor.helpers import GaussianSmearing +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot @registry.register_model("SchNet") class SchNet(BaseModel): @@ -246,4 +246,4 @@ def forward(self, data): output["pos_grad"] = None output["cell_grad"] = None - return output \ No newline at end of file + return output