diff --git a/matdeeplearn/models/schnet.py b/matdeeplearn/models/schnet.py index ac23b7c0..acf86a15 100644 --- a/matdeeplearn/models/schnet.py +++ b/matdeeplearn/models/schnet.py @@ -15,7 +15,7 @@ from matdeeplearn.common.registry import registry from matdeeplearn.models.base_model import BaseModel, conditional_grad -from matdeeplearn.preprocessor.helpers import GaussianSmearing +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot @registry.register_model("SchNet") class SchNet(BaseModel): @@ -246,4 +246,4 @@ def forward(self, data): output["pos_grad"] = None output["cell_grad"] = None - return output \ No newline at end of file + return output diff --git a/matdeeplearn/models/torchmd_etEarly.py b/matdeeplearn/models/torchmd_etEarly.py index 905c774d..ca9685a2 100644 --- a/matdeeplearn/models/torchmd_etEarly.py +++ b/matdeeplearn/models/torchmd_etEarly.py @@ -15,6 +15,7 @@ from matdeeplearn.models.base_model import BaseModel, conditional_grad from matdeeplearn.models.torchmd_output_modules import Scalar, EquivariantScalar from matdeeplearn.common.registry import registry +from matdeeplearn.preprocessor.helpers import node_rep_one_hot @registry.register_model("torchmd_etEarly") @@ -411,4 +412,4 @@ def aggregate( def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: - return inputs \ No newline at end of file + return inputs