diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 04d3d7b08d..3b92351a11 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -229,7 +229,7 @@ def get_sel(jdata, rcut): max_rcut = get_rcut(jdata) type_map = get_type_map(jdata) - if len(type_map) == 0: + if type_map and len(type_map) == 0: type_map = None train_data = get_data(jdata["training"]["training_data"], max_rcut, type_map, None) train_data.get_batch()