diff --git a/source/train/calculator.py b/source/train/calculator.py index 02f2c77a09..37fb7ad412 100644 --- a/source/train/calculator.py +++ b/source/train/calculator.py @@ -32,12 +32,12 @@ class DP(Calculator): name = "DP" implemented_properties = ["energy", "forces", "stress"] - def __init__(self, model, label="DP", **kwargs): + def __init__(self, model, label="DP", type_dict=None, **kwargs): Calculator.__init__(self, label=label, **kwargs) self.dp = DeepPot(model) - try: - self.type_dict=kwargs['type_dict'] - except: + if type_dict: + self.type_dict=type_dict + else: self.type_dict = dict(zip(self.dp.get_type_map(), range(self.dp.get_ntypes()))) def calculate(self, atoms=None, properties=["energy", "forces", "stress"], system_changes=all_changes):