diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index be021fc5..20ac316a 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -705,26 +705,28 @@ def fetch_data_from_dict( return current_data -def assign_data_to_dict( - data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str, value: Any -) -> None: - """Assign data to a nested dictionary using a '/' separated key. - Missing intermediate dictionaries will be created automatically. +def assign_data_to_dict(data_dict: TensorDict, name: str, value: Any) -> None: + """Assign data to a TensorDict using a '/' separated key. + Missing intermediate TensorDicts will be created automatically. Args: - data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to assign data to. + data_dict (TensorDict): The TensorDict to assign data to. name (str): The '/' separated key string. value (Any): The value to assign. """ keys = name.split("/") current_data = data_dict + batch_size = current_data.batch_size for key in keys[:-1]: - if key not in current_data or not isinstance(current_data[key], dict): - current_data[key] = {} # create intermediate dict if missing + if key not in current_data or not isinstance(current_data.get(key), TensorDict): + current_data[key] = TensorDict( + {}, batch_size=current_data.batch_size, device=current_data.device + ) current_data = current_data[key] last_key = keys[-1] + current_data.batch_size = batch_size # Ensure the batch size is consistent current_data[last_key] = value