Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions embodichain/lab/gym/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,26 +705,28 @@ def fetch_data_from_dict(
return current_data


def assign_data_to_dict(
data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str, value: Any
) -> None:
"""Assign data to a nested dictionary using a '/' separated key.
Missing intermediate dictionaries will be created automatically.
def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None:
"""Assign data to a TensorDict using a '/' separated key.
Missing intermediate TensorDicts will be created automatically.

Args:
data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to assign data to.
data_dict (TensorDict): The TensorDict to assign data to.
name (str): The '/' separated key string.
value (Any): The value to assign.
"""
Comment on lines +708 to 716
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper now enforces TensorDict for assignment, but the paired fetch_data_from_dict helper (used in ObservationManager) is still typed/documented as operating on Dict[...]. Consider updating the fetch helper’s type hints/docstring (and any referenced docs) to TensorDict/EnvObs too, so callers don’t get misleading typing and the TensorDict-only contract is consistent.

Copilot uses AI. Check for mistakes.
keys = name.split("/")
current_data = data_dict
batch_size = current_data.batch_size

for key in keys[:-1]:
if key not in current_data or not isinstance(current_data[key], dict):
current_data[key] = {} # create intermediate dict if missing
if key not in current_data or not isinstance(current_data.get(key), TensorDict):
current_data[key] = TensorDict(
{}, batch_size=current_data.batch_size, device=current_data.device
)
current_data = current_data[key]

last_key = keys[-1]
current_data.batch_size = batch_size # Ensure the batch size is consistent
current_data[last_key] = value
Comment on lines 728 to 730
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_data.batch_size = batch_size is likely invalid for tensordict.TensorDict (batch_size is typically derived/validated and may be read-only), and even if it works it can silently desynchronize the nested TensorDict from the shapes of its contained tensors. Prefer enforcing batch size by creating intermediate TensorDicts with the intended batch_size (e.g., from the root) and, when traversing existing nested entries, validate their batch_size matches instead of mutating it.

Copilot uses AI. Check for mistakes.


Expand Down