diff --git a/dpdata/abacus/stru.py b/dpdata/abacus/stru.py index eecfa8f18..50ec2cb72 100644 --- a/dpdata/abacus/stru.py +++ b/dpdata/abacus/stru.py @@ -751,66 +751,68 @@ def process_file_input(file_input, atom_names, input_name): out += "0.0\n" out += str(data["atom_numbs"][iele]) + "\n" for iatom in range(data["atom_numbs"][iele]): - iatomtype = np.nonzero(data["atom_types"] == iele)[0][iatom] + iatomtype = np.nonzero(data["atom_types"] == iele)[0][ + iatom + ] # it is the atom index iout = f"{data['coords'][frame_idx][iatomtype, 0]:.12f} {data['coords'][frame_idx][iatomtype, 1]:.12f} {data['coords'][frame_idx][iatomtype, 2]:.12f}" # add flags for move, velocity, mag, angle1, angle2, and sc if move is not None: if ( - isinstance(ndarray2list(move[natom_tot]), (list, tuple)) - and len(move[natom_tot]) == 3 + isinstance(ndarray2list(move[iatomtype]), (list, tuple)) + and len(move[iatomtype]) == 3 ): iout += " " + " ".join( - ["1" if ii else "0" for ii in move[natom_tot]] + ["1" if ii else "0" for ii in move[iatomtype]] ) - elif isinstance(ndarray2list(move[natom_tot]), (int, float, bool)): - iout += " 1 1 1" if move[natom_tot] else " 0 0 0" + elif isinstance(ndarray2list(move[iatomtype]), (int, float, bool)): + iout += " 1 1 1" if move[iatomtype] else " 0 0 0" else: iout += " 1 1 1" if ( velocity is not None - and isinstance(ndarray2list(velocity[natom_tot]), (list, tuple)) - and len(velocity[natom_tot]) == 3 + and isinstance(ndarray2list(velocity[iatomtype]), (list, tuple)) + and len(velocity[iatomtype]) == 3 ): - iout += " v " + " ".join([f"{ii:.12f}" for ii in velocity[natom_tot]]) + iout += " v " + " ".join([f"{ii:.12f}" for ii in velocity[iatomtype]]) if mag is not None: - if isinstance(ndarray2list(mag[natom_tot]), (list, tuple)) and len( - mag[natom_tot] + if isinstance(ndarray2list(mag[iatomtype]), (list, tuple)) and len( + mag[iatomtype] ) in [1, 3]: - iout += " mag " + " ".join([f"{ii:.12f}" for ii in mag[natom_tot]]) - elif isinstance(ndarray2list(mag[natom_tot]), (int, float)): - iout += " mag " + f"{mag[natom_tot]:.12f}" + iout += " mag " + " ".join([f"{ii:.12f}" for ii in mag[iatomtype]]) + elif isinstance(ndarray2list(mag[iatomtype]), (int, float)): + iout += " mag " + f"{mag[iatomtype]:.12f}" if angle1 is not None and isinstance( - ndarray2list(angle1[natom_tot]), (int, float) + ndarray2list(angle1[iatomtype]), (int, float) ): - iout += " angle1 " + f"{angle1[natom_tot]:.12f}" + iout += " angle1 " + f"{angle1[iatomtype]:.12f}" if angle2 is not None and isinstance( - ndarray2list(angle2[natom_tot]), (int, float) + ndarray2list(angle2[iatomtype]), (int, float) ): - iout += " angle2 " + f"{angle2[natom_tot]:.12f}" + iout += " angle2 " + f"{angle2[iatomtype]:.12f}" if sc is not None: - if isinstance(ndarray2list(sc[natom_tot]), (list, tuple)) and len( - sc[natom_tot] + if isinstance(ndarray2list(sc[iatomtype]), (list, tuple)) and len( + sc[iatomtype] ) in [1, 3]: iout += " sc " + " ".join( - ["1" if ii else "0" for ii in sc[natom_tot]] + ["1" if ii else "0" for ii in sc[iatomtype]] ) - elif isinstance(ndarray2list(sc[natom_tot]), (int, float, bool)): - iout += " sc " + "1" if sc[natom_tot] else "0" + elif isinstance(ndarray2list(sc[iatomtype]), (int, float, bool)): + iout += " sc " + "1" if sc[iatomtype] else "0" if lambda_ is not None: - if isinstance(ndarray2list(lambda_[natom_tot]), (list, tuple)) and len( - lambda_[natom_tot] + if isinstance(ndarray2list(lambda_[iatomtype]), (list, tuple)) and len( + lambda_[iatomtype] ) in [1, 3]: iout += " lambda " + " ".join( - [f"{ii:.12f}" for ii in lambda_[natom_tot]] + [f"{ii:.12f}" for ii in lambda_[iatomtype]] ) - elif isinstance(ndarray2list(lambda_[natom_tot]), (int, float)): - iout += " lambda " + f"{lambda_[natom_tot]:.12f}" + elif isinstance(ndarray2list(lambda_[iatomtype]), (int, float)): + iout += " lambda " + f"{lambda_[iatomtype]:.12f}" out += iout + "\n" natom_tot += 1 diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 29480d860..cf071920d 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -206,6 +206,61 @@ def test_dump_move_from_vasp(self): """ self.assertTrue(stru_ref in c) + def test_dump_chaotic_atomic_species(self): + import copy + + import numpy as np + + temp_system = copy.deepcopy(self.system_ch4) + temp_system.data["atom_types"] = np.array([1, 0, 1, 1, 1]) + temp_system.data["coords"] = np.array( + [[[1, 1, 1], [0, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]]] + ) + temp_system.data["move"] = np.array( + [[[1, 0, 0], [0, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]] + ) + velocity = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]]) + mag = np.array( + [[11, 11, 11], [22, 22, 22], [33, 33, 33], [44, 44, 44], [55, 55, 55]] + ) + constrain = np.array([1, 0, 1, 0, 1]) + sc = np.array([[0, 1, 1], [0, 0, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]) + lambda_ = np.array( + [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + [1.0, 1.1, 1.2], + [1.3, 1.4, 1.5], + ] + ) + temp_system.to( + "stru", + "STRU_tmp", + velocity=velocity, + mag=mag, + constrain=constrain, + sc=sc, + lambda_=lambda_, + ) + + assert os.path.isfile("STRU_tmp") + with open("STRU_tmp") as f: + lines = f.read() + ref_c = """C +0.0 +1 +0.000000000000 0.000000000000 0.000000000000 0 1 1 v 2.000000000000 2.000000000000 2.000000000000 mag 22.000000000000 22.000000000000 22.000000000000 sc 0 0 1 lambda 0.400000000000 0.500000000000 0.600000000000 +H +0.0 +4 +1.000000000000 1.000000000000 1.000000000000 1 0 0 v 1.000000000000 1.000000000000 1.000000000000 mag 11.000000000000 11.000000000000 11.000000000000 sc 0 1 1 lambda 0.100000000000 0.200000000000 0.300000000000 +2.000000000000 2.000000000000 2.000000000000 1 1 1 v 3.000000000000 3.000000000000 3.000000000000 mag 33.000000000000 33.000000000000 33.000000000000 sc 1 1 1 lambda 0.700000000000 0.800000000000 0.900000000000 +3.000000000000 3.000000000000 3.000000000000 1 1 1 v 4.000000000000 4.000000000000 4.000000000000 mag 44.000000000000 44.000000000000 44.000000000000 sc 1 1 1 lambda 1.000000000000 1.100000000000 1.200000000000 +4.000000000000 4.000000000000 4.000000000000 1 1 1 v 5.000000000000 5.000000000000 5.000000000000 mag 55.000000000000 55.000000000000 55.000000000000 sc 1 1 1 lambda 1.300000000000 1.400000000000 1.500000000000""" + + self.assertTrue(ref_c in lines) + class TestABACUSParseStru(unittest.TestCase): def test_parse_pos_oneline(self):