Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/dynsight/_internal/trajectory/cluster_insight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
from pathlib import Path

from numpy.typing import NDArray
from tropea_clustering._internal.first_classes import StateMulti, StateUni

from dynsight.trajectory import Insight, Trj


from tropea_clustering._internal.onion_smooth.first_classes import (
StateMulti,
StateUni,
)

import dynsight
from dynsight.logs import logger

Expand Down Expand Up @@ -167,6 +171,23 @@ def load_from_json(
logger.log(msg)
raise ValueError(msg)

raw_list = data["state_list"]
state_list = []

for entry in raw_list:
# Infer the correct class
if isinstance(entry.get("mean"), list):
state_cls = StateMulti
else:
state_cls = StateUni

# Convert lists back to ndarray
kwargs = {}
for k, v in entry.items():
kwargs[k] = np.array(v) if isinstance(v, list) else v

state_list.append(state_cls(**kwargs))

base_dir = file_path.parent
labels = np.load(base_dir / data["labels_file"], mmap_mode=mmap_mode)
reshaped = np.load(
Expand All @@ -178,7 +199,7 @@ def load_from_json(
return cls(
labels=labels,
reshaped_data=reshaped,
state_list=data["state_list"],
state_list=state_list,
meta=data.get("meta", {}),
)

Expand Down Expand Up @@ -353,6 +374,26 @@ def load_from_json(
logger.log(msg)
raise ValueError(msg)

raw_list = data["state_list"]
state_list = []

for entry in raw_list:
# Decide which class to use
if isinstance(entry.get("mean"), list):
state_cls = StateMulti
else:
state_cls = StateUni

# Rebuild kwargs (convert lists back to np.ndarrays)
kwargs = {}
for k, v in entry.items():
if isinstance(v, list):
kwargs[k] = np.array(v)
else:
kwargs[k] = v

state_list.append(state_cls(**kwargs))

labels_path = file_path.parent / data["labels_file"]
labels = np.load(labels_path, mmap_mode=mmap_mode)

Expand All @@ -363,7 +404,7 @@ def load_from_json(

return cls(
labels=labels,
state_list=data["state_list"],
state_list=state_list,
meta=data.get("meta", {}),
)

Expand Down