diff --git a/src/dynsight/_internal/trajectory/cluster_insight.py b/src/dynsight/_internal/trajectory/cluster_insight.py index 99c7f6cc..86c4ad57 100644 --- a/src/dynsight/_internal/trajectory/cluster_insight.py +++ b/src/dynsight/_internal/trajectory/cluster_insight.py @@ -10,11 +10,15 @@ from pathlib import Path from numpy.typing import NDArray - from tropea_clustering._internal.first_classes import StateMulti, StateUni from dynsight.trajectory import Insight, Trj +from tropea_clustering._internal.onion_smooth.first_classes import ( + StateMulti, + StateUni, +) + import dynsight from dynsight.logs import logger @@ -167,6 +171,23 @@ def load_from_json( logger.log(msg) raise ValueError(msg) + raw_list = data["state_list"] + state_list = [] + + for entry in raw_list: + # Infer the correct class + if isinstance(entry.get("mean"), list): + state_cls = StateMulti + else: + state_cls = StateUni + + # Convert lists back to ndarray + kwargs = {} + for k, v in entry.items(): + kwargs[k] = np.array(v) if isinstance(v, list) else v + + state_list.append(state_cls(**kwargs)) + base_dir = file_path.parent labels = np.load(base_dir / data["labels_file"], mmap_mode=mmap_mode) reshaped = np.load( @@ -178,7 +199,7 @@ def load_from_json( return cls( labels=labels, reshaped_data=reshaped, - state_list=data["state_list"], + state_list=state_list, meta=data.get("meta", {}), ) @@ -353,6 +374,26 @@ def load_from_json( logger.log(msg) raise ValueError(msg) + raw_list = data["state_list"] + state_list = [] + + for entry in raw_list: + # Decide which class to use + if isinstance(entry.get("mean"), list): + state_cls = StateMulti + else: + state_cls = StateUni + + # Rebuild kwargs (convert lists back to np.ndarrays) + kwargs = {} + for k, v in entry.items(): + if isinstance(v, list): + kwargs[k] = np.array(v) + else: + kwargs[k] = v + + state_list.append(state_cls(**kwargs)) + labels_path = file_path.parent / data["labels_file"] labels = np.load(labels_path, mmap_mode=mmap_mode) @@ -363,7 +404,7 @@ def load_from_json( return cls( labels=labels, - state_list=data["state_list"], + state_list=state_list, meta=data.get("meta", {}), )