From a709d2571511689c043a24a8867fe533a1e117db Mon Sep 17 00:00:00 2001 From: Matteo Becchi Date: Fri, 21 Nov 2025 10:21:21 +0100 Subject: [PATCH 1/2] OnionSmoothInsight now loaded correctly from JSON. --- .../_internal/trajectory/cluster_insight.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/dynsight/_internal/trajectory/cluster_insight.py b/src/dynsight/_internal/trajectory/cluster_insight.py index 99c7f6cc..082ff8a5 100644 --- a/src/dynsight/_internal/trajectory/cluster_insight.py +++ b/src/dynsight/_internal/trajectory/cluster_insight.py @@ -10,11 +10,15 @@ from pathlib import Path from numpy.typing import NDArray - from tropea_clustering._internal.first_classes import StateMulti, StateUni from dynsight.trajectory import Insight, Trj +from tropea_clustering._internal.onion_smooth.first_classes import ( + StateMulti, + StateUni, +) + import dynsight from dynsight.logs import logger @@ -353,6 +357,26 @@ def load_from_json( logger.log(msg) raise ValueError(msg) + raw_list = data["state_list"] + state_list = [] + + for entry in raw_list: + # Decide which class to use + if isinstance(entry.get("mean"), list): + state_cls = StateMulti + else: + state_cls = StateUni + + # Rebuild kwargs (convert lists back to np.ndarrays) + kwargs = {} + for k, v in entry.items(): + if isinstance(v, list): + kwargs[k] = np.array(v) + else: + kwargs[k] = v + + state_list.append(state_cls(**kwargs)) + labels_path = file_path.parent / data["labels_file"] labels = np.load(labels_path, mmap_mode=mmap_mode) @@ -363,7 +387,7 @@ def load_from_json( return cls( labels=labels, - state_list=data["state_list"], + state_list=state_list, meta=data.get("meta", {}), ) From d10894bc6e2bb8d32408754e5bca1f436725a27c Mon Sep 17 00:00:00 2001 From: Matteo Becchi Date: Fri, 21 Nov 2025 10:27:57 +0100 Subject: [PATCH 2/2] OnionInsight now loaded correctly from JSON. --- .../_internal/trajectory/cluster_insight.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/dynsight/_internal/trajectory/cluster_insight.py b/src/dynsight/_internal/trajectory/cluster_insight.py index 082ff8a5..86c4ad57 100644 --- a/src/dynsight/_internal/trajectory/cluster_insight.py +++ b/src/dynsight/_internal/trajectory/cluster_insight.py @@ -171,6 +171,23 @@ def load_from_json( logger.log(msg) raise ValueError(msg) + raw_list = data["state_list"] + state_list = [] + + for entry in raw_list: + # Infer the correct class + if isinstance(entry.get("mean"), list): + state_cls = StateMulti + else: + state_cls = StateUni + + # Convert lists back to ndarray + kwargs = {} + for k, v in entry.items(): + kwargs[k] = np.array(v) if isinstance(v, list) else v + + state_list.append(state_cls(**kwargs)) + base_dir = file_path.parent labels = np.load(base_dir / data["labels_file"], mmap_mode=mmap_mode) reshaped = np.load( @@ -182,7 +199,7 @@ def load_from_json( return cls( labels=labels, reshaped_data=reshaped, - state_list=data["state_list"], + state_list=state_list, meta=data.get("meta", {}), )