-
Notifications
You must be signed in to change notification settings - Fork 8
Fix issue in obs manager for adding new key #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -705,26 +705,28 @@ def fetch_data_from_dict( | |
| return current_data | ||
|
|
||
|
|
||
| def assign_data_to_dict( | ||
| data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str, value: Any | ||
| ) -> None: | ||
| """Assign data to a nested dictionary using a '/' separated key. | ||
| Missing intermediate dictionaries will be created automatically. | ||
| def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None: | ||
| """Assign data to a TensorDict using a '/' separated key. | ||
| Missing intermediate TensorDicts will be created automatically. | ||
|
|
||
| Args: | ||
| data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to assign data to. | ||
| data_dict (TensorDict): The TensorDict to assign data to. | ||
| name (str): The '/' separated key string. | ||
| value (Any): The value to assign. | ||
| """ | ||
| keys = name.split("/") | ||
| current_data = data_dict | ||
| batch_size = current_data.batch_size | ||
|
|
||
| for key in keys[:-1]: | ||
| if key not in current_data or not isinstance(current_data[key], dict): | ||
| current_data[key] = {} # create intermediate dict if missing | ||
| if key not in current_data or not isinstance(current_data.get(key), TensorDict): | ||
| current_data[key] = TensorDict( | ||
| {}, batch_size=current_data.batch_size, device=current_data.device | ||
| ) | ||
yuecideng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| current_data = current_data[key] | ||
|
|
||
| last_key = keys[-1] | ||
| current_data.batch_size = batch_size # Ensure the batch size is consistent | ||
| current_data[last_key] = value | ||
|
Comment on lines
728
to
730
|
||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper now enforces
TensorDictfor assignment, but the pairedfetch_data_from_dicthelper (used inObservationManager) is still typed/documented as operating onDict[...]. Consider updating the fetch helper’s type hints/docstring (and any referenced docs) toTensorDict/EnvObstoo, so callers don’t get misleading typing and the TensorDict-only contract is consistent.