From 51e12e39b2694f16591433171542ebdf60fa339f Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 10 Mar 2026 07:44:11 +0000 Subject: [PATCH 1/2] wip --- embodichain/lab/gym/utils/gym_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index be021fc5..001bc7b3 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -705,14 +705,12 @@ def fetch_data_from_dict( return current_data -def assign_data_to_dict( - data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str, value: Any -) -> None: - """Assign data to a nested dictionary using a '/' separated key. - Missing intermediate dictionaries will be created automatically. +def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None: + """Assign data to a TensorDict using a '/' separated key. + Missing intermediate TensorDicts will be created automatically. Args: - data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to assign data to. + data_dict (TensorDict): The TensorDict to assign data to. name (str): The '/' separated key string. value (Any): The value to assign. """ @@ -720,8 +718,10 @@ def assign_data_to_dict( current_data = data_dict for key in keys[:-1]: - if key not in current_data or not isinstance(current_data[key], dict): - current_data[key] = {} # create intermediate dict if missing + if key not in current_data or not isinstance(current_data.get(key), TensorDict): + current_data[key] = TensorDict( + {}, batch_size=current_data.batch_size, device=current_data.device + ) current_data = current_data[key] last_key = keys[-1] From ec61b1cd74f5c6c80023bce32fa1b87bca5eb1e3 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 10 Mar 2026 09:39:12 +0000 Subject: [PATCH 2/2] wip --- embodichain/lab/gym/utils/gym_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 001bc7b3..20ac316a 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -716,6 +716,7 @@ def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None: """ keys = name.split("/") current_data = data_dict + batch_size = current_data.batch_size for key in keys[:-1]: if key not in current_data or not isinstance(current_data.get(key), TensorDict): @@ -725,6 +726,7 @@ def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None: current_data = current_data[key] last_key = keys[-1] + current_data.batch_size = batch_size # Ensure the batch size is consistent current_data[last_key] = value