Added Multi Trajectory context window generation#9
Conversation
- Modified `TrajectoryContextDataset` to accept multiple trajectory data.
|
It's worth noting that the HuggingFace This class seamlessly converts the data format to the appropriate backend (numpy, torch, jax, TensorFlow) using the Therefore, Kooplearn could primarily focus on trajectory/context slicing, leaving the backend management to the HuggingFace backend, which is already integrated within the framework. |
kooplearn/data.py
Outdated
| raise ImportError( | ||
| "You selected the 'torch' backend, but kooplearn wasn't able to import it." | ||
| ) | ||
| if isinstance(data_per_traj[0], np.ndarray): |
There was a problem hiding this comment.
By concatenating everything together, in the case of multi-trajectories it becomes hard to answer questions like "What was the n-th frame of the k-th trajectory?"
Of course, if every trajectory has the same length, this can be quickly recovered, but this might not always be the case (see my other comment).
The primary purpose of the idx_map, although the API is not exploited yet, is to help the user evaluate multi-step forecasting errors. For a simple trajectory, this can be done, e.g. as
for t in times:
ref_idxs = test_contexts.idx_map.data[:, 0, 0] + t
Y_pred = model.predict(test_contexts, t=t)
Y_true = test_contexts.data[ref_idxs]There was a problem hiding this comment.
It was unclear for me from the documentation and docstring what the idx_map role was.
I follow now what you mean. Will update.
In the case of multi-trajectory then, idx_map will have an index per trajectory.
- After initializing TensorContextDataset the `self.backend` attribute stores the desired backend. Either torch or numpy. - The method __getitems__ is a method from Dataset classes that enable fast collection of samples already in batch form. As our datasets are assumed to be in memory, we can enable this fast indexing.
kooplearn/data.py
Outdated
| elif isinstance(idx, slice): | ||
| return TensorContextDataset(self.data[idx]) | ||
|
|
||
| def __getitems__(self, indices: list[int]) -> 'TensorContextDataset': |
There was a problem hiding this comment.
Nice, I didn't know if this trick. As far as I can tell, however, it only works with the torch backend, right? I propose to change __getitem__ as follows
def __getitem__(self, idx) -> 'TensorContextDataset':
if np.issubdtype(type(idx), np.integer):
# TODO: The default collect behaviour is to return a list of [type] objects. This additional dimension here
# seems to be introduced only for a very specific customm collect_fn.
return TensorContextDataset(self.data[idx][None, ...])
elif isinstance(idx, slice):
return TensorContextDataset(self.data[idx])
# CHANGE STARTS HERE
else:
if self.backend == 'torch':
self.__getitems__(idx)
else:
return TensorContextDataset(self.data[idx])
There was a problem hiding this comment.
Yes, this is torch specific. The __getitems__ method is called on the default data fetch method used by DataLoaders
|
@Danfoa any updates on this pull request? |
I modified
TrajectoryContextDatasetto accept multiple trajectory data.As the data stream remains agnostic to the dimensions of the features, I had to add a "multi_traj" boolean argument at the class's initializer to assert if the first dimension of the vector is considered as the number of trajectories or not.
I suggest enforcing a prefix input expected shape of (n_trajs, n_frames, *features_dims), to avoid this new unnecessary argument and rely on the auxiliary methods
traj_to_contextsandmulti_traj_to_contextto deal with the more flexible shapes.Comments: There seems to be a perhaps unnecessarily deep hierarchy or classes
TrajectoryContextDataset->TensorContextDataset->ContextWindowDataset->ContextWindow->Sequence, for handling the time series data slicing. Several of these classes repeat data parsing to appropriate backends and shapes. This deep hierarchy makes it quite cumbersome to modify the data pipeline for someone unfamiliar with the code. I propose to reduce this hierarchy to 1 or 2 classes maximum.