In dump_colored_trj, it should be replaced self.labels.shape[0] with self.labels.shape[1].
if self.labels.shape != (n_atoms, n_frames):
msg = (
f"Shape mismatch: Trj should have {self.labels.shape[0]} "
f"atoms, {self.labels.shape[0]} frames, but has {n_atoms} "
f"atoms, {n_frames} frames."
)
logger.log(msg)
raise ValueError(msg)
In
dump_colored_trj, it should be replacedself.labels.shape[0]withself.labels.shape[1].