From 30fdc31d17eb71532405a6fc79e3edf2a6002841 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 10:02:39 +0800 Subject: [PATCH 01/23] try Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 14 ++++++-- transfer_queue/metadata.py | 69 +++++++++++++++++++++----------------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index cdce339..e09dd1e 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -209,10 +209,10 @@ async def async_get_meta( >>> print(batch_meta.is_ready) # True if all samples ready >>> >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> so may include unready samples. Consumed samples will not be fetched.) + >>> # so may include unready samples. Consumed samples will not be fetched.) >>> batch_meta = asyncio.run(client.async_get_meta( ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, + ... batch_size=4, # this is optional when using force_fetch ... partition_id="train_0", ... mode="force_fetch", ... task_name="generate_sequences" @@ -253,6 +253,14 @@ async def async_get_meta( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e + + @dynamic_socket(socket_name="request_handle_socket") + async def async_set_meta_extra_info( + self, + metadata: BatchMeta, + ): + + async def async_put( self, data: TensorDict, @@ -338,6 +346,8 @@ async def async_put( ): await self.storage_manager.put_data(data, metadata) + await self.async_set_meta_extra_info(metadata) + logger.debug( f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully." ) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 404a54a..2737e90 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -201,9 +201,15 @@ class BatchMeta: """Records the metadata of a batch of data samples.""" samples: list[SampleMeta] + + # external meta for non-sample level information extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # internal data for different storage backends: _custom_meta[global_index][field] - _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + + # external user-defined meta for each sample + custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + + # internal meta for different storage backends for each sample + _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize all computed properties during initialization""" @@ -260,44 +266,47 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) - # Custom meta methods for different storage backends - def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: - """Get the entire custom meta dictionary""" - return copy.deepcopy(self._custom_meta) - - def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]): - """Update custom meta with a new dictionary""" - if new_custom_meta: - self._custom_meta.update(new_custom_meta) - - # Extra info interface methods - def get_extra_info(self, key: str, default: Any = None) -> Any: - """Get extra info by key""" - return self.extra_info.get(key, default) - - def set_extra_info(self, key: str, value: Any) -> None: - """Set extra info by key""" - self.extra_info[key] = value + def get_all_extra_info(self) -> dict[str, Any]: + """Get all extra info as a dictionary""" + return copy.deepcopy(self.extra_info) def update_extra_info(self, info_dict: dict[str, Any]) -> None: """Update extra info with multiple key-value pairs""" self.extra_info.update(info_dict) - def remove_extra_info(self, key: str) -> Any: - """Remove extra info by key and return its value""" - return self.extra_info.pop(key, None) - def clear_extra_info(self) -> None: """Clear all extra info""" self.extra_info.clear() - def has_extra_info(self, key: str) -> bool: - """Check if extra info contains a specific key""" - return key in self.extra_info + def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: + """Get the entire custom_meta dictionary""" + return copy.deepcopy(self.custom_meta) + + def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]]): + """Update custom_meta with a new dictionary""" + self._custom_meta.update(new_custom_meta) + + def clear_custom_meta(self) -> None: + """Clear custom_meta""" + self._custom_meta.clear() + + def get_all_custom_backend_meta(self) -> dict[int, dict[str, Any]]: + """Get the entire _custom_backend_meta dictionary""" + return copy.deepcopy(self._custom_backend_meta) + + def update_custom_backend_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]): + """Update _custom_backend_meta with a new dictionary""" + if new_custom_meta: + self._custom_backend_meta.update(new_custom_meta) + + def clear_custom_backend_meta(self) -> None: + """Clear _custom_backend_meta""" + self._custom_backend_meta.clear() + + + + - def get_all_extra_info(self) -> dict[str, Any]: - """Get all extra info as a dictionary""" - return copy.deepcopy(self.extra_info) def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """ From d2d8c3ec7b9f2a2ca7a123aaa89410d5d8b1a8da Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 15:00:02 +0800 Subject: [PATCH 02/23] try Signed-off-by: 0oshowero0 --- tests/test_controller_data_partitions.py | 2 +- transfer_queue/client.py | 94 ++++++++++++++++++------ transfer_queue/controller.py | 84 +++++++++++++++++---- transfer_queue/metadata.py | 78 ++++++++++++++++---- transfer_queue/utils/zmq_utils.py | 2 + 5 files changed, 205 insertions(+), 55 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 976b277..84c8b63 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -482,7 +482,7 @@ def test_custom_meta_in_data_partition_status(): assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2 # Retrieval via helper for a subset of fields - retrieved = partition.get_field_custom_meta([0, 1], ["input_ids", "attention_mask"]) + retrieved = partition.get_field_custom_backend_meta([0, 1], ["input_ids", "attention_mask"]) assert 0 in retrieved and "input_ids" in retrieved[0] assert 1 in retrieved and "attention_mask" in retrieved[1] diff --git a/transfer_queue/client.py b/transfer_queue/client.py index e09dd1e..271d0aa 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -16,6 +16,7 @@ import asyncio import logging import os +from collections import defaultdict import threading from functools import wraps from typing import Any, Callable, Optional, Union @@ -211,11 +212,8 @@ async def async_get_meta( >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, >>> # so may include unready samples. Consumed samples will not be fetched.) >>> batch_meta = asyncio.run(client.async_get_meta( - ... data_fields=["input_ids", "attention_mask"], - ... batch_size=4, # this is optional when using force_fetch - ... partition_id="train_0", + ... partition_id="train_0", # optional ... mode="force_fetch", - ... task_name="generate_sequences" ... )) >>> print(batch_meta.is_ready) # May be False if some samples not ready """ @@ -234,31 +232,81 @@ async def async_get_meta( }, ) - try: - await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() - response_msg = ZMQMessage.deserialize(response_serialized) - logger.debug( - f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}" + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: + metadata_dict = response_msg.body["metadata"] + return BatchMeta.from_dict(metadata_dict) if isinstance(metadata_dict, dict) else metadata_dict + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" ) - if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: - metadata_dict = response_msg.body["metadata"] - return BatchMeta.from_dict(metadata_dict) if isinstance(metadata_dict, dict) else metadata_dict - else: - raise RuntimeError( - f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e @dynamic_socket(socket_name="request_handle_socket") - async def async_set_meta_extra_info( + async def async_set_custom_meta( self, metadata: BatchMeta, - ): + socket: Optional[zmq.asyncio.Socket] = None, + ) -> None: + assert socket is not None + + if not self._controller: + raise RuntimeError("No controller registered") + + partition_ids = metadata.partition_ids + global_indexes = metadata.global_indexes + custom_meta = metadata.get_all_custom_meta() + + + if len(global_indexes) == 0 or len(custom_meta) == 0: + logger.warning(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.") + return + + + non_exist_global_indexes = set(custom_meta.keys()) - set(global_indexes) + if bool(non_exist_global_indexes): + raise ValueError(f"Trying to update custom_meta with non-exist global_indexes! " + f"{non_exist_global_indexes} do not exist in this batch.") + + # chunk metadata according to partition_ids + metadata_chunks = metadata.chunk_by_partition() + + partition_custom_meta = {k:[] for k in metadata.partition_ids} + + for meta in metadata_chunks: + partition_custom_meta[meta.partition_ids[0]].append(meta.custom_meta) + + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.SET_CUSTOM_META, + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "partition_custom_meta": partition_custom_meta + }, + ) + + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from " + f"controller {self._controller.id}" + ) + + if response_msg.request_type != ZMQRequestType.SET_CUSTOM_META_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Failed to set custom metadata from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) async def async_put( @@ -346,7 +394,7 @@ async def async_put( ): await self.storage_manager.put_data(data, metadata) - await self.async_set_meta_extra_info(metadata) + await self.async_set_custom_meta(metadata) logger.debug( f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully." diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 39c890c..fc06015 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -222,11 +222,13 @@ class DataPartitionStatus: default_factory=set ) # set of global indexes that pre-allocated, but not active in this partition - # Field metadata + # Metadata field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} - field_custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_meta} + field_custom_backend_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_backend_meta} + # User-defined metadata that may not apply to field level + custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {} # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. @@ -480,9 +482,9 @@ def _update_field_metadata( # Only create and update custom_meta mapping if a custom_meta value was provided if custom_meta_value[i] is not None: - if global_idx not in self.field_custom_metas: - self.field_custom_metas[global_idx] = {} - self.field_custom_metas[global_idx].update(custom_meta_value[i]) + if global_idx not in self.field_custom_backend_metas: + self.field_custom_backend_metas[global_idx] = {} + self.field_custom_backend_metas[global_idx].update(custom_meta_value[i]) def mark_consumed(self, task_name: str, global_indices: list[int]): """ @@ -629,7 +631,7 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: return ready_sample_indices - # ==================== Field Metadata Methods ==================== + # ==================== Metadata Methods ==================== def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]: """Get dtype for a specific sample and field.""" @@ -639,14 +641,28 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) - def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> dict[int, dict[str, Any]]: - """Get custom_meta for multiple samples and fields.""" + def get_field_custom_backend_meta(self, global_indices: list[int], field_names: list[str]) -> dict[int, dict[str, Any]]: + """Get custom_backend_meta for multiple samples and fields.""" return { - idx: {f: v for f, v in self.field_custom_metas[idx].items() if f in field_names} + idx: {f: v for f, v in self.field_custom_backend_metas[idx].items() if f in field_names} for idx in global_indices - if idx in self.field_custom_metas + if idx in self.field_custom_backend_metas } + def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]: + """Get custom_meta for multiple samples.""" + return { + idx: self.custom_metas[idx] + for idx in global_indices + if idx in self.custom_metas + } + + + def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: + """Set custom_meta for multiple samples.""" + for k in custom_meta.keys(): + self.custom_metas[k] = custom_meta[k] + # ==================== Statistics and Monitoring ==================== def get_statistics(self) -> dict[str, Any]: @@ -982,6 +998,22 @@ def get_production_status( return partition.get_production_status_for_fields(data_fields, mask=True) + def set_custom_meta( + self, partition_custom_meta: dict[str, dict[int, dict]] + ) -> None: + """ + Set custom meta for samples in a partition. + + Args: + partition_id: ID of the partition + + """ + + for partition_id, custom_meta in partition_custom_meta.items(): + partition = self._get_partition(partition_id) + if partition: + partition.set_custom_meta(custom_meta) + def get_metadata( self, data_fields: list[str], @@ -1095,11 +1127,11 @@ def get_metadata( ) elif mode == "force_fetch": - global_indexes_range = self.index_manager.get_indexes_for_partition(partition_id) - consumer_status = self.get_consumption_status(partition_id, task_name) - not_consumed_idx = [i for i in global_indexes_range if consumer_status[i] == 0] - batch_global_indexes = not_consumed_idx - consumed_indexes = [] + if partition_id is not None: + batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) + consumed_indexes = [] + else: + batch_global_indexes = list(sorted(self.index_manager.allocated_indexes)) # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) @@ -1216,10 +1248,12 @@ def generate_batch_meta( ) samples.append(sample) - custom_meta = partition.get_field_custom_meta(batch_global_indexes, data_fields) + custom_meta = partition.get_custom_meta(batch_global_indexes) + custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields) batch_meta = BatchMeta(samples=samples) batch_meta.update_custom_meta(custom_meta) + batch_meta.update_custom_backend_meta(custom_backend_meta) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): @@ -1459,6 +1493,24 @@ def _process_request(self): receiver_id=request_msg.sender_id, body={"metadata": metadata}, ) + elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META: + with perf_monitor.measure(op_type="SET_CUSTOM_META"): + params = request_msg.body + global_indexes = params["global_indexes"] + custom_metadata = params["custom_metadata"] + + self.set_custom_meta( + global_indexes=global_indexes, + custom_metadata=custom_metadata, + ) + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.SET_CUSTOM_META, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"message": f"Successfully set custom_meta"}, + ) + elif request_msg.request_type == ZMQRequestType.CLEAR_META: with perf_monitor.measure(op_type="CLEAR_META"): params = request_msg.body diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 2737e90..9fcd6e6 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -208,13 +208,15 @@ class BatchMeta: # external user-defined meta for each sample custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) - # internal meta for different storage backends for each sample + # internal meta for different storage backends in per-sample per-field level _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) self.extra_info = copy.deepcopy(self.extra_info) + self.custom_meta = copy.deepcopy(self.custom_meta) + self._custom_backend_meta = copy.deepcopy(self._custom_backend_meta) # Basic properties object.__setattr__(self, "_size", len(self.samples)) @@ -235,7 +237,12 @@ def __post_init__(self): object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) + # filter custom_meta and _custom_backend_meta + self.custom_meta = copy.deepcopy({k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta}) + self._custom_backend_meta = copy.deepcopy({k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta}) else: + self.custom_meta = {} + self._custom_backend_meta = {} object.__setattr__(self, "_global_indexes", []) object.__setattr__(self, "_field_names", []) object.__setattr__(self, "_partition_ids", []) @@ -266,48 +273,73 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) + def set_extra_info(self, key: str, value: Any) -> None: + """Set extra_info by key""" + self.extra_info[key] = value + def get_all_extra_info(self) -> dict[str, Any]: - """Get all extra info as a dictionary""" + """Get all extra_info as a dictionary""" return copy.deepcopy(self.extra_info) def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """Update extra info with multiple key-value pairs""" + """Update extra_info with multiple key-value pairs""" self.extra_info.update(info_dict) def clear_extra_info(self) -> None: - """Clear all extra info""" + """Clear all extra_info""" self.extra_info.clear() + def set_custom_meta(self, key: int, value: dict[str, Any]) -> None: + """Set custom_meta by key""" + if key not in self.global_indexes: + raise ValueError(f"key {key} not found in global_indexes {self.global_indexes}.") + + self.custom_meta[key] = copy.deepcopy(value) + def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """Get the entire custom_meta dictionary""" return copy.deepcopy(self.custom_meta) - def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]]): + def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): """Update custom_meta with a new dictionary""" - self._custom_meta.update(new_custom_meta) + + non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) + if bool(non_exist_global_indexes): + raise ValueError(f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " + f"do not exist in this batch.") + + self.custom_meta.update(new_meta) def clear_custom_meta(self) -> None: """Clear custom_meta""" - self._custom_meta.clear() + self.custom_meta.clear() + + def set_custom_backend_meta(self, key: int, value: dict[str, Any]) -> None: + """Set custom_meta by key""" + if key not in self.global_indexes: + raise ValueError(f"key {key} not found in global_indexes {self.global_indexes}.") + + self._custom_backend_meta[key] = copy.deepcopy(value) def get_all_custom_backend_meta(self) -> dict[int, dict[str, Any]]: """Get the entire _custom_backend_meta dictionary""" return copy.deepcopy(self._custom_backend_meta) - def update_custom_backend_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]): + def update_custom_backend_meta(self, new_meta: Optional[dict[int, dict[str, Any]]]): """Update _custom_backend_meta with a new dictionary""" - if new_custom_meta: - self._custom_backend_meta.update(new_custom_meta) + + non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) + if bool(non_exist_global_indexes): + raise ValueError(f"Trying to update _custom_backend_meta with non-exist global_indexes! " + f"{non_exist_global_indexes} do not exist in this batch.") + + if new_meta: + self._custom_backend_meta.update(new_meta) def clear_custom_backend_meta(self) -> None: """Clear _custom_backend_meta""" self._custom_backend_meta.clear() - - - - - def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """ Add new fields from a TensorDict to all samples in this batch. @@ -417,6 +449,22 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: start = end return chunk_list + def chunk_by_partition(self, ) -> list["BatchMeta"]: + """ + Split this batch into smaller chunks according to partition_ids. + + Return: + List of smaller BatchMeta chunks, each chunk has samples with identical partition_id + """ + + grouped_global_indexes = defaultdict(list) + for partition_id, global_index in zip(self.partition_ids, self.global_indexes): + grouped_global_indexes[partition_id].append(global_index) + + chunk_list = [self.select_samples(samples) for samples in grouped_global_indexes.values()] + + return chunk_list + @classmethod def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 9e1cadc..0a39113 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -74,6 +74,8 @@ class ZMQRequestType(ExplicitEnum): GET_META_RESPONSE = "GET_META_RESPONSE" GET_PARTITION_META = "GET_PARTITION_META" GET_PARTITION_META_RESPONSE = "GET_PARTITION_META_RESPONSE" + SET_CUSTOM_META = "SET_CUSTOM_META" + SET_CUSTOM_META_RESPONSE = "SET_CUSTOM_META_RESPONSE" CLEAR_META = "CLEAR_META" CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE" CLEAR_PARTITION = "CLEAR_PARTITION" From 9a1e75b0bed92ad4905e1b591eff46f872e86956 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 16:25:39 +0800 Subject: [PATCH 03/23] support custom_meta Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 70 ++++++++++---- transfer_queue/controller.py | 139 +++++++++++++++++++-------- transfer_queue/metadata.py | 181 ++++++++++++++++++++++++++++------- 3 files changed, 293 insertions(+), 97 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 271d0aa..fe32b72 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -16,7 +16,6 @@ import asyncio import logging import os -from collections import defaultdict import threading from functools import wraps from typing import Any, Callable, Optional, Union @@ -248,49 +247,61 @@ async def async_get_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - - @dynamic_socket(socket_name="request_handle_socket") async def async_set_custom_meta( self, metadata: BatchMeta, socket: Optional[zmq.asyncio.Socket] = None, ) -> None: + """ + Asynchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + socket: ZMQ async socket for message transmission (injected by decorator) + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> asyncio.run(client.async_set_custom_meta(batch_meta)) + """ assert socket is not None if not self._controller: raise RuntimeError("No controller registered") - partition_ids = metadata.partition_ids global_indexes = metadata.global_indexes custom_meta = metadata.get_all_custom_meta() - if len(global_indexes) == 0 or len(custom_meta) == 0: logger.warning(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.") return - - non_exist_global_indexes = set(custom_meta.keys()) - set(global_indexes) - if bool(non_exist_global_indexes): - raise ValueError(f"Trying to update custom_meta with non-exist global_indexes! " - f"{non_exist_global_indexes} do not exist in this batch.") - # chunk metadata according to partition_ids metadata_chunks = metadata.chunk_by_partition() - partition_custom_meta = {k:[] for k in metadata.partition_ids} + # Build partition_custom_meta in format: {partition_id: {global_index: {meta1:xxx, meta2:xxx}}} + partition_custom_meta: dict[str, dict[int, dict]] = {pid: {} for pid in set(metadata.partition_ids)} for meta in metadata_chunks: - partition_custom_meta[meta.partition_ids[0]].append(meta.custom_meta) - + partition_custom_meta[meta.partition_ids[0]].update(meta.get_all_custom_meta()) request_msg = ZMQMessage.create( request_type=ZMQRequestType.SET_CUSTOM_META, sender_id=self.client_id, receiver_id=self._controller.id, body={ - "partition_custom_meta": partition_custom_meta + "partition_custom_meta": partition_custom_meta, }, ) @@ -298,8 +309,7 @@ async def async_set_custom_meta( response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from " - f"controller {self._controller.id}" + f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from controller {self._controller.id}" ) if response_msg.request_type != ZMQRequestType.SET_CUSTOM_META_RESPONSE: @@ -308,7 +318,6 @@ async def async_set_custom_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - async def async_put( self, data: TensorDict, @@ -906,6 +915,7 @@ def wrapper(*args, **kwargs): self._check_consumption_status = _make_sync(self.async_check_consumption_status) self._check_production_status = _make_sync(self.async_check_production_status) self._get_partition_list = _make_sync(self.async_get_partition_list) + self._set_custom_meta = _make_sync(self.async_set_custom_meta) def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None @@ -1062,6 +1072,30 @@ def get_partition_list( """ return self._get_partition_list() + def set_custom_meta(self, metadata: BatchMeta) -> None: + """Synchronously send custom metadata to the controller. + + This method sends per-sample custom metadata (custom_meta) to the controller. + The custom_meta is stored in the controller and can be retrieved along with + the BatchMeta in subsequent get_meta calls. + + Args: + metadata: BatchMeta containing the samples and their custom metadata to store. + The custom_meta should be set using BatchMeta.update_custom_meta() or + BatchMeta.set_custom_meta() before calling this method. + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Create batch with custom metadata + >>> batch_meta = client.get_meta(data_fields=["input_ids"], batch_size=4, ...) + >>> batch_meta.update_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + >>> client.set_custom_meta(batch_meta) + """ + + return self._set_custom_meta(metadata=metadata) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index fc06015..f69e66a 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -226,9 +226,11 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} - field_custom_backend_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_backend_meta} + field_custom_backend_meta: dict[int, dict[str, Any]] = field( + default_factory=dict + ) # global_idx -> {field: custom_backend_meta} # User-defined metadata that may not apply to field level - custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {} + custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {} # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. @@ -429,7 +431,7 @@ def _update_field_metadata( global_indices: list[int], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], - custom_meta: Optional[dict[int, dict[str, Any]]], + custom_backend_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" if not global_indices: @@ -440,9 +442,9 @@ def _update_field_metadata( raise ValueError(f"`global_indices` {len(global_indices)} and `dtypes` {len(dtypes)} length mismatch.") if shapes and len(global_indices) != len(shapes): raise ValueError(f"`global_indices` {len(global_indices)} and `shapes` {len(shapes)} length mismatch.") - if custom_meta and len(global_indices) != len(custom_meta): + if custom_backend_meta and len(global_indices) != len(custom_backend_meta): raise ValueError( - f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_meta)} length mismatch." + f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_backend_meta)} length mismatch." ) # Extract values for each provided mapping; if a mapping is absent, use Nones @@ -460,12 +462,12 @@ def _update_field_metadata( else: shape_value = tuple([None] * len(global_indices)) - if custom_meta: - custom_meta_value = itemgetter(*global_indices)(custom_meta) - if not isinstance(custom_meta_value, tuple): - custom_meta_value = (custom_meta_value,) + if custom_backend_meta: + custom_backend_meta_value = itemgetter(*global_indices)(custom_backend_meta) + if not isinstance(custom_backend_meta_value, tuple): + custom_backend_meta_value = (custom_backend_meta_value,) else: - custom_meta_value = tuple([None] * len(global_indices)) + custom_backend_meta_value = tuple([None] * len(global_indices)) for i, global_idx in enumerate(global_indices): # Only create and update dtype mapping if a dtype value was provided @@ -480,11 +482,11 @@ def _update_field_metadata( self.field_shapes[global_idx] = {} self.field_shapes[global_idx].update(shape_value[i]) - # Only create and update custom_meta mapping if a custom_meta value was provided - if custom_meta_value[i] is not None: - if global_idx not in self.field_custom_backend_metas: - self.field_custom_backend_metas[global_idx] = {} - self.field_custom_backend_metas[global_idx].update(custom_meta_value[i]) + # Only create and update custom_backend_meta mapping if a custom_backend_meta value was provided + if custom_backend_meta_value[i] is not None: + if global_idx not in self.field_custom_backend_meta: + self.field_custom_backend_meta[global_idx] = {} + self.field_custom_backend_meta[global_idx].update(custom_backend_meta_value[i]) def mark_consumed(self, task_name: str, global_indices: list[int]): """ @@ -641,27 +643,65 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) - def get_field_custom_backend_meta(self, global_indices: list[int], field_names: list[str]) -> dict[int, dict[str, Any]]: - """Get custom_backend_meta for multiple samples and fields.""" + def get_field_custom_backend_meta( + self, global_indices: list[int], field_names: list[str] + ) -> dict[int, dict[str, Any]]: + """ + Get custom_backend_meta for multiple samples and fields. + + This method retrieves backend-specific metadata stored at per-sample per-field level. + The returned dictionary maps global_index to a dictionary of field_name to metadata. + + Args: + global_indices: List of global sample indices to retrieve metadata for + field_names: List of field names to filter by. Only metadata for these + fields will be included in the result. + + Returns: + Dictionary mapping global_index to field-name-to-metadata mapping. + Only includes indices that have custom_backend_meta set. + + Example: + >>> partition.get_field_custom_backend_meta([0, 1], ["field_a", "field_b"]) + {0: {'field_a': {'meta1': 'xxx'}, 'field_b': {'meta1': 'xxx'}}, 1: {...}} + """ return { - idx: {f: v for f, v in self.field_custom_backend_metas[idx].items() if f in field_names} + idx: {f: v for f, v in self.field_custom_backend_meta[idx].items() if f in field_names} for idx in global_indices - if idx in self.field_custom_backend_metas + if idx in self.field_custom_backend_meta } def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]: - """Get custom_meta for multiple samples.""" - return { - idx: self.custom_metas[idx] - for idx in global_indices - if idx in self.custom_metas - } + """ + Get custom_meta for multiple samples. + This method retrieves user-defined per-sample metadata. + + Args: + global_indices: List of global sample indices to retrieve metadata for + + Returns: + Dictionary mapping global_index to custom metadata dict. + Only includes indices that have custom_meta set. + + Example: + >>> partition.get_custom_meta([0, 2]) + {0: {'score': 0.9}, 2: {'label': 'positive'}} + """ + return {idx: self.custom_meta[idx] for idx in global_indices if idx in self.custom_meta} def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: - """Set custom_meta for multiple samples.""" + """ + Set custom_meta for multiple samples. + + This method sets or updates user-defined per-sample metadata. + + Args: + custom_meta: Dictionary mapping global_index to custom metadata dict. + Existing entries will be overwritten. + """ for k in custom_meta.keys(): - self.custom_metas[k] = custom_meta[k] + self.custom_meta[k] = custom_meta[k] # ==================== Statistics and Monitoring ==================== @@ -758,7 +798,8 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr for idx in indexes_to_release: self.field_dtypes.pop(idx, None) self.field_shapes.pop(idx, None) - self.field_custom_metas.pop(idx, None) + self.field_custom_backend_meta.pop(idx, None) + self.custom_meta.pop(idx, None) except Exception as e: logger.error( @@ -998,15 +1039,33 @@ def get_production_status( return partition.get_production_status_for_fields(data_fields, mask=True) - def set_custom_meta( - self, partition_custom_meta: dict[str, dict[int, dict]] - ) -> None: + def set_custom_meta(self, partition_custom_meta: dict[str, dict[int, dict]]) -> None: """ - Set custom meta for samples in a partition. + Set custom_meta for samples in partitions. - Args: - partition_id: ID of the partition + This method allows setting per-sample custom metadata (custom_meta) for samples + identified by their global indexes within specific partitions. Custom metadata + is stored per-sample and can be retrieved along with BatchMeta in subsequent + get_meta calls. + Args: + partition_custom_meta: Dictionary mapping partition_id to custom metadata dict. + Format: {partition_id: {global_index: {metadata_key: metadata_value}}} + - partition_id: ID of the partition + - global_index: Global index of the sample + - metadata_key/value: User-defined metadata key-value pairs + + Example: + >>> # Set custom metadata for samples in different partitions + >>> controller.set_custom_meta({ + ... "train_0": { + ... 0: {"score": 0.9, "label": "positive"}, + ... 1: {"score": 0.8, "label": "negative"} + ... }, + ... "train_1": { + ... 10: {"score": 0.95, "label": "positive"} + ... } + ... }) """ for partition_id, custom_meta in partition_custom_meta.items(): @@ -1496,19 +1555,15 @@ def _process_request(self): elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META: with perf_monitor.measure(op_type="SET_CUSTOM_META"): params = request_msg.body - global_indexes = params["global_indexes"] - custom_metadata = params["custom_metadata"] + partition_custom_meta = params["partition_custom_meta"] - self.set_custom_meta( - global_indexes=global_indexes, - custom_metadata=custom_metadata, - ) + self.set_custom_meta(partition_custom_meta=partition_custom_meta) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.SET_CUSTOM_META, + request_type=ZMQRequestType.SET_CUSTOM_META_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, - body={"message": f"Successfully set custom_meta"}, + body={"message": "Successfully set custom_meta"}, ) elif request_msg.request_type == ZMQRequestType.CLEAR_META: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 9fcd6e6..a47991d 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -137,7 +137,7 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance - # TODO(tianyi): move custom_meta to FieldMeta level + # TODO(tianyi): (maybe) move _custom_backend_meta and _custom_meta to FieldMeta level? selected_sample_meta = SampleMeta( fields=selected_fields, partition_id=self.partition_id, @@ -205,8 +205,8 @@ class BatchMeta: # external meta for non-sample level information extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # external user-defined meta for each sample - custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + # internal user-defined meta for each sample + _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) # internal meta for different storage backends in per-sample per-field level _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) @@ -215,7 +215,7 @@ def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) self.extra_info = copy.deepcopy(self.extra_info) - self.custom_meta = copy.deepcopy(self.custom_meta) + self._custom_meta = copy.deepcopy(self._custom_meta) self._custom_backend_meta = copy.deepcopy(self._custom_backend_meta) # Basic properties @@ -237,11 +237,15 @@ def __post_init__(self): object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) - # filter custom_meta and _custom_backend_meta - self.custom_meta = copy.deepcopy({k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta}) - self._custom_backend_meta = copy.deepcopy({k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta}) + # filter _custom_meta and _custom_backend_meta + self._custom_meta = copy.deepcopy( + {k: self._custom_meta[k] for k in self.global_indexes if k in self._custom_meta} + ) + self._custom_backend_meta = copy.deepcopy( + {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} + ) else: - self.custom_meta = {} + self._custom_meta = {} self._custom_backend_meta = {} object.__setattr__(self, "_global_indexes", []) object.__setattr__(self, "_field_names", []) @@ -274,70 +278,171 @@ def partition_ids(self) -> list[str]: return getattr(self, "_partition_ids", []) def set_extra_info(self, key: str, value: Any) -> None: - """Set extra_info by key""" + """ + Set extra_info value for a specific key. + + Args: + key: The key to set in extra_info + value: The value to associate with the key + """ self.extra_info[key] = value def get_all_extra_info(self) -> dict[str, Any]: - """Get all extra_info as a dictionary""" + """Get all extra_info as a dictionary (deep copy for immutability). + + Returns: + A deep copy of the extra_info dictionary + """ return copy.deepcopy(self.extra_info) def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """Update extra_info with multiple key-value pairs""" + """ + Update extra_info with multiple key-value pairs. + + This method updates the extra_info dictionary with the provided key-value pairs. + Existing keys will be overwritten with new values. + + Args: + info_dict: Dictionary of key-value pairs to add/update in extra_info + """ self.extra_info.update(info_dict) def clear_extra_info(self) -> None: - """Clear all extra_info""" + """ + Clear all extra_info. + + This method removes all key-value pairs from the extra_info dictionary. + """ self.extra_info.clear() - def set_custom_meta(self, key: int, value: dict[str, Any]) -> None: - """Set custom_meta by key""" - if key not in self.global_indexes: - raise ValueError(f"key {key} not found in global_indexes {self.global_indexes}.") + def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: + """ + Set _custom_meta for a specific sample by global_index. - self.custom_meta[key] = copy.deepcopy(value) + Custom metadata is user-defined per-sample metadata that can be stored + and retrieved along with the BatchMeta. + + Args: + global_index: The global_index of the sample to set custom meta for + meta_dict: Dictionary containing custom metadata for the sample + + Raises: + ValueError: If the key is not in global_indexes + """ + + if global_index not in self.global_indexes: + raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") + + self._custom_meta[global_index] = copy.deepcopy(meta_dict) def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: - """Get the entire custom_meta dictionary""" - return copy.deepcopy(self.custom_meta) + """ + Get all _custom_meta as a dictionary. + + Returns: + A deep copy of the _custom_meta dictionary + """ + return copy.deepcopy(self._custom_meta) def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): - """Update custom_meta with a new dictionary""" + """ + Update _custom_meta with a dictionary of new metadata. + + This method updates the _custom_meta dictionary with the provided metadata. + Existing keys will be overwritten with new values. + + Args: + new_meta: Dictionary of new metadata + + Raises: + ValueError: If any key in new_meta is not in global_indexes + """ + + if new_meta is None: + return non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) if bool(non_exist_global_indexes): - raise ValueError(f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " - f"do not exist in this batch.") + raise ValueError( + f"Trying to update _custom_meta with non-exist global_indexes! {non_exist_global_indexes} " + f"do not exist in this batch." + ) - self.custom_meta.update(new_meta) + self._custom_meta.update(new_meta) def clear_custom_meta(self) -> None: - """Clear custom_meta""" - self.custom_meta.clear() + """ + Clear all _custom_meta. - def set_custom_backend_meta(self, key: int, value: dict[str, Any]) -> None: - """Set custom_meta by key""" - if key not in self.global_indexes: - raise ValueError(f"key {key} not found in global_indexes {self.global_indexes}.") + This method removes all entries from the _custom_meta dictionary. + """ + self._custom_meta.clear() + + def set_custom_backend_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: + """ + Set _custom_backend_meta for a specific sample by global_index. - self._custom_backend_meta[key] = copy.deepcopy(value) + Custom backend metadata is internal metadata for storage backends, + stored at per-sample per-field level. This is typically used by + storage backends to store backend-specific information. + + Args: + global_index: The global_index of the sample to set backend meta for + meta_dict: Dictionary mapping field names to backend metadata + + Raises: + ValueError: If the key is not in global_indexes + """ + + if global_index not in self.global_indexes: + raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") + + self._custom_backend_meta[global_index] = copy.deepcopy(meta_dict) def get_all_custom_backend_meta(self) -> dict[int, dict[str, Any]]: - """Get the entire _custom_backend_meta dictionary""" + """ + Get all _custom_backend_meta as a dictionary. + + Returns: + A deep copy of the _custom_backend_meta dictionary + """ + return copy.deepcopy(self._custom_backend_meta) def update_custom_backend_meta(self, new_meta: Optional[dict[int, dict[str, Any]]]): - """Update _custom_backend_meta with a new dictionary""" + """ + Update _custom_backend_meta with a dictionary of new metadata. + + This method updates the _custom_backend_meta dictionary with the provided metadata. + Existing keys will be overwritten with new values. + + Args: + new_meta: Dictionary of new metadata + + Raises: + ValueError: If any key in new_meta is not in global_indexes + """ + + if new_meta is None: + return non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) if bool(non_exist_global_indexes): - raise ValueError(f"Trying to update _custom_backend_meta with non-exist global_indexes! " - f"{non_exist_global_indexes} do not exist in this batch.") + raise ValueError( + f"Trying to update _custom_backend_meta with non-exist global_indexes! " + f"{non_exist_global_indexes} do not exist in this batch." + ) if new_meta: self._custom_backend_meta.update(new_meta) def clear_custom_backend_meta(self) -> None: - """Clear _custom_backend_meta""" + """ + Clear all _custom_backend_meta. + + This method removes all entries from the _custom_backend_meta dictionary. + + """ self._custom_backend_meta.clear() def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": @@ -449,7 +554,9 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: start = end return chunk_list - def chunk_by_partition(self, ) -> list["BatchMeta"]: + def chunk_by_partition( + self, + ) -> list["BatchMeta"]: """ Split this batch into smaller chunks according to partition_ids. @@ -458,7 +565,7 @@ def chunk_by_partition(self, ) -> list["BatchMeta"]: """ grouped_global_indexes = defaultdict(list) - for partition_id, global_index in zip(self.partition_ids, self.global_indexes): + for partition_id, global_index in zip(self.partition_ids, self.global_indexes, strict=False): grouped_global_indexes[partition_id].append(global_index) chunk_list = [self.select_samples(samples) for samples in grouped_global_indexes.values()] From dfb07e5aa9562b3babb3e38a705a0ced7f0b2a4d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 18:53:29 +0800 Subject: [PATCH 04/23] add CI Signed-off-by: 0oshowero0 --- tests/test_client.py | 48 +++++++ tests/test_controller.py | 153 ++++++++++++++++++++ tests/test_controller_data_partitions.py | 171 +++++++++++++++++++---- transfer_queue/controller.py | 1 + 4 files changed, 347 insertions(+), 26 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index b9e478a..5bb97d1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -774,3 +774,51 @@ async def test_sync_and_async_methods_mixed_usage(client_setup): assert async_data is not None print("✓ Mixed async and sync method calls work correctly") + + +# ===================================================== +# Custom Meta Interface Tests +# ===================================================== + + +class TestClientCustomMetaInterface: + """Tests for client custom_meta interface methods.""" + + def test_set_custom_meta_sync(self, client_setup): + """Test synchronous set_custom_meta method.""" + client, _, _ = client_setup + + # Test synchronous set_custom_meta + + # First get metadata + metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") + # Set custom_meta on the metadata + metadata.update_custom_meta( + { + 0: {"input_ids": {"token_count": 100}}, + 1: {"input_ids": {"token_count": 120}}, + } + ) + + # Call set_custom_meta with metadata (BatchMeta) + client.set_custom_meta(metadata) + print("✓ set_custom_meta sync method works") + + @pytest.mark.asyncio + async def test_set_custom_meta_async(self, client_setup): + """Test asynchronous async_set_custom_meta method.""" + client, _, _ = client_setup + + # First get metadata + metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0") + # Set custom_meta on the metadata + metadata.update_custom_meta( + { + 0: {"input_ids": {"token_count": 100}}, + 1: {"input_ids": {"token_count": 120}}, + } + ) + + # Call async_set_custom_meta with metadata (BatchMeta) + await client.async_set_custom_meta(metadata) + print("✓ async_set_custom_meta async method works") diff --git a/tests/test_controller.py b/tests/test_controller.py index 354f292..192395d 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -450,3 +450,156 @@ def test_controller_clear_meta(self, ray_setup): assert set(partition_after.global_indexes) == set([4, 5, 7]) print("✓ Clear meta correct") + + +class TestTransferQueueControllerCustomMeta: + """Integration tests for TransferQueueController custom_meta and custom_backend_meta methods. + + Note: In this codebase: + - custom_meta: per-sample metadata (simple key-value pairs per sample) + - custom_backend_meta: per-sample per-field metadata (stored via update_production_status) + """ + + def test_controller_with_custom_meta(self, ray_setup): + """Test TransferQueueController with custom_backend_meta and custom_meta functionality""" + + batch_size = 3 + partition_id = "custom_meta_test" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids", "attention_mask"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=batch_size, + partition_id=partition_id, + mode="insert", + ) + ) + + assert metadata.global_indexes == list(range(batch_size)) + + # Build custom_backend_meta (per-sample per-field metadata) + custom_backend_meta = { + 0: {"prompt_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}}, + 1: {"prompt_ids": {"token_count": 120}, "attention_mask": {"mask_ratio": 0.15}}, + 2: {"prompt_ids": {"token_count": 90}, "attention_mask": {"mask_ratio": 0.12}}, + } + + # Update production status with custom_backend_meta + dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes} + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_backend_meta, + ) + ) + assert success + + # Get partition snapshot and verify custom_backend_meta is stored + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + + # Verify custom_backend_meta via get_field_custom_backend_meta + result = partition.get_field_custom_backend_meta(list(range(batch_size)), ["prompt_ids", "attention_mask"]) + assert len(result) == batch_size + assert result[0]["prompt_ids"]["token_count"] == 100 + assert result[2]["attention_mask"]["mask_ratio"] == 0.12 + + print("✓ Controller set custom_backend_meta via update_production_status correct") + + # Now set custom_meta (per-sample metadata) + # Format: {partition_id: {global_index: custom_meta_dict}} + custom_meta = { + partition_id: { + 0: {"sample_score": 0.9, "quality": "high"}, + 1: {"sample_score": 0.8, "quality": "medium"}, + # You can set partial samples with custom_meta. + } + } + + # Verify set_custom_meta method exists and can be called + ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta)) + + # Verify via partition snapshot + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + result = partition.get_custom_meta([0, 1]) + assert 0 in result + assert result[0]["sample_score"] == 0.9 + assert result[0]["quality"] == "high" + assert 1 in result + assert result[1]["sample_score"] == 0.8 + assert 2 not in result + + # Init another partition + new_partition_id = "custom_meta_test2" + # Create metadata in insert mode + data_fields = ["prompt_ids", "attention_mask"] + new_metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=batch_size, + partition_id=new_partition_id, + mode="insert", + ) + ) + + # Update production status + dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in new_metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in new_metadata.global_indexes} + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=new_partition_id, + global_indexes=new_metadata.global_indexes, + field_names=new_metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=None, + ) + ) + assert success + + # Provide complicated case: update custom_meta with mixed partitions, and update previous custom_meta + new_custom_meta = { + new_partition_id: { + 3: {"sample_score": 1, "quality": "high"}, + 4: {"sample_score": 0, "quality": "low"}, + }, + partition_id: { + 2: {"sample_score": 0.7, "quality": "high"}, + 0: {"sample_score": 0.001, "quality": "low"}, + }, + } + + # update with new_custom_meta + ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=new_custom_meta)) + + # Verify via partition snapshot + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + result = partition.get_custom_meta([0, 1, 2]) + assert 0 in result + assert result[0]["sample_score"] == 0.001 # updated! + assert result[0]["quality"] == "low" # updated! + assert 1 in result # unchanged + assert result[1]["sample_score"] == 0.8 # unchanged + assert 2 in result # unchanged + assert result[2]["sample_score"] == 0.7 # new + + new_partition = ray.get(tq_controller.get_partition_snapshot.remote(new_partition_id)) + result = new_partition.get_custom_meta([3, 4, 5]) + assert 3 in result + assert result[3]["sample_score"] == 1 + assert result[3]["quality"] == "high" + assert 4 in result + assert result[4]["sample_score"] == 0 + assert 5 not in result # 5 has not custom_meta, it will not return even we retrieve for 5 + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 84c8b63..e5854c1 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -448,20 +448,22 @@ def test_performance_characteristics(): def test_custom_meta_in_data_partition_status(): - """Simplified tests for custom_meta functionality in DataPartitionStatus.""" + """Simple tests for custom_meta and custom_backend_meta functionality in DataPartitionStatus.""" - print("Testing simplified custom_meta in DataPartitionStatus...") + print("Testing custom_meta and custom_backend_meta in DataPartitionStatus...") from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="custom_meta_test") - # Basic custom_meta storage via update_production_status + # First, set up production status global_indices = [0, 1, 2] field_names = ["input_ids", "attention_mask"] dtypes = {i: {"input_ids": "torch.int32", "attention_mask": "torch.bool"} for i in global_indices} shapes = {i: {"input_ids": (512,), "attention_mask": (512,)} for i in global_indices} - custom_meta = { + + # custom_backend_meta goes to field_custom_backend_meta (per-sample per-field metadata) + custom_backend_meta = { 0: {"input_ids": {"token_count": 100}}, 1: {"attention_mask": {"mask_ratio": 0.2}}, 2: {"input_ids": {"token_count": 300}}, @@ -472,29 +474,44 @@ def test_custom_meta_in_data_partition_status(): field_names=field_names, dtypes=dtypes, shapes=shapes, - custom_meta=custom_meta, + custom_meta=custom_backend_meta, ) assert success - # Verify some stored values - assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 - assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2 + # Verify custom_backend_meta stored via get_field_custom_backend_meta + retrieved_backend = partition.get_field_custom_backend_meta([0, 1, 2], ["input_ids", "attention_mask"]) + assert 0 in retrieved_backend + assert retrieved_backend[0]["input_ids"]["token_count"] == 100 + assert 1 in retrieved_backend + assert retrieved_backend[1]["attention_mask"]["mask_ratio"] == 0.2 + + # Test set_custom_meta (goes to custom_meta, per-sample metadata) + partition.set_custom_meta({0: {"sample_score": 0.9}, 1: {"sample_score": 0.8}}) + retrieved_custom = partition.get_custom_meta([0, 1]) + assert 0 in retrieved_custom + assert retrieved_custom[0]["sample_score"] == 0.9 + assert 1 in retrieved_custom + assert retrieved_custom[1]["sample_score"] == 0.8 + + # Clearing a sample should remove both custom_meta and custom_backend_meta + partition.clear_data([0], clear_consumption=True) - # Retrieval via helper for a subset of fields - retrieved = partition.get_field_custom_backend_meta([0, 1], ["input_ids", "attention_mask"]) - assert 0 in retrieved and "input_ids" in retrieved[0] - assert 1 in retrieved and "attention_mask" in retrieved[1] + # Verify custom_meta is cleared + result_custom = partition.get_custom_meta([0, 1]) + assert 0 not in result_custom + assert 1 in result_custom - # Clearing a sample should remove its custom_meta - partition.clear_data([0], clear_consumption=True) - assert 0 not in partition.field_custom_metas + # Verify custom_backend_meta is cleared + result_backend = partition.get_field_custom_backend_meta([0, 1, 2], ["input_ids", "attention_mask"]) + assert 0 not in result_backend + assert 2 in result_backend # Sample 2 should still be there - print("✓ Custom_meta tests passed") + print("✓ Custom_meta and custom_backend_meta tests passed") def test_update_field_metadata_variants(): - """Test _update_field_metadata handles dtypes/shapes/custom_meta being optional and merging.""" + """Test _update_field_metadata handles dtypes/shapes/custom_backend_meta being optional and merging.""" from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="update_meta_test") @@ -503,29 +520,29 @@ def test_update_field_metadata_variants(): global_indices = [0, 1] dtypes = {0: {"f1": "torch.int32"}, 1: {"f1": "torch.bool"}} - partition._update_field_metadata(global_indices, dtypes, shapes=None, custom_meta=None) + partition._update_field_metadata(global_indices, dtypes, shapes=None, custom_backend_meta=None) assert partition.field_dtypes[0]["f1"] == "torch.int32" assert partition.field_dtypes[1]["f1"] == "torch.bool" assert partition.field_shapes == {} - assert partition.field_custom_metas == {} + assert partition.field_custom_backend_meta == {} # Only shapes provided for a new index - partition._update_field_metadata([2], dtypes=None, shapes={2: {"f2": (16,)}}, custom_meta=None) + partition._update_field_metadata([2], dtypes=None, shapes={2: {"f2": (16,)}}, custom_backend_meta=None) assert partition.field_shapes[2]["f2"] == (16,) - # Only custom_meta provided and merged with existing entries - partition._update_field_metadata([2], dtypes=None, shapes=None, custom_meta={2: {"f2": {"meta": 1}}}) - assert 2 in partition.field_custom_metas - assert partition.field_custom_metas[2]["f2"]["meta"] == 1 + # Only custom_backend_meta provided and merged with existing entries + partition._update_field_metadata([2], dtypes=None, shapes=None, custom_backend_meta={2: {"f2": {"meta": 1}}}) + assert 2 in partition.field_custom_backend_meta + assert partition.field_custom_backend_meta[2]["f2"]["meta"] == 1 # Merging dtypes on an existing index should preserve previous keys - partition._update_field_metadata([0], dtypes={0: {"f2": "torch.float32"}}, shapes=None, custom_meta=None) + partition._update_field_metadata([0], dtypes={0: {"f2": "torch.float32"}}, shapes=None, custom_backend_meta=None) assert partition.field_dtypes[0]["f1"] == "torch.int32" assert partition.field_dtypes[0]["f2"] == "torch.float32" # Length mismatch should raise ValueError when provided mapping lengths differ from global_indices with pytest.raises(ValueError): - partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None) + partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_backend_meta=None) def test_get_production_status_for_fields(): @@ -822,3 +839,105 @@ def test_pre_allocated_indexes_mixed_with_dynamic(): print("✓ Mixed pre-allocated and dynamic indexes work correctly") print("Mixed indexes tests passed!\n") + + +class TestDataPartitionStatusCustomMeta: + """Unit tests for DataPartitionStatus custom_meta methods.""" + + def test_set_custom_meta_single_partition(self): + """Test set_custom_meta sets custom metadata for samples in a partition.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="train_0") + + # Set custom_meta for specific samples + custom_meta = { + 0: {"score": 0.9, "label": "positive"}, + 1: {"score": 0.8, "label": "negative"}, + } + partition.set_custom_meta(custom_meta) + + # Verify + result = partition.get_custom_meta([0, 1, 2]) + assert 0 in result + assert result[0]["score"] == 0.9 + assert 1 in result + assert result[1]["label"] == "negative" + + def test_set_custom_meta_updates_existing(self): + """Test set_custom_meta updates existing custom metadata.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="train_0") + + # Initial custom_meta + partition.set_custom_meta({0: {"score": 0.5}}) + + # Update with new values + partition.set_custom_meta({0: {"score": 0.9, "label": "updated"}}) + + result = partition.get_custom_meta([0]) + assert result[0]["score"] == 0.9 + assert result[0]["label"] == "updated" + + def test_get_custom_meta_returns_only_requested(self): + """Test get_custom_meta only returns metadata for requested indices.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="train_0") + + partition.set_custom_meta( + { + 0: {"data": "sample_0"}, + 1: {"data": "sample_1"}, + 2: {"data": "sample_2"}, + } + ) + + # Request only specific indices + result = partition.get_custom_meta([0, 2]) + + assert 0 in result + assert 2 in result + assert 1 not in result + assert result[0]["data"] == "sample_0" + assert result[2]["data"] == "sample_2" + + def test_get_custom_meta_empty_for_missing(self): + """Test get_custom_meta returns empty dict for indices without metadata.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="train_0") + + # Set custom_meta only for sample 0 + partition.set_custom_meta({0: {"score": 0.9}}) + + # Request indices that don't have metadata + result = partition.get_custom_meta([1, 2]) + + assert 0 not in result + assert 1 not in result + assert 2 not in result + + def test_custom_meta_cleared_with_data(self): + """Test custom_meta is cleared when clearing sample data.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="train_0") + + # Set production status and custom_meta + partition.update_production_status( + global_indices=[0, 1], + field_names=["input_ids"], + dtypes={0: {"input_ids": "torch.int32"}, 1: {"input_ids": "torch.int32"}}, + shapes={0: {"input_ids": (512,)}, 1: {"input_ids": (512,)}}, + ) + partition.set_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}}) + + # Clear sample 0 + partition.clear_data([0], clear_consumption=True) + + # Verify sample 0 custom_meta is cleared + result = partition.get_custom_meta([0, 1]) + assert 0 not in result + assert 1 in result # Sample 1 should still have custom_meta diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index f69e66a..88a35ef 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -700,6 +700,7 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: custom_meta: Dictionary mapping global_index to custom metadata dict. Existing entries will be overwritten. """ + for k in custom_meta.keys(): self.custom_meta[k] = custom_meta[k] From b6b0d00320a77e4ece96239909ebca0f7581b7d1 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 19:20:10 +0800 Subject: [PATCH 05/23] update batchmeta & ci Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 395 +++++++++++++++++++++++-------------- transfer_queue/metadata.py | 117 ++--------- 2 files changed, 267 insertions(+), 245 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index b3c988b..59ce567 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -683,103 +683,111 @@ def test_batch_meta_select_samples_with_extra_info(self): assert selected_batch.extra_info["number"] == 42 assert selected_batch.extra_info["list"] == [1, 2, 3] - def test_batch_meta_extra_info_operations(self): - """Example: Extra info management operations.""" + # ===================================================== + # Custom Meta Tests + # ===================================================== + + def test_batch_meta_set_custom_meta_basic(self): + """Test set_custom_meta sets metadata for a sample by global_index.""" fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), } - batch = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - - # Set and get - batch.set_extra_info("key1", "value1") - assert batch.get_extra_info("key1") == "value1" - assert batch.has_extra_info("key1") is True - - # Update multiple - batch.update_extra_info({"key2": "value2", "key3": "value3"}) - assert batch.has_extra_info("key2") is True + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) - # Remove - removed = batch.remove_extra_info("key2") - assert removed == "value2" - assert batch.has_extra_info("key2") is False + # Set custom_meta for sample 0 + batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) - # Get all - all_info = batch.get_all_extra_info() - assert "key1" in all_info - assert "key3" in all_info + result = batch.get_all_custom_meta() + assert 0 in result + assert result[0]["sample_score"] == 0.9 + assert result[0]["quality"] == "high" + # Sample 1 should not have custom_meta + assert 1 not in result + + def test_batch_meta_set_custom_meta_overwrites(self): + """Test set_custom_meta overwrites existing metadata.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) - # Clear - batch.clear_extra_info() - assert len(batch.extra_info) == 0 + # Set initial custom_meta + batch.set_custom_meta(0, {"sample_score": 0.9, "quality": "high"}) + # Overwrite with new custom_meta + batch.set_custom_meta(0, {"sample_score": 0.1, "quality": "low"}) -class TestEdgeCases: - """Edge cases and important boundaries.""" + result = batch.get_all_custom_meta() + assert result[0]["sample_score"] == 0.1 + assert result[0]["quality"] == "low" - def test_batch_meta_chunk_with_more_chunks_than_samples(self): - """Example: Chunking when chunks > samples produces empty chunks.""" + def test_batch_meta_set_custom_meta_invalid_global_index(self): + """Test set_custom_meta raises error for invalid global_index.""" fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] batch = BatchMeta(samples=samples) - # 5 chunks for 2 samples - chunks = batch.chunk(5) - - assert len(chunks) == 5 - # First 2 chunks have samples - assert len(chunks[0]) == 1 - assert len(chunks[1]) == 1 - # Last 3 chunks are empty - assert len(chunks[2]) == 0 - assert len(chunks[3]) == 0 - assert len(chunks[4]) == 0 + # Try to set with non-existent global index + with pytest.raises(ValueError) as exc_info: + batch.set_custom_meta(999, {"sample_score": 0.9}) + assert "not found in global_indexes" in str(exc_info.value) - def test_batch_meta_concat_with_empty_batches(self): - """Example: Concat handles empty batches gracefully.""" + def test_batch_meta_update_custom_meta(self): + """Test update_custom_meta adds metadata for different global indices.""" fields = { - "test_field": FieldMeta( - name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME - ) + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + "field_b": FieldMeta(name="field_b", dtype=torch.int64, shape=(3,)), } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) - batch1 = BatchMeta(samples=[]) - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) - batch3 = BatchMeta(samples=[]) - - # Empty batches are filtered out - result = BatchMeta.concat([batch1, batch2, batch3]) - assert len(result) == 1 - assert result.global_indexes == [0] + # Initial custom_meta for sample 0 + batch.update_custom_meta({0: {"sample_score": 0.9}}) - def test_batch_meta_concat_validation_error(self): - """Example: Concat validation catches field name mismatches.""" - fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} - fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} + # Update with metadata for sample 1 + batch.update_custom_meta({1: {"sample_score": 0.1}}) - batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields1)]) + result = batch.get_all_custom_meta() + assert result[0]["sample_score"] == 0.9 + assert result[1]["sample_score"] == 0.1 - batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields2)]) + def test_batch_meta_update_custom_meta_overwrites(self): + """Test update_custom_meta overwrites existing metadata at same key.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) - with pytest.raises(ValueError) as exc_info: - BatchMeta.concat([batch1, batch2], validate=True) - assert "Field names do not match" in str(exc_info.value) + # Initial custom_meta + batch.update_custom_meta({0: {"sample_score": 0.9, "quality": "high"}}) + # Update with new value for same field - dict.update replaces + batch.update_custom_meta({0: {"sample_score": 0.1, "quality": "low"}}) -class TestCustomMeta: - """Unit tests for BatchMeta custom meta methods.""" + result = batch.get_all_custom_meta() + assert result[0]["sample_score"] == 0.1 + assert result[0]["quality"] == "low" - def test_get_all_custom_meta_returns_deep_copy(self): - """Test get_all_custom_meta returns a deep copy of the custom meta dict.""" + def test_batch_meta_update_custom_meta_with_none(self): + """Test update_custom_meta with None does nothing.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } @@ -788,20 +796,17 @@ def test_get_all_custom_meta_returns_deep_copy(self): ] batch = BatchMeta(samples=samples) - custom_meta = {0: {"field_a": {"nested": "value"}}} - batch.update_custom_meta(custom_meta) - - # Get all custom meta - result = batch.get_all_custom_meta() + # Set initial value + batch.update_custom_meta({0: {"sample_score": 0.9}}) - # Verify it's a deep copy - modifying result should not affect original - result[0]["field_a"]["nested"] = "modified" + # Update with None should not change anything + batch.update_custom_meta(None) - original = batch.get_all_custom_meta() - assert original[0]["field_a"]["nested"] == "value" + result = batch.get_all_custom_meta() + assert result[0]["sample_score"] == 0.9 - def test_get_all_custom_meta_empty(self): - """Test get_all_custom_meta with no custom meta returns empty dict.""" + def test_batch_meta_update_custom_meta_with_empty_dict(self): + """Test update_custom_meta with empty dict does nothing.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } @@ -810,74 +815,77 @@ def test_get_all_custom_meta_empty(self): ] batch = BatchMeta(samples=samples) - result = batch.get_all_custom_meta() + # Set initial value + batch.update_custom_meta({0: {"sample_score": 0.9}}) - assert result == {} + # Update with empty dict should not change anything + batch.update_custom_meta({}) + + result = batch.get_all_custom_meta() + assert result[0]["sample_score"] == 0.9 - def test_update_custom_meta_basic(self): - """Test update_custom_meta adds new entries.""" + def test_batch_meta_update_custom_meta_invalid_global_index(self): + """Test update_custom_meta raises error for invalid global_index.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] batch = BatchMeta(samples=samples) - # Update with custom meta - custom_meta = { - 0: {"field_a": "value_0"}, - 1: {"field_a": "value_1"}, - } - batch.update_custom_meta(custom_meta) - - result = batch.get_all_custom_meta() - assert result[0]["field_a"] == "value_0" - assert result[1]["field_a"] == "value_1" + # Try to update with non-existent global index + with pytest.raises(ValueError) as exc_info: + batch.update_custom_meta({999: {"sample_score": 0.9}}) + assert "non-exist global_indexes" in str(exc_info.value) - def test_update_custom_meta_overwrites_existing(self): - """Test update_custom_meta overwrites existing entries at the top level.""" + def test_batch_meta_clear_custom_meta(self): + """Test clear_custom_meta removes all custom metadata.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] batch = BatchMeta(samples=samples) - # Initial custom meta - batch.update_custom_meta({0: {"field_a": "original"}}) + # Set custom_meta + batch.set_custom_meta(0, {"sample_score": 0.9}) + batch.set_custom_meta(1, {"sample_score": 0.1}) - # Update with new value - dict.update replaces the entire value for key 0 - batch.update_custom_meta({0: {"field_a": "updated"}}) + # Clear all + batch.clear_custom_meta() result = batch.get_all_custom_meta() - assert result[0]["field_a"] == "updated" + assert result == {} - def test_update_custom_meta_merges_different_keys(self): - """Test update_custom_meta merges different top-level keys.""" + def test_batch_meta_get_all_custom_meta_returns_deep_copy(self): + """Test get_all_custom_meta returns a deep copy.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] batch = BatchMeta(samples=samples) - # First update - batch.update_custom_meta({0: {"field_a": "value_0"}}) - - # Second update with different key - batch.update_custom_meta({1: {"field_a": "value_1"}}) + custom_meta = {0: {"sample_score": 0.9, "nested": {"value": 1}}} + batch.update_custom_meta(custom_meta) + # Get all custom_meta result = batch.get_all_custom_meta() - assert result[0]["field_a"] == "value_0" - assert result[1]["field_a"] == "value_1" - def test_update_custom_meta_with_none(self): - """Test update_custom_meta with None does nothing.""" + # Verify it's a deep copy - modifying result should not affect original + result[0]["sample_score"] = 0.1 + result[0]["nested"]["value"] = 999 + + original = batch.get_all_custom_meta() + assert original[0]["sample_score"] == 0.9 + assert original[0]["nested"]["value"] == 1 + + def test_batch_meta_get_all_custom_meta_empty(self): + """Test get_all_custom_meta with no custom_meta returns empty dict.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } @@ -886,17 +894,11 @@ def test_update_custom_meta_with_none(self): ] batch = BatchMeta(samples=samples) - # Set initial value - batch.update_custom_meta({0: {"field_a": "value"}}) - - # Update with None should not change anything - batch.update_custom_meta(None) - result = batch.get_all_custom_meta() - assert result[0]["field_a"] == "value" + assert result == {} - def test_update_custom_meta_with_empty_dict(self): - """Test update_custom_meta with empty dict does nothing.""" + def test_batch_meta_custom_meta_with_nested_data(self): + """Test custom_meta supports nested dictionary data.""" fields = { "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), } @@ -905,38 +907,135 @@ def test_update_custom_meta_with_empty_dict(self): ] batch = BatchMeta(samples=samples) - # Set initial value - batch.update_custom_meta({0: {"field_a": "value"}}) - - # Update with empty dict should not change anything - batch.update_custom_meta({}) + nested_meta = { + "model_info": {"name": "llama", "version": "7b", "config": {"hidden_size": 4096, "num_layers": 32}}, + "tags": ["training", "inference"], + } + batch.set_custom_meta(0, nested_meta) result = batch.get_all_custom_meta() - assert result[0]["field_a"] == "value" + assert result[0]["model_info"]["name"] == "llama" + assert result[0]["model_info"]["version"] == "7b" + assert result[0]["model_info"]["config"]["hidden_size"] == 4096 + assert result[0]["tags"] == ["training", "inference"] - def test_custom_meta_with_complex_values(self): - """Test custom meta can store complex values like dicts, lists, tensors.""" + # ===================================================== + # Extra Info Methods Tests + # ===================================================== + + def test_batch_meta_update_extra_info(self): + """Test update_extra_info adds multiple values.""" fields = { - "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) + } + batch = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + + # Update with multiple values + batch.update_extra_info({"key1": "value1", "key2": "value2", "key3": "value3"}) + + # Verify all exist + assert "key1" in batch.extra_info + assert "key2" in batch.extra_info + assert "key3" in batch.extra_info + assert batch.extra_info["key1"] == "value1" + assert batch.extra_info["key2"] == "value2" + + def test_batch_meta_extra_info_preserved_in_operations(self): + """Test extra_info is preserved in batch operations.""" + fields = { + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) + } + batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + batch1.extra_info["test_key1"] = "test_value" + + batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) + batch2.extra_info["test_key2"] = "test_value_2" + + result = BatchMeta.concat([batch1, batch2]) + + # Extra info is preserved + assert "test_key1" in result.extra_info + + def test_batch_meta_extra_info_with_concat(self): + """Test extra_info handling in concat with mixed types.""" + fields = { + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) + } + + batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + batch1.extra_info["string"] = "hello" + batch1.extra_info["number"] = 42 + + batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields)]) + batch2.extra_info["string"] = "world" + batch2.extra_info["number"] = 100 + + result = BatchMeta.concat([batch1, batch2]) + + # String: last value wins + assert result.extra_info["string"] == "world" + + +class TestEdgeCases: + """Edge cases and important boundaries.""" + + def test_batch_meta_chunk_with_more_chunks_than_samples(self): + """Example: Chunking when chunks > samples produces empty chunks.""" + fields = { + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) } samples = [ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] batch = BatchMeta(samples=samples) - # Store complex values - custom_meta = { - 0: { - "field_a": { - "nested_dict": {"key": "value"}, - "list": [1, 2, 3], - "number": 42, - } - } + # 5 chunks for 2 samples + chunks = batch.chunk(5) + + assert len(chunks) == 5 + # First 2 chunks have samples + assert len(chunks[0]) == 1 + assert len(chunks[1]) == 1 + # Last 3 chunks are empty + assert len(chunks[2]) == 0 + assert len(chunks[3]) == 0 + assert len(chunks[4]) == 0 + + def test_batch_meta_concat_with_empty_batches(self): + """Example: Concat handles empty batches gracefully.""" + fields = { + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) } - batch.update_custom_meta(custom_meta) - result = batch.get_all_custom_meta() - assert result[0]["field_a"]["nested_dict"]["key"] == "value" - assert result[0]["field_a"]["list"] == [1, 2, 3] - assert result[0]["field_a"]["number"] == 42 + batch1 = BatchMeta(samples=[]) + batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields)]) + batch3 = BatchMeta(samples=[]) + + # Empty batches are filtered out + result = BatchMeta.concat([batch1, batch2, batch3]) + assert len(result) == 1 + assert result.global_indexes == [0] + + def test_batch_meta_concat_validation_error(self): + """Example: Concat validation catches field name mismatches.""" + fields1 = {"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,))} + fields2 = {"field2": FieldMeta(name="field2", dtype=torch.float32, shape=(2,))} + + batch1 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=0, fields=fields1)]) + + batch2 = BatchMeta(samples=[SampleMeta(partition_id="partition_0", global_index=1, fields=fields2)]) + + with pytest.raises(ValueError) as exc_info: + BatchMeta.concat([batch1, batch2], validate=True) + assert "Field names do not match" in str(exc_info.value) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index a47991d..d701269 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -137,7 +137,7 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance - # TODO(tianyi): (maybe) move _custom_backend_meta and _custom_meta to FieldMeta level? + # TODO(tianyi): (maybe) move _custom_backend_meta and custom_meta to FieldMeta level? selected_sample_meta = SampleMeta( fields=selected_fields, partition_id=self.partition_id, @@ -205,8 +205,8 @@ class BatchMeta: # external meta for non-sample level information extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # internal user-defined meta for each sample - _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) + # user-defined meta for each sample + custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) # internal meta for different storage backends in per-sample per-field level _custom_backend_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) @@ -215,7 +215,7 @@ def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) self.extra_info = copy.deepcopy(self.extra_info) - self._custom_meta = copy.deepcopy(self._custom_meta) + self.custom_meta = copy.deepcopy(self.custom_meta) self._custom_backend_meta = copy.deepcopy(self._custom_backend_meta) # Basic properties @@ -237,15 +237,15 @@ def __post_init__(self): object.__setattr__(self, "_partition_ids", [sample.partition_id for sample in self.samples]) - # filter _custom_meta and _custom_backend_meta - self._custom_meta = copy.deepcopy( - {k: self._custom_meta[k] for k in self.global_indexes if k in self._custom_meta} + # filter custom_meta and _custom_backend_meta + self.custom_meta = copy.deepcopy( + {k: self.custom_meta[k] for k in self.global_indexes if k in self.custom_meta} ) self._custom_backend_meta = copy.deepcopy( {k: self._custom_backend_meta[k] for k in self.global_indexes if k in self._custom_backend_meta} ) else: - self._custom_meta = {} + self.custom_meta = {} self._custom_backend_meta = {} object.__setattr__(self, "_global_indexes", []) object.__setattr__(self, "_field_names", []) @@ -277,16 +277,6 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) - def set_extra_info(self, key: str, value: Any) -> None: - """ - Set extra_info value for a specific key. - - Args: - key: The key to set in extra_info - value: The value to associate with the key - """ - self.extra_info[key] = value - def get_all_extra_info(self) -> dict[str, Any]: """Get all extra_info as a dictionary (deep copy for immutability). @@ -317,7 +307,7 @@ def clear_extra_info(self) -> None: def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: """ - Set _custom_meta for a specific sample by global_index. + Set custom_meta for a specific sample by global_index. Custom metadata is user-defined per-sample metadata that can be stored and retrieved along with the BatchMeta. @@ -333,22 +323,22 @@ def set_custom_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: if global_index not in self.global_indexes: raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - self._custom_meta[global_index] = copy.deepcopy(meta_dict) + self.custom_meta[global_index] = copy.deepcopy(meta_dict) def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """ - Get all _custom_meta as a dictionary. + Get all custom_meta as a dictionary. Returns: - A deep copy of the _custom_meta dictionary + A deep copy of the custom_meta dictionary """ - return copy.deepcopy(self._custom_meta) + return copy.deepcopy(self.custom_meta) def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): """ - Update _custom_meta with a dictionary of new metadata. + Update custom_meta with a dictionary of new metadata. - This method updates the _custom_meta dictionary with the provided metadata. + This method updates the custom_meta dictionary with the provided metadata. Existing keys will be overwritten with new values. Args: @@ -364,86 +354,19 @@ def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) if bool(non_exist_global_indexes): raise ValueError( - f"Trying to update _custom_meta with non-exist global_indexes! {non_exist_global_indexes} " + f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " f"do not exist in this batch." ) - self._custom_meta.update(new_meta) + self.custom_meta.update(new_meta) def clear_custom_meta(self) -> None: """ - Clear all _custom_meta. - - This method removes all entries from the _custom_meta dictionary. - """ - self._custom_meta.clear() - - def set_custom_backend_meta(self, global_index: int, meta_dict: dict[str, Any]) -> None: - """ - Set _custom_backend_meta for a specific sample by global_index. - - Custom backend metadata is internal metadata for storage backends, - stored at per-sample per-field level. This is typically used by - storage backends to store backend-specific information. - - Args: - global_index: The global_index of the sample to set backend meta for - meta_dict: Dictionary mapping field names to backend metadata - - Raises: - ValueError: If the key is not in global_indexes - """ - - if global_index not in self.global_indexes: - raise ValueError(f"key {global_index} not found in global_indexes {self.global_indexes}.") - - self._custom_backend_meta[global_index] = copy.deepcopy(meta_dict) - - def get_all_custom_backend_meta(self) -> dict[int, dict[str, Any]]: - """ - Get all _custom_backend_meta as a dictionary. - - Returns: - A deep copy of the _custom_backend_meta dictionary - """ - - return copy.deepcopy(self._custom_backend_meta) - - def update_custom_backend_meta(self, new_meta: Optional[dict[int, dict[str, Any]]]): - """ - Update _custom_backend_meta with a dictionary of new metadata. - - This method updates the _custom_backend_meta dictionary with the provided metadata. - Existing keys will be overwritten with new values. - - Args: - new_meta: Dictionary of new metadata - - Raises: - ValueError: If any key in new_meta is not in global_indexes - """ - - if new_meta is None: - return - - non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) - if bool(non_exist_global_indexes): - raise ValueError( - f"Trying to update _custom_backend_meta with non-exist global_indexes! " - f"{non_exist_global_indexes} do not exist in this batch." - ) - - if new_meta: - self._custom_backend_meta.update(new_meta) - - def clear_custom_backend_meta(self) -> None: - """ - Clear all _custom_backend_meta. - - This method removes all entries from the _custom_backend_meta dictionary. + Clear all custom_meta. + This method removes all entries from the custom_meta dictionary. """ - self._custom_backend_meta.clear() + self.custom_meta.clear() def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": """ From 098b7d14fdccd21d1547f3a560443aeb8fc05e87 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 30 Jan 2026 19:33:43 +0800 Subject: [PATCH 06/23] fix Signed-off-by: 0oshowero0 --- tests/test_client.py | 3 +++ tests/test_metadata.py | 27 +++++++++++++++++++++++++++ transfer_queue/controller.py | 2 +- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 5bb97d1..604ced8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -134,6 +134,9 @@ def _handle_requests(self): "partition_ids": ["partition_0", "partition_1", "test_partition"], } response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE + elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META: + response_body = {"message": "success"} + response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE else: response_body = {"error": f"Unknown request type: {request_msg.request_type}"} response_type = ZMQRequestType.CLEAR_META_RESPONSE diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 59ce567..f74c54b 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -188,6 +188,33 @@ def test_batch_meta_chunk(self): assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 + def test_batch_meta_chunk_by_partition(self): + """Example: Split a batch into multiple chunks.""" + fields = { + "test_field": FieldMeta( + name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME + ) + } + samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i, fields=fields) for i in range(10)] + batch = BatchMeta(samples=samples) + + # Chunk according to partition_id + chunks = batch.chunk_by_partition() + + assert len(chunks) == 4 + assert len(chunks[0]) == 3 + assert chunks[0].partition_ids == ["partition_0", "partition_0", "partition_0"] + assert chunks[0].global_indexes == [0, 4, 8] + assert len(chunks[1]) == 3 + assert chunks[1].partition_ids == ["partition_1", "partition_1", "partition_1"] + assert chunks[1].global_indexes == [1, 5, 9] + assert len(chunks[2]) == 2 + assert chunks[2].partition_ids == ["partition_2", "partition_2"] + assert chunks[2].global_indexes == [2, 6] + assert len(chunks[3]) == 2 + assert chunks[3].partition_ids == ["partition_3", "partition_3"] + assert chunks[3].global_indexes == [3, 7] + def test_batch_meta_init_validation_error_different_field_names(self): """Example: Init validation catches samples with different field names.""" # Create first sample with field1 diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 88a35ef..cd482a4 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1313,7 +1313,7 @@ def generate_batch_meta( batch_meta = BatchMeta(samples=samples) batch_meta.update_custom_meta(custom_meta) - batch_meta.update_custom_backend_meta(custom_backend_meta) + batch_meta._custom_backend_meta.update(custom_backend_meta) return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): From 02e1b4824943e4fc6db8a1bd0725a87155a23ca8 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 31 Jan 2026 09:55:17 +0800 Subject: [PATCH 07/23] add tutorial & update CI Signed-off-by: 0oshowero0 --- .github/workflows/tutorial-check.yml | 38 +++++++++++++++++++ transfer_queue/client.py | 2 +- tutorial/02_metadata_concepts.py | 56 +++++++++++++++++++--------- 3 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/tutorial-check.yml diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml new file mode 100644 index 0000000..afd42ca --- /dev/null +++ b/.github/workflows/tutorial-check.yml @@ -0,0 +1,38 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Tutorial check + +on: + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + +jobs: + build: + runs-on: ubuntu-latest + timeout-minutes: 10 + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install -e ".[yuanrong]" + - name: Run tutorials + run: | + for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file diff --git a/transfer_queue/client.py b/transfer_queue/client.py index fe32b72..3322bce 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -284,7 +284,7 @@ async def async_set_custom_meta( custom_meta = metadata.get_all_custom_meta() if len(global_indexes) == 0 or len(custom_meta) == 0: - logger.warning(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.") + logger.debug(f"[{self.client_id}]: Empty BatchMeta or custom_meta provided. No action taken.") return # chunk metadata according to partition_ids diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index a2c6d59..bb11f9f 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -16,6 +16,7 @@ import os import sys import textwrap +import uuid import warnings from pathlib import Path @@ -203,37 +204,45 @@ def demonstrate_batch_meta(): print(f" Size: {batch.size}") # Example 2: Add extra_info - print("[Example 2] Adding batch-level information...") - batch.set_extra_info("epoch", 1) - batch.set_extra_info("batch_idx", 0) + print("[Example 2] Adding batch-level information through extra_info...") + print("Note: The extra info will not be stored into TransferQueueController.") + batch.extra_info["epoch"] = 1 + batch.extra_info["batch_idx"] = 0 print(f"✓ Extra info: {batch.get_all_extra_info()}") - # Example 3: Chunk a batch - print("[Example 3] Chunking a batch into parts...") + print("[Example 3] Adding sample-level information through custom_meta...") + batch.set_custom_meta( + global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"} + ) + batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}}) + print(f"✓ Extra info: {batch.get_all_custom_meta()}") + + # Example 4: Chunk a batch + print("[Example 4] Chunking a batch into parts...") chunks = batch.chunk(3) print(f"✓ Split into {len(chunks)} chunks:") for i, chunk in enumerate(chunks): print(f" Chunk {i}: {len(chunk)} samples, indexes={chunk.global_indexes}") - # Example 4: Select specific fields - print("[Example 4] Selecting specific fields...") + # Example 5: Select specific fields + print("[Example 5] Selecting specific fields...") selected_batch = batch.select_fields(["input_ids", "responses"]) print(f"✓ Selected fields: {selected_batch.field_names}") print(f" Original fields: {batch.field_names}") - # Example 5: Select specific samples - print("[Example 5] Selecting specific samples...") + # Example 6: Select specific samples + print("[Example 6] Selecting specific samples...") selected_samples = batch.select_samples([0, 2, 4]) print(f"✓ Selected samples at indexes: {selected_samples.global_indexes}") - # Example 6: Reorder samples - print("[Example 6] Reordering samples...") + # Example 7: Reorder samples + print("[Example 7] Reordering samples...") print(f" Original order: {batch.global_indexes}") batch.reorder([4, 3, 2, 1, 0]) print(f" After reorder: {batch.global_indexes}") - # Example 7: Concat batches - print("[Example 7] Concatenating batches...") + # Example 8: Concat batches + print("[Example 8] Concatenating batches...") batch1 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3)]) batch2 = BatchMeta(samples=[SampleMeta(partition_id="train_0", global_index=i, fields=fields) for i in range(3, 6)]) concatenated = BatchMeta.concat([batch1, batch2]) @@ -241,8 +250,8 @@ def demonstrate_batch_meta(): print(f" Global indexes: {concatenated.global_indexes}") print(" Note: concat combines multiple batches into one (same structure)") - # Example 8: Union batches - print("[Example 8] Unioning batches (different fields, same samples)...") + # Example 9: Union batches + print("[Example 9] Unioning batches (different fields, same samples)...") batch_with_input = BatchMeta( samples=[ SampleMeta( @@ -342,6 +351,18 @@ def demonstrate_real_workflow(): batch_meta = client.put(data=data_batch, partition_id=partition_id) print(f"✓ Put {data_batch.batch_size[0]} samples into partition '{partition_id}', got BatchMeta back {batch_meta}.") + print("[Step 2] [Optional] Setting sample-level custom_meta...") + + custom_meta = { + global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} + for global_index in batch_meta.global_indexes + } + batch_meta.set_custom_meta(custom_meta) + print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") + + client.set_custom_meta(batch_meta) + print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") + print("[Step 2] Try to get metadata from TransferQueue from other places...") batch_meta = client.get_meta( data_fields=["input_ids", "attention_mask"], @@ -355,6 +376,7 @@ def demonstrate_real_workflow(): print(f" Field names: {batch_meta.field_names}") print(f" Partition ID: {batch_meta.samples[0].partition_id}") print(f" Sample structure: {batch_meta.samples[0]}") + print(f" Custom Meta: {batch_meta.get_all_custom_meta()}") print("[Step 3] Retrieve samples with specific fields..") selected_meta = batch_meta.select_fields(["input_ids"]) @@ -404,8 +426,8 @@ def main(): Key Concepts: - Metadata tracks data structure without storing actual data - - Production status tracks whether data is ready for consumption - - BatchMeta provides operations: chunk, concat, union, select, reorder + - User can set their own custom metadata into BatchMeta, and use TQ controller to store them. + - BatchMeta provides operations: chunk, concat, union, select, reorder... - Metadata is lightweight and can be passed around efficiently - Union requires samples to have identical partition_id and global_index """ From 2b5eb43de1ae1bae28089b1935dfb340da3c102d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 31 Jan 2026 10:03:49 +0800 Subject: [PATCH 08/23] fix Signed-off-by: 0oshowero0 --- .github/workflows/tutorial-check.yml | 3 +++ tutorial/02_metadata_concepts.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index afd42ca..9a2b193 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -33,6 +33,9 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[yuanrong]" + - name: Export env vars + run: | + export TQ_NUM_THREADS=2 - name: Run tutorials run: | for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index bb11f9f..b399007 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -357,7 +357,7 @@ def demonstrate_real_workflow(): global_index: {"uid": uuid.uuid4().hex[:4], "session_id": uuid.uuid4().hex[:4], "model_version": 0} for global_index in batch_meta.global_indexes } - batch_meta.set_custom_meta(custom_meta) + batch_meta.update_custom_meta(custom_meta) print(f"✓ Set custom_meta into BatchMeta: {batch_meta.get_all_custom_meta()}") client.set_custom_meta(batch_meta) @@ -449,7 +449,8 @@ def main(): print("2. SampleMeta describes a single data sample") print("3. BatchMeta manages collections of samples with operations") print("4. Metadata operations: chunk, concat, union, select, reorder... You can retrieve subsets easily!") - print("5. concat combines batches; union merges fields of same samples") + print("5. extra_info is in batch-level, and custom_meta is in sample-level.") + print("6. You can put custom_meta into TQ controller, so you can retrieve them from anywhere!") # Cleanup ray.shutdown() From 177fe5b6e940fafad4d9334baf1395818a6879bd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sat, 31 Jan 2026 10:52:40 +0800 Subject: [PATCH 09/23] update var name Signed-off-by: 0oshowero0 --- transfer_queue/storage/clients/base.py | 8 ++++---- transfer_queue/storage/clients/mooncake_client.py | 4 ++-- transfer_queue/storage/clients/ray_storage_client.py | 4 ++-- transfer_queue/storage/clients/yuanrong_client.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index 1457b36..90c33fe 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -44,7 +44,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: raise NotImplementedError("Subclasses must implement put") @abstractmethod - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: """ Retrieve values from the storage backend by key. Args: @@ -55,9 +55,9 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li dtypes: Optional data type information for the expected values. The structure and interpretation of this argument are determined by the concrete storage backend implementation. - custom_meta: Optional backend-specific metadata used to control - or optimize the retrieval process. Its format is defined by - the concrete storage backend implementation. + custom_backend_meta: Optional backend-specific metadata used to + control or optimize the retrieval process. Its format is + defined by the concrete storage backend implementation. Returns: list[Any]: List of values retrieved from the storage backend, in the same order as the provided keys. diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 9f69a97..6b730ca 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -132,14 +132,14 @@ def _batch_put_bytes(self, keys: list[str], values: list[bytes]): if ret != 0: raise RuntimeError(f"put_batch failed with error code: {ret}") - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. Args: keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_meta (List[str], optional): Device type (npu/cpu) for each key + custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key Returns: List[Any]: Retrieved values in the same order as input keys. diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index 78a6c07..8bd4468 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -84,14 +84,14 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: ray.get(self.storage_actor.put_obj_ref.remote(keys, obj_refs)) return None - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: """ Retrieve objects from remote storage. Args: keys (list): List of string keys to fetch. shapes (list, optional): Ignored. For compatibility with KVStorageManager. dtypes (list, optional): Ignored. For compatibility with KVStorageManager. - custom_meta (list, optional): Ray object ref for each key + custom_backend_meta (list, optional): Ray object ref for each key Returns: list: List of retrieved objects """ diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index b082faf..bf845b5 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -374,7 +374,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: results[idx] = obj return results - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. @@ -383,7 +383,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_meta (List[str], optional): Device type (npu/cpu) for each key + custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key Returns: List[Any]: Retrieved values in the same order as input keys. From b3b1abe790303409e6f7c73c5011929106a7ce14 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Sun, 1 Feb 2026 19:45:50 +0800 Subject: [PATCH 10/23] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 94854ce..a2b3880 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio +import copy import itertools import logging import os @@ -508,9 +510,9 @@ def process_field(field_idx: int): return TensorDict(merged_data, batch_size=num_samples) @staticmethod - def _get_shape_type_custom_meta_list(metadata: BatchMeta): + def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): """ - Extract the expected shape, dtype, and custom meta for each field-sample pair in metadata. + Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. Args: @@ -521,8 +523,8 @@ def _get_shape_type_custom_meta_list(metadata: BatchMeta): """ shapes = [] dtypes = [] - custom_meta_list = [] - all_custom_meta = metadata.get_all_custom_meta() + custom_backend_meta_list = [] + all_custom_backend_meta = copy.deepcopy(metadata._custom_backend_meta) for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) @@ -530,8 +532,8 @@ def _get_shape_type_custom_meta_list(metadata: BatchMeta): shapes.append(field.shape) dtypes.append(field.dtype) global_index = metadata.global_indexes[index] - custom_meta_list.append(all_custom_meta.get(global_index, {}).get(field_name, None)) - return shapes, dtypes, custom_meta_list + custom_backend_meta_list.append(all_custom_backend_meta.get(global_index, {}).get(field_name, None)) + return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ @@ -618,8 +620,10 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.warning("Attempted to get data, but metadata contains no fields.") return TensorDict({}, batch_size=len(metadata)) keys = self._generate_keys(metadata.field_names, metadata.global_indexes) - shapes, dtypes, custom_meta = self._get_shape_type_custom_meta_list(metadata) - values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes, custom_meta=custom_meta) + shapes, dtypes, custom_backend_meta = self._get_shape_type_custom_backend_meta_list(metadata) + values = self.storage_client.get( + keys=keys, shapes=shapes, dtypes=dtypes, custom_backend_meta=custom_backend_meta + ) return self._merge_tensors_to_tensordict(metadata, values) async def clear_data(self, metadata: BatchMeta) -> None: From f5c95ac006ea20ebc466e6afa3283baa155fadfa Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 09:21:55 +0800 Subject: [PATCH 11/23] fix ut Signed-off-by: 0oshowero0 --- tests/test_kv_storage_manager.py | 42 +++++++++++++++++--------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 5cfb6fe..41296dd 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -21,13 +21,13 @@ import torch from tensordict import TensorDict -from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta -from transfer_queue.storage.managers.base import KVStorageManager - # Setup path parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) +from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402 +from transfer_queue.storage.managers.base import KVStorageManager # noqa: E402 + def get_meta(data, global_indexes=None): if not global_indexes: @@ -163,9 +163,11 @@ def test_merge_tensors_to_tensordict(mock_create, test_data): assert complex_tensordict[key] == complex_data[key] -def test_get_shape_type_custom_meta_list_without_custom_meta(test_data): - """Test _get_shape_type_custom_meta_list returns correct shapes and dtypes without custom_meta.""" - shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(test_data["metadata"]) +def test_get_shape_type_custom_backend_meta_list_without_custom_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_meta.""" + shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list( + test_data["metadata"] + ) # Expected order: sorted by field name (label, mask, text), then by global_index order # 3 fields * 3 samples = 9 entries @@ -184,25 +186,25 @@ def test_get_shape_type_custom_meta_list_without_custom_meta(test_data): ] expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) # No custom_meta provided, so all should be None - expected_custom_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) + expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"])) assert shapes == expected_shapes assert dtypes == expected_dtypes - assert custom_meta_list == expected_custom_meta + assert custom_backend_meta_list == expected_custom_backend_meta -def test_get_shape_type_custom_meta_list_with_custom_meta(test_data): +def test_get_shape_type_custom_backend_meta_list_with_custom_meta(test_data): """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided.""" # Add custom_meta to metadata - custom_meta = { + custom_backend_meta = { 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, } metadata = test_data["metadata"] - metadata.update_custom_meta(custom_meta) + metadata._custom_backend_meta.update(custom_backend_meta) - shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(metadata) + shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index expected_custom_meta = [ @@ -216,24 +218,24 @@ def test_get_shape_type_custom_meta_list_with_custom_meta(test_data): {"key4": "value4"}, # text, global_index=9 {"key7": "value7"}, # text, global_index=10 ] - assert custom_meta_list == expected_custom_meta + assert custom_backend_meta_list == expected_custom_meta -def test_get_shape_type_custom_meta_list_with_partial_custom_meta(test_data): - """Test _get_shape_type_custom_meta_list handles partial custom_meta correctly.""" +def test_get_shape_type_custom_backend_meta_list_with_partial_custom_meta(test_data): + """Test _get_shape_type_custom_backend_meta_list handles partial custom_meta correctly.""" # Add custom_meta only for some global_indexes and fields - custom_meta = { + custom_backend_meta = { 8: {"text": {"key1": "value1"}}, # Only text field # global_index 9 has no custom_meta 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only } metadata = test_data["metadata"] - metadata.update_custom_meta(custom_meta) + metadata._custom_backend_meta.update(custom_backend_meta) - shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(metadata) + shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata) # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index - expected_custom_meta = [ + expected_custom_backend_meta = [ None, # label, global_index=8 (not in custom_meta) None, # label, global_index=9 (not in custom_meta) {"key2": "value2"}, # label, global_index=10 @@ -244,7 +246,7 @@ def test_get_shape_type_custom_meta_list_with_partial_custom_meta(test_data): None, # text, global_index=9 (not in custom_meta) None, # text, global_index=10 (not in custom_meta for text) ] - assert custom_meta_list == expected_custom_meta + assert custom_backend_meta_list == expected_custom_backend_meta @pytest.fixture From 4ed69ef1c342cb4838203191bf87be2a398bc34a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:03:55 +0800 Subject: [PATCH 12/23] fix metadata chunk & concat & union Signed-off-by: 0oshowero0 # Conflicts: # transfer_queue/client.py --- .github/workflows/tutorial-check.yml | 4 +- tests/test_controller.py | 2 +- tests/test_metadata.py | 89 +++++++++++++++++++++++--- transfer_queue/client.py | 46 ++++++++++++-- transfer_queue/controller.py | 1 + transfer_queue/metadata.py | 93 +++++++++++++++++++++++++--- tutorial/02_metadata_concepts.py | 2 +- 7 files changed, 212 insertions(+), 25 deletions(-) diff --git a/.github/workflows/tutorial-check.yml b/.github/workflows/tutorial-check.yml index 9a2b193..1202e21 100644 --- a/.github/workflows/tutorial-check.yml +++ b/.github/workflows/tutorial-check.yml @@ -33,9 +33,7 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[yuanrong]" - - name: Export env vars - run: | - export TQ_NUM_THREADS=2 - name: Run tutorials run: | + export TQ_NUM_THREADS=2 for file in tutorial/*.py; do python3 "$file"; done \ No newline at end of file diff --git a/tests/test_controller.py b/tests/test_controller.py index 192395d..e373250 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -599,7 +599,7 @@ def test_controller_with_custom_meta(self, ray_setup): assert result[3]["quality"] == "high" assert 4 in result assert result[4]["sample_score"] == 0 - assert 5 not in result # 5 has not custom_meta, it will not return even we retrieve for 5 + assert 5 not in result # 5 has no custom_meta, it will not return even we retrieve for 5 # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index f74c54b..b2293ea 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -178,7 +178,11 @@ def test_batch_meta_chunk(self): ) } samples = [SampleMeta(partition_id="partition_0", global_index=i, fields=fields) for i in range(10)] - batch = BatchMeta(samples=samples) + batch = BatchMeta( + samples=samples, + custom_meta={i: {"uid": i} for i in range(10)}, + _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + ) # Chunk into 3 parts chunks = batch.chunk(3) @@ -188,6 +192,22 @@ def test_batch_meta_chunk(self): assert len(chunks[1]) == 3 assert len(chunks[2]) == 3 + # validate custom_meta is chunked + assert 0 in chunks[0].custom_meta + assert 1 in chunks[0].custom_meta + assert 2 in chunks[0].custom_meta + assert 3 in chunks[0].custom_meta + assert 4 not in chunks[0].custom_meta + assert 4 in chunks[1].custom_meta + + # validate _custom_backend_meta is chunked + assert 0 in chunks[0]._custom_backend_meta + assert 1 in chunks[0]._custom_backend_meta + assert 2 in chunks[0]._custom_backend_meta + assert 3 in chunks[0]._custom_backend_meta + assert 4 not in chunks[0]._custom_backend_meta + assert 4 in chunks[1]._custom_backend_meta + def test_batch_meta_chunk_by_partition(self): """Example: Split a batch into multiple chunks.""" fields = { @@ -196,7 +216,11 @@ def test_batch_meta_chunk_by_partition(self): ) } samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i, fields=fields) for i in range(10)] - batch = BatchMeta(samples=samples) + batch = BatchMeta( + samples=samples, + custom_meta={i: {"uid": i} for i in range(10)}, + _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + ) # Chunk according to partition_id chunks = batch.chunk_by_partition() @@ -215,6 +239,20 @@ def test_batch_meta_chunk_by_partition(self): assert chunks[3].partition_ids == ["partition_3", "partition_3"] assert chunks[3].global_indexes == [3, 7] + # validate _custom_backend_meta is chunked + assert 0 in chunks[0].custom_meta + assert 4 in chunks[0].custom_meta + assert 8 in chunks[0].custom_meta + assert 1 not in chunks[0].custom_meta + assert 1 in chunks[1].custom_meta + + # validate _custom_backend_meta is chunked + assert 0 in chunks[0]._custom_backend_meta + assert 4 in chunks[0]._custom_backend_meta + assert 8 in chunks[0]._custom_backend_meta + assert 1 not in chunks[0]._custom_backend_meta + assert 1 in chunks[1]._custom_backend_meta + def test_batch_meta_init_validation_error_different_field_names(self): """Example: Init validation catches samples with different field names.""" # Create first sample with field1 @@ -243,14 +281,18 @@ def test_batch_meta_concat(self): samples=[ SampleMeta(partition_id="partition_0", global_index=0, fields=fields), SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - ] + ], + custom_meta={i: {"uid": i} for i in [0, 1]}, + _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [0, 1]}, ) batch2 = BatchMeta( samples=[ SampleMeta(partition_id="partition_0", global_index=2, fields=fields), SampleMeta(partition_id="partition_0", global_index=3, fields=fields), - ] + ], + custom_meta={i: {"uid": i} for i in [2, 3]}, + _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in [2, 3]}, ) # Concatenate batches @@ -258,6 +300,8 @@ def test_batch_meta_concat(self): assert len(result) == 4 assert result.global_indexes == [0, 1, 2, 3] + assert result.custom_meta == {i: {"uid": i} for i in [0, 1, 2, 3]} + assert result._custom_backend_meta == {i: {"test_field": {"dtype": torch.float32}} for i in [0, 1, 2, 3]} def test_batch_meta_concat_with_tensor_extra_info(self): """Example: Concat handles tensor extra_info by concatenating along dim=0.""" @@ -364,7 +408,10 @@ def test_batch_meta_union(self): samples=[ SampleMeta(partition_id="partition_0", global_index=0, fields=fields1), SampleMeta(partition_id="partition_0", global_index=1, fields=fields1), - ] + ], + _custom_backend_meta={ + i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [0, 1] + }, ) batch1.extra_info["info1"] = "value1" @@ -372,7 +419,10 @@ def test_batch_meta_union(self): samples=[ SampleMeta(partition_id="partition_0", global_index=0, fields=fields2), SampleMeta(partition_id="partition_0", global_index=1, fields=fields2), - ] + ], + _custom_backend_meta={ + i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [0, 1] + }, ) batch2.extra_info["info2"] = "value2" @@ -388,6 +438,12 @@ def test_batch_meta_union(self): assert result.extra_info["info1"] == "value1" assert result.extra_info["info2"] == "value2" + # _custom_backend_meta is merged + assert result._custom_backend_meta == { + i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} + for i in [0, 1] + } + def test_batch_meta_union_validation(self): """Example: Union validation catches mismatched conditions.""" fields = {"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,))} @@ -463,7 +519,18 @@ def test_batch_meta_select_fields(self): SampleMeta(partition_id="partition_0", global_index=0, fields=fields), SampleMeta(partition_id="partition_0", global_index=1, fields=fields), ] - batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) + batch = BatchMeta( + samples=samples, + extra_info={"test_key": "test_value"}, + _custom_backend_meta={ + i: { + "field1": {"dtype": torch.float32}, + "field2": {"dtype": torch.int64}, + "field3": {"dtype": torch.bool}, + } + for i in [0, 1] + }, + ) # Select only field1 and field3 selected_batch = batch.select_fields(["field1", "field3"]) @@ -481,6 +548,14 @@ def test_batch_meta_select_fields(self): # Global indexes are preserved assert selected_batch.global_indexes == [0, 1] + # _custom_backend_meta is selected + assert "field1" in selected_batch._custom_backend_meta[0] + assert "field2" not in selected_batch._custom_backend_meta[0] + assert "field3" in selected_batch._custom_backend_meta[0] + assert "field1" in selected_batch._custom_backend_meta[1] + assert "field2" not in selected_batch._custom_backend_meta[1] + assert "field3" in selected_batch._custom_backend_meta[1] + def test_batch_meta_select_fields_with_nonexistent_fields(self): """Example: Select fields ignores non-existent field names in batch.""" fields = { diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3322bce..19e0403 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -209,7 +209,7 @@ async def async_get_meta( >>> print(batch_meta.is_ready) # True if all samples ready >>> >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, - >>> # so may include unready samples. Consumed samples will not be fetched.) + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) >>> batch_meta = asyncio.run(client.async_get_meta( ... partition_id="train_0", # optional ... mode="force_fetch", @@ -941,17 +941,55 @@ def get_meta( task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, ) -> BatchMeta: - """Synchronously fetch data metadata from controller. + """Synchronously fetch data metadata from the controller via ZMQ. Args: data_fields: List of data field names to retrieve metadata for batch_size: Number of samples to request in the batch - partition_id: Target data partition id + partition_id: Current data partition id + mode: Data fetch mode. Options: + - 'fetch': Get ready data only + - 'force_fetch': Get data regardless of readiness (may return unready samples) + - 'insert': Internal usage - should not be used by users task_name: Optional task name associated with the request sampling_config: Optional sampling configuration for custom samplers. + Returns: - BatchMeta: Batch metadata containing data location information + BatchMeta: Metadata object containing data structure, sample information, and readiness status + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Example 1: Basic fetch metadata + >>> batch_meta = client.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences" + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example) + >>> batch_meta = client.get_meta( + ... data_fields=["input_ids", "attention_mask"], + ... batch_size=8, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... sampling_config={"n_samples_per_prompt": 4} + ... ) + >>> print(batch_meta.is_ready) # True if all samples ready + >>> + >>> # Example 3: Force fetch metadata (bypass production status check and Sampler, + >>> # so may include unready and already-consumed samples. No filtering by consumption status is applied.) + >>> batch_meta = client.get_meta( + ... partition_id="train_0", # optional + ... mode="force_fetch", + ... ) + >>> print(batch_meta.is_ready) # May be False if some samples not ready """ return self._get_meta( data_fields=data_fields, diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index cd482a4..faf52f4 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1192,6 +1192,7 @@ def get_metadata( consumed_indexes = [] else: batch_global_indexes = list(sorted(self.index_manager.allocated_indexes)) + consumed_indexes = [] # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index d701269..76d4ed5 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -408,8 +408,18 @@ def select_samples(self, sample_indices: list[int]) -> "BatchMeta": selected_samples = [self.samples[i] for i in sample_indices] + selected_custom_meta = {i: self.custom_meta[i] for i in sample_indices if i in self.custom_meta} + selected_custom_backend_meta = { + i: self._custom_backend_meta[i] for i in sample_indices if i in self._custom_backend_meta + } + # construct new BatchMeta instance - selected_batch_meta = BatchMeta(samples=selected_samples, extra_info=self.extra_info) + selected_batch_meta = BatchMeta( + samples=selected_samples, + extra_info=self.extra_info, + custom_meta=selected_custom_meta, + _custom_backend_meta=selected_custom_backend_meta, + ) return selected_batch_meta @@ -427,8 +437,23 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta": # select fields for each SampleMeta new_samples = [sample.select_fields(field_names=field_names) for sample in self.samples] + # select fields in _custom_backend_meta + selected_custom_backend_meta = {} + for idx in self.global_indexes: + if idx in self._custom_backend_meta: + custom_backend_meta_idx = self._custom_backend_meta[idx] + + selected_custom_backend_meta[idx] = { + field: custom_backend_meta_idx[field] for field in field_names if field in custom_backend_meta_idx + } + # construct new BatchMeta instance - new_batch_meta = BatchMeta(samples=new_samples, extra_info=self.extra_info) + new_batch_meta = BatchMeta( + samples=new_samples, + extra_info=self.extra_info, + custom_meta=self.custom_meta, + _custom_backend_meta=selected_custom_backend_meta, + ) return new_batch_meta @@ -439,7 +464,12 @@ def __len__(self) -> int: def __getitem__(self, item): if isinstance(item, int | np.integer): sample_meta = self.samples[item] if self.samples else [] - return BatchMeta(samples=[sample_meta], extra_info=self.extra_info) + return BatchMeta( + samples=[sample_meta], + extra_info=self.extra_info, + custom_meta=self.custom_meta, + _custom_backend_meta=self._custom_backend_meta, + ) else: raise TypeError(f"Indexing with {type(item)} is not supported now!") @@ -472,7 +502,17 @@ def chunk(self, chunks: int) -> list["BatchMeta"]: current_chunk_size = base_size + 1 if i < remainder else base_size end = start + current_chunk_size chunk_samples = self.samples[start:end] - chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info) + global_indexes = self.global_indexes[start:end] + chunk_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} + chunk_custom_backend_meta = { + i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta + } + chunk = BatchMeta( + samples=chunk_samples, + extra_info=self.extra_info, + custom_meta=chunk_custom_meta, + _custom_backend_meta=chunk_custom_backend_meta, + ) chunk_list.append(chunk) start = end return chunk_list @@ -512,14 +552,14 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": """ if not data: logger.warning("Try to concat empty BatchMeta chunks. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}) + return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) # skip empty chunks data = [chunk for chunk in data if chunk and len(chunk.samples) > 0] if len(data) == 0: logger.warning("No valid BatchMeta chunks to concatenate. Returning empty BatchMeta.") - return BatchMeta(samples=[], extra_info={}) + return BatchMeta(samples=[], extra_info={}, custom_meta={}, _custom_backend_meta={}) if validate: base_fields = data[0].field_names @@ -533,11 +573,20 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": # Merge all extra_info dictionaries from the chunks merged_extra_info = dict() + merged_custom_meta = dict() + merged_custom_backend_meta = dict() values_by_key = defaultdict(list) for chunk in data: + # For the sample-level custom_meta and field-level _custom_backend_meta, we directly update the dict. + merged_custom_meta.update(chunk.custom_meta) + merged_custom_backend_meta.update(chunk._custom_backend_meta) + for key, value in chunk.extra_info.items(): values_by_key[key].append(value) + + # For the batch-level extra_info, we concat the tensor/NonTensorStack/NonTensorData/list + # objects to prevent information losses. for key, values in values_by_key.items(): if all(isinstance(v, torch.Tensor) for v in values): try: @@ -558,7 +607,12 @@ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": else: merged_extra_info[key] = values[-1] - return BatchMeta(samples=all_samples, extra_info=merged_extra_info) + return BatchMeta( + samples=all_samples, + extra_info=merged_extra_info, + custom_meta=merged_custom_meta, + _custom_backend_meta=merged_custom_backend_meta, + ) def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: """ @@ -603,7 +657,26 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet # Merge extra info dictionaries merged_extra_info = {**self.extra_info, **other.extra_info} - return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) + + # Merge custom_meta dictionaries + merged_custom_meta = {**self.custom_meta, **other.custom_meta} + + # Merge custom_backend_meta dictionaries + merged_custom_backend_meta = {} + for idx in self.global_indexes: + if idx in self._custom_backend_meta and other._custom_backend_meta[idx]: + merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} + elif idx in self._custom_backend_meta: + merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} + elif idx in other._custom_backend_meta: + merged_custom_backend_meta[idx] = other._custom_backend_meta[idx] + + return BatchMeta( + samples=merged_samples, + extra_info=merged_extra_info, + custom_meta=merged_custom_meta, + _custom_backend_meta=merged_custom_backend_meta, + ) def reorder(self, indices: list[int]): """ @@ -699,7 +772,7 @@ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": """ if extra_info is None: extra_info = {} - return cls(samples=[], extra_info=extra_info) + return cls(samples=[], extra_info=extra_info, custom_meta={}, _custom_backend_meta={}) def __str__(self): sample_strs = ", ".join(str(sample) for sample in self.samples) @@ -718,6 +791,8 @@ def from_dict(cls, data: dict) -> "BatchMeta": return cls( samples=samples, extra_info=data.get("extra_info", {}), + custom_meta=data.get("custom_meta", {}), + _custom_backend_meta=data.get("_custom_backend_meta", {}), ) diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index b399007..8bdffef 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -215,7 +215,7 @@ def demonstrate_batch_meta(): global_index=0, meta_dict={"uid": "prompt@0", "session_id": "session@0", "model_version": "epoch@0"} ) batch.update_custom_meta({1: {"uid": "prompt@1", "session_id": "session@0", "model_version": "epoch@0"}}) - print(f"✓ Extra info: {batch.get_all_custom_meta()}") + print(f"✓ Custom meta: {batch.get_all_custom_meta()}") # Example 4: Chunk a batch print("[Example 4] Chunking a batch into parts...") From a01e6bdfedc0fdf41b2ed3828c77e80fe9cb59c6 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:11:18 +0800 Subject: [PATCH 13/23] modify select_samples by global_index Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 3 ++- transfer_queue/metadata.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index faf52f4..8c9e653 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -444,7 +444,8 @@ def _update_field_metadata( raise ValueError(f"`global_indices` {len(global_indices)} and `shapes` {len(shapes)} length mismatch.") if custom_backend_meta and len(global_indices) != len(custom_backend_meta): raise ValueError( - f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_backend_meta)} length mismatch." + f"`global_indices` {len(global_indices)} and `custom_backend_meta` {len(custom_backend_meta)} " + f"length mismatch." ) # Extract values for each provided mapping; if a mapping is absent, use Nones diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 76d4ed5..19b0e88 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -391,26 +391,27 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) return self - def select_samples(self, sample_indices: list[int]) -> "BatchMeta": + def select_samples(self, global_indices: list[int]) -> "BatchMeta": """ Select specific samples from this batch. This will construct a new BatchMeta instance containing only the specified samples. Args: - sample_indices (list[int]): List of sample indices to retain. + global_indices (list[int]): List of sample indices to retain. It used the relative + Returns: BatchMeta: A new BatchMeta instance containing only the specified samples. """ - if any(i < 0 or i >= len(self.samples) for i in sample_indices): - raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})") + if any(i not in self.global_indexes for i in global_indices): + raise ValueError("selected global_indices do not exist in this batch!)") - selected_samples = [self.samples[i] for i in sample_indices] + selected_samples = [self.samples[i] for i in global_indices] - selected_custom_meta = {i: self.custom_meta[i] for i in sample_indices if i in self.custom_meta} + selected_custom_meta = {i: self.custom_meta[i] for i in global_indices if i in self.custom_meta} selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in sample_indices if i in self._custom_backend_meta + i: self._custom_backend_meta[i] for i in global_indices if i in self._custom_backend_meta } # construct new BatchMeta instance From bdadf4c3574d1350b57d8a6d066833f7d46bb867 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:11:57 +0800 Subject: [PATCH 14/23] fix Signed-off-by: 0oshowero0 --- tutorial/02_metadata_concepts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index 8bdffef..bbeec54 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -363,7 +363,7 @@ def demonstrate_real_workflow(): client.set_custom_meta(batch_meta) print("✓ Successful to store custom_meta into TQ controller. Now you can retrieve the custom_meta from anywhere.") - print("[Step 2] Try to get metadata from TransferQueue from other places...") + print("[Step 3] Try to get metadata from TransferQueue from other places...") batch_meta = client.get_meta( data_fields=["input_ids", "attention_mask"], batch_size=8, From 339d9edbd4ab42fdab4abf3af1d56b80e415119e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:39:24 +0800 Subject: [PATCH 15/23] fix param name Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 14 ++++++++------ transfer_queue/storage/managers/base.py | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 8c9e653..e0239ba 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -373,7 +373,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], - custom_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields. @@ -384,7 +384,8 @@ def update_production_status( field_names: List of field names to mark as produced dtypes: Optional per-sample field dtype information shapes: Optional per-sample field shape information - custom_meta: Optional per-sample field custom metadata + custom_backend_meta: Optional per-sample per-field + custom metadata provided by storage backend Returns: True if update was successful, False on error @@ -415,7 +416,7 @@ def update_production_status( self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata - self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) + self._update_field_metadata(global_indices, dtypes, shapes, custom_backend_meta) # Save these global_indexes self.global_indexes.update(global_indices) @@ -969,7 +970,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], - custom_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -981,6 +982,7 @@ def update_production_status( field_names: List of field names to mark as produced dtypes: Optional per-sample field dtype information shapes: Optional per-sample field shape information + custom_backend_meta: Optional custom backend metadata Returns: True if update was successful, False otherwise @@ -990,7 +992,7 @@ def update_production_status( logger.error(f"Partition {partition_id} not found") return False - success = partition.update_production_status(global_indexes, field_names, dtypes, shapes, custom_meta) + success = partition.update_production_status(global_indexes, field_names, dtypes, shapes, custom_backend_meta) if success: logger.debug( f"[{self.controller_id}]: Updated production status for partition {partition_id}: " @@ -1679,7 +1681,7 @@ def _update_data_status(self): field_names=message_data.get("fields", []), dtypes=message_data.get("dtypes", {}), shapes=message_data.get("shapes", {}), - custom_meta=message_data.get("custom_meta", {}), + custom_backend_meta=message_data.get("custom_backend_meta", {}), ) if success: diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index a2b3880..2c76e91 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -198,7 +198,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], - custom_meta: Optional[dict[int, dict[str, Any]]] = None, + custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> None: """ Notify controller that new data is ready. @@ -209,7 +209,7 @@ async def notify_data_update( global_indexes: Data update related global_indexes. dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format. shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. - custom_meta: Per-field custom_meta for each field, in {global_index: {field: custom_meta}} format. + custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. """ # Create zmq poller for notifying data update information @@ -234,7 +234,7 @@ async def notify_data_update( "global_indexes": global_indexes, "dtypes": dtypes, "shapes": shapes, - "custom_meta": custom_meta, + "custom_backend_meta": custom_backend_meta, }, ).serialize() From 44fcea0dac306a61e1b0b643e470e0fc4c5cc0dd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:39:36 +0800 Subject: [PATCH 16/23] optimize docstring Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 55 +++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 19e0403..3948153 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -435,7 +435,6 @@ async def async_get_data(self, metadata: BatchMeta) -> TensorDict: >>> batch = asyncio.run(client.async_get_data(batch_meta)) >>> print(batch) >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes - """ if not hasattr(self, "storage_manager") or self.storage_manager is None: @@ -1000,29 +999,48 @@ def get_meta( ) def get_data(self, metadata: BatchMeta) -> TensorDict: - """Synchronously fetch data from storage units. + """Synchronously fetch data from storage units and organize into TensorDict. Args: - metadata: Batch metadata containing data location information + metadata: Batch metadata containing data location information and global indexes Returns: - TensorDict containing requested data fields + TensorDict containing: + - Requested data fields (e.g., "prompts", "attention_mask") + + Example: + >>> batch_meta = client.get_data( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=4, + ... partition_id="train_0", + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> print(batch) + >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes """ return self._get_data(metadata=metadata) def clear_partition(self, partition_id: str): - """Synchronously clear the whole partition from storage units and controller. + """Synchronously clear the whole partition from all storage units and the controller. Args: partition_id: The partition id to clear data for + + Raises: + RuntimeError: If clear operation fails """ return self._clear_partition(partition_id=partition_id) def clear_samples(self, metadata: BatchMeta): - """Synchronously clear specific samples from storage units and controller metadata. + """Synchronously clear specific samples from all storage units and the controller. Args: metadata: The BatchMeta of the corresponding data to be cleared + + Raises: + RuntimeError: If clear operation fails """ return self._clear_samples(metadata=metadata) @@ -1035,6 +1053,17 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: Returns: bool: True if all samples have been consumed by the task, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples have been consumed + >>> is_consumed = client.check_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... ) + >>> print(f"All samples consumed: {is_consumed}") """ return self._check_consumption_status(task_name=task_name, partition_id=partition_id) @@ -1070,8 +1099,16 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> data_fields: Data fields to check production status for partition_id: Partition id to check production status for - Returns: - bool: True if all samples have been produced and ready, False otherwise + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples are ready for consumption + >>> is_ready = client.check_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... ) + >>> print(f"All samples ready: {is_ready}") """ return self._check_production_status(data_fields=data_fields, partition_id=partition_id) @@ -1102,7 +1139,7 @@ def get_production_status( def get_partition_list( self, - ): + ) -> list[str]: """Synchronously fetch the list of partition ids from the controller. Returns: From e66561e7bff024d51c6a19f8e274968def4372c3 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 11:44:17 +0800 Subject: [PATCH 17/23] fix ci Signed-off-by: 0oshowero0 --- tests/test_controller.py | 6 +++--- tests/test_controller_data_partitions.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index e373250..14fd3aa 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -89,7 +89,7 @@ def test_controller_with_single_partition(self, ray_setup): field_names=metadata.field_names, dtypes=dtypes, shapes=shapes, - custom_meta=None, + custom_backend_meta=None, ) ) assert success @@ -498,7 +498,7 @@ def test_controller_with_custom_meta(self, ray_setup): field_names=metadata.field_names, dtypes=dtypes, shapes=shapes, - custom_meta=custom_backend_meta, + custom_backend_meta=custom_backend_meta, ) ) assert success @@ -561,7 +561,7 @@ def test_controller_with_custom_meta(self, ray_setup): field_names=new_metadata.field_names, dtypes=dtypes, shapes=shapes, - custom_meta=None, + custom_backend_meta=None, ) ) assert success diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index e5854c1..31478ba 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -65,7 +65,7 @@ def test_data_partition_status(): 1: {"input_ids": (512,), "attention_mask": (512,)}, 2: {"input_ids": (512,), "attention_mask": (512,)}, }, - custom_meta=None, + custom_backend_meta=None, ) assert success @@ -175,7 +175,7 @@ def test_dynamic_expansion_scenarios(): 5: {"field_1": (32,)}, 10: {"field_1": (32,)}, }, - custom_meta=None, + custom_backend_meta=None, ) assert partition.total_samples_num == 3 assert partition.allocated_samples_num >= 11 # Should accommodate index 10 @@ -474,7 +474,7 @@ def test_custom_meta_in_data_partition_status(): field_names=field_names, dtypes=dtypes, shapes=shapes, - custom_meta=custom_backend_meta, + custom_backend_meta=custom_backend_meta, ) assert success From bc34a6a38193871d23615924b3679881ac58999f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 14:28:15 +0800 Subject: [PATCH 18/23] fix Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 28 ++++++++--------- transfer_queue/client.py | 54 +++++++++++++++++++++++++++++--- transfer_queue/controller.py | 8 ++--- transfer_queue/metadata.py | 49 ++++++++++++++++------------- tutorial/02_metadata_concepts.py | 6 ++-- 5 files changed, 96 insertions(+), 49 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index b2293ea..01ad823 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -469,16 +469,16 @@ def test_batch_meta_reorder(self): ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields), ] batch = BatchMeta(samples=samples) # Reorder to [2, 0, 1] batch.reorder([2, 0, 1]) - assert batch.global_indexes == [2, 0, 1] + assert batch.global_indexes == [6, 4, 5] # Batch indexes are updated assert batch.samples[0].batch_index == 0 assert batch.samples[1].batch_index == 1 @@ -645,20 +645,20 @@ def test_batch_meta_select_samples(self): "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)), } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), - SampleMeta(partition_id="partition_0", global_index=3, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields), + SampleMeta(partition_id="partition_0", global_index=7, fields=fields), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) # Select samples at indices [0, 2] - selected_batch = batch.select_samples([0, 2]) + selected_batch = batch.select_samples([0, 2]) # This will select the first two samples with global_index=4/5 # Check number of samples assert len(selected_batch) == 2 # Check global indexes - assert selected_batch.global_indexes == [0, 2] + assert selected_batch.global_indexes == [4, 6] # Check fields are preserved for sample in selected_batch.samples: assert "field1" in sample.fields @@ -676,9 +676,9 @@ def test_batch_meta_select_samples_all_indices(self): ) } samples = [ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields), - SampleMeta(partition_id="partition_0", global_index=2, fields=fields), + SampleMeta(partition_id="partition_0", global_index=4, fields=fields), + SampleMeta(partition_id="partition_0", global_index=5, fields=fields), + SampleMeta(partition_id="partition_0", global_index=6, fields=fields), ] batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"}) @@ -687,7 +687,7 @@ def test_batch_meta_select_samples_all_indices(self): # All samples are selected assert len(selected_batch) == 3 - assert selected_batch.global_indexes == [0, 1, 2] + assert selected_batch.global_indexes == [4, 5, 6] # Extra info is preserved assert selected_batch.extra_info["test_key"] == "test_value" diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3948153..475b229 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -329,6 +329,9 @@ async def async_put( If metadata is not provided, it will be created automatically using insert mode with the provided data fields and partition_id. + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + Note: When using multiple workers for distributed execution, there may be data ordering inconsistencies between workers during put operations. @@ -374,8 +377,7 @@ async def async_put( >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) - + >>> metadata = asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id)) """ if not hasattr(self, "storage_manager") or self.storage_manager is None: @@ -919,16 +921,60 @@ def wrapper(*args, **kwargs): def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None ) -> BatchMeta: - """Synchronously write data to storage units. + """Synchronously write data to storage units based on metadata. + + If metadata is not provided, it will be created automatically using insert mode + with the provided data fields and partition_id. + + During put, the custom_meta in metadata will update the corresponding custom_meta in + TransferQueue Controller. + + Note: + When using multiple workers for distributed execution, there may be data + ordering inconsistencies between workers during put operations. Args: data: Data to write as TensorDict - metadata: Optional metadata containing index and storage unit information + metadata: Records the metadata of a batch of data samples, containing index and + storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved metadata; will be updated in a future version to reflect the post-put state) + + Raises: + ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided + RuntimeError: If storage operation fails + + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_partition_id = "train_0" + >>> # Example 1: Normal usage with existing metadata + >>> batch_meta = client.get_meta( + ... data_fields=["prompts", "attention_mask"], + ... batch_size=batch_size, + ... partition_id=current_partition_id, + ... mode="fetch", + ... task_name="generate_sequences", + ... ) + >>> batch = client.get_data(batch_meta) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> client.put(data=output, metadata=batch_meta) + >>> + >>> # Example 2: Initial data insertion without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id! + >>> # Please make sure the corresponding partition_id is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) + >>> # This will create metadata in "insert" mode internally. + >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ return self._put(data=data, metadata=metadata, partition_id=partition_id) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index e0239ba..c104219 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1190,12 +1190,8 @@ def get_metadata( ) elif mode == "force_fetch": - if partition_id is not None: - batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) - consumed_indexes = [] - else: - batch_global_indexes = list(sorted(self.index_manager.allocated_indexes)) - consumed_indexes = [] + batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) + consumed_indexes = [] # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 19b0e88..5e56ebf 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -391,27 +391,31 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) return self - def select_samples(self, global_indices: list[int]) -> "BatchMeta": + def select_samples(self, indexes: list[int]) -> "BatchMeta": """ Select specific samples from this batch. This will construct a new BatchMeta instance containing only the specified samples. Args: - global_indices (list[int]): List of sample indices to retain. It used the relative - + indexes (list[int]): List of indexes (relative to this batch, not global_indexes) + to retain. Returns: BatchMeta: A new BatchMeta instance containing only the specified samples. """ - if any(i not in self.global_indexes for i in global_indices): - raise ValueError("selected global_indices do not exist in this batch!)") + if any(i < 0 or i >= len(self.samples) for i in indexes): + raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})") - selected_samples = [self.samples[i] for i in global_indices] + selected_samples = [self.samples[i] for i in indexes] - selected_custom_meta = {i: self.custom_meta[i] for i in global_indices if i in self.custom_meta} + selected_custom_meta = { + i: self.custom_meta[self.global_indexes[i]] for i in indexes if self.global_indexes[i] in self.custom_meta + } selected_custom_backend_meta = { - i: self._custom_backend_meta[i] for i in global_indices if i in self._custom_backend_meta + i: self._custom_backend_meta[self.global_indexes[i]] + for i in indexes + if self.global_indexes[i] in self._custom_backend_meta } # construct new BatchMeta instance @@ -465,11 +469,12 @@ def __len__(self) -> int: def __getitem__(self, item): if isinstance(item, int | np.integer): sample_meta = self.samples[item] if self.samples else [] + global_idx = self.global_indexes[item] return BatchMeta( samples=[sample_meta], extra_info=self.extra_info, - custom_meta=self.custom_meta, - _custom_backend_meta=self._custom_backend_meta, + custom_meta={global_idx: self.custom_meta[global_idx]}, + _custom_backend_meta={global_idx: self._custom_backend_meta[global_idx]}, ) else: raise TypeError(f"Indexing with {type(item)} is not supported now!") @@ -532,7 +537,7 @@ def chunk_by_partition( for partition_id, global_index in zip(self.partition_ids, self.global_indexes, strict=False): grouped_global_indexes[partition_id].append(global_index) - chunk_list = [self.select_samples(samples) for samples in grouped_global_indexes.values()] + chunk_list = [self.select_samples(global_indices) for global_indices in grouped_global_indexes.values()] return chunk_list @@ -665,7 +670,7 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet # Merge custom_backend_meta dictionaries merged_custom_backend_meta = {} for idx in self.global_indexes: - if idx in self._custom_backend_meta and other._custom_backend_meta[idx]: + if idx in self._custom_backend_meta and idx in other._custom_backend_meta: merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx], **other._custom_backend_meta[idx]} elif idx in self._custom_backend_meta: merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} @@ -679,38 +684,38 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet _custom_backend_meta=merged_custom_backend_meta, ) - def reorder(self, indices: list[int]): + def reorder(self, indexes: list[int]): """ - Reorder the SampleMeta in the BatchMeta according to the given indices (must equal to the length of samples). + Reorder the SampleMeta in the BatchMeta according to the given indexes (must equal to the length of samples). The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. To select a subset of samples or repeat specific samples, please use the non-inplace method select_samples(). Args: - indices : list[int] + indexes : list[int] A list of integers specifying the new order of SampleMeta. Each integer represents the current index of the SampleMeta in the BatchMeta. """ - if len(indices) != self.size: + if len(indexes) != self.size: raise ValueError( - f"Attempted to reorder with indices length {len(indices)} that does not match samples length " + f"Attempted to reorder with indexes length {len(indexes)} that does not match samples length " f"{self.size}. Please use non-inplace method select_samples() instead if you want to " f"select a subset of samples or repeat specific samples." ) - if len(set(indices)) != self.size: + if len(set(indexes)) != self.size: raise ValueError( - f"Indices={indices} contain duplicates. Please use non-inplace method " + f"Indexes={indexes} contain duplicates. Please use non-inplace method " f"select_samples() instead if you want to select a subset of samples or repeat specific samples." ) - if any(i < 0 or i >= len(self.samples) for i in indices): - raise ValueError(f"Reorder indices must be in the range [0, {self.size}).") + if any(i < 0 or i >= len(self.samples) for i in indexes): + raise ValueError(f"Reorder indexes must be in the range [0, {self.size}).") # Reorder the samples - reordered_samples = [self.samples[i] for i in indices] + reordered_samples = [self.samples[i] for i in indexes] object.__setattr__(self, "samples", reordered_samples) # Update necessary attributes diff --git a/tutorial/02_metadata_concepts.py b/tutorial/02_metadata_concepts.py index bbeec54..93e59ac 100644 --- a/tutorial/02_metadata_concepts.py +++ b/tutorial/02_metadata_concepts.py @@ -378,7 +378,7 @@ def demonstrate_real_workflow(): print(f" Sample structure: {batch_meta.samples[0]}") print(f" Custom Meta: {batch_meta.get_all_custom_meta()}") - print("[Step 3] Retrieve samples with specific fields..") + print("[Step 4] Retrieve samples with specific fields..") selected_meta = batch_meta.select_fields(["input_ids"]) print("✓ Selected 'input_ids' field only:") print(f" New field names: {selected_meta.field_names}") @@ -386,7 +386,7 @@ def demonstrate_real_workflow(): retrieved_data = client.get_data(selected_meta) print(f" Retrieved data keys: {list(retrieved_data.keys())}") - print("[Step 4] Select specific samples from the retrieved BatchMeta...") + print("[Step 5] Select specific samples from the retrieved BatchMeta...") partial_meta = batch_meta.select_samples([0, 2, 4, 6]) print("✓ Selected samples at indices [0, 2, 4, 6]:") print(f" New global indexes: {partial_meta.global_indexes}") @@ -394,7 +394,7 @@ def demonstrate_real_workflow(): retrieved_data = client.get_data(partial_meta) print(f" Retrieved data samples: {retrieved_data}, all the data samples: {data_batch}") - print("[Step 5] Demonstrate chunk operation...") + print("[Step 6] Demonstrate chunk operation...") chunks = batch_meta.chunk(2) print(f"✓ Chunked into {len(chunks)} parts:") for i, chunk in enumerate(chunks): From f422aeb0fa789c41f7fd9330847f486ddbbf6957 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 15:04:25 +0800 Subject: [PATCH 19/23] fix Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 34 +++++++++++++++++----------------- transfer_queue/metadata.py | 32 ++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 01ad823..ef34e61 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -215,11 +215,11 @@ def test_batch_meta_chunk_by_partition(self): name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME ) } - samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i, fields=fields) for i in range(10)] + samples = [SampleMeta(partition_id=f"partition_{i % 4}", global_index=i + 10, fields=fields) for i in range(10)] batch = BatchMeta( samples=samples, - custom_meta={i: {"uid": i} for i in range(10)}, - _custom_backend_meta={i: {"test_field": {"dtype": torch.float32}} for i in range(10)}, + custom_meta={i + 10: {"uid": i + 10} for i in range(10)}, + _custom_backend_meta={i + 10: {"test_field": {"dtype": torch.float32}} for i in range(10)}, ) # Chunk according to partition_id @@ -228,30 +228,30 @@ def test_batch_meta_chunk_by_partition(self): assert len(chunks) == 4 assert len(chunks[0]) == 3 assert chunks[0].partition_ids == ["partition_0", "partition_0", "partition_0"] - assert chunks[0].global_indexes == [0, 4, 8] + assert chunks[0].global_indexes == [10, 14, 18] assert len(chunks[1]) == 3 assert chunks[1].partition_ids == ["partition_1", "partition_1", "partition_1"] - assert chunks[1].global_indexes == [1, 5, 9] + assert chunks[1].global_indexes == [11, 15, 19] assert len(chunks[2]) == 2 assert chunks[2].partition_ids == ["partition_2", "partition_2"] - assert chunks[2].global_indexes == [2, 6] + assert chunks[2].global_indexes == [12, 16] assert len(chunks[3]) == 2 assert chunks[3].partition_ids == ["partition_3", "partition_3"] - assert chunks[3].global_indexes == [3, 7] + assert chunks[3].global_indexes == [13, 17] # validate _custom_backend_meta is chunked - assert 0 in chunks[0].custom_meta - assert 4 in chunks[0].custom_meta - assert 8 in chunks[0].custom_meta - assert 1 not in chunks[0].custom_meta - assert 1 in chunks[1].custom_meta + assert 10 in chunks[0].custom_meta + assert 14 in chunks[0].custom_meta + assert 18 in chunks[0].custom_meta + assert 11 not in chunks[0].custom_meta + assert 11 in chunks[1].custom_meta # validate _custom_backend_meta is chunked - assert 0 in chunks[0]._custom_backend_meta - assert 4 in chunks[0]._custom_backend_meta - assert 8 in chunks[0]._custom_backend_meta - assert 1 not in chunks[0]._custom_backend_meta - assert 1 in chunks[1]._custom_backend_meta + assert 10 in chunks[0]._custom_backend_meta + assert 14 in chunks[0]._custom_backend_meta + assert 18 in chunks[0]._custom_backend_meta + assert 11 not in chunks[0]._custom_backend_meta + assert 11 in chunks[1]._custom_backend_meta def test_batch_meta_init_validation_error_different_field_names(self): """Example: Init validation catches samples with different field names.""" diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 5e56ebf..b54a125 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -409,13 +409,10 @@ def select_samples(self, indexes: list[int]) -> "BatchMeta": selected_samples = [self.samples[i] for i in indexes] - selected_custom_meta = { - i: self.custom_meta[self.global_indexes[i]] for i in indexes if self.global_indexes[i] in self.custom_meta - } + global_indexes = [self.global_indexes[i] for i in indexes] + selected_custom_meta = {i: self.custom_meta[i] for i in global_indexes if i in self.custom_meta} selected_custom_backend_meta = { - i: self._custom_backend_meta[self.global_indexes[i]] - for i in indexes - if self.global_indexes[i] in self._custom_backend_meta + i: self._custom_backend_meta[i] for i in global_indexes if i in self._custom_backend_meta } # construct new BatchMeta instance @@ -470,11 +467,22 @@ def __getitem__(self, item): if isinstance(item, int | np.integer): sample_meta = self.samples[item] if self.samples else [] global_idx = self.global_indexes[item] + + if global_idx in self.custom_meta: + custom_meta = {global_idx: self.custom_meta[global_idx]} + else: + custom_meta = {} + + if global_idx in self._custom_backend_meta: + custom_backend_meta = {global_idx: self._custom_backend_meta[global_idx]} + else: + custom_backend_meta = {} + return BatchMeta( samples=[sample_meta], extra_info=self.extra_info, - custom_meta={global_idx: self.custom_meta[global_idx]}, - _custom_backend_meta={global_idx: self._custom_backend_meta[global_idx]}, + custom_meta=custom_meta, + _custom_backend_meta=custom_backend_meta, ) else: raise TypeError(f"Indexing with {type(item)} is not supported now!") @@ -533,11 +541,11 @@ def chunk_by_partition( List of smaller BatchMeta chunks, each chunk has samples with identical partition_id """ - grouped_global_indexes = defaultdict(list) - for partition_id, global_index in zip(self.partition_ids, self.global_indexes, strict=False): - grouped_global_indexes[partition_id].append(global_index) + grouped_indexes = defaultdict(list) + for partition_id, indexes in zip(self.partition_ids, range(self.size), strict=False): + grouped_indexes[partition_id].append(indexes) - chunk_list = [self.select_samples(global_indices) for global_indices in grouped_global_indexes.values()] + chunk_list = [self.select_samples(idx) for idx in grouped_indexes.values()] return chunk_list From abf2112b485a70b3fa1e86a7cb9d82c9ac94b5ef Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 16:46:15 +0800 Subject: [PATCH 20/23] minor update Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 14 +++++++------- transfer_queue/client.py | 2 +- transfer_queue/controller.py | 8 ++++++-- transfer_queue/metadata.py | 6 ++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index ef34e61..98c4b57 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -239,7 +239,7 @@ def test_batch_meta_chunk_by_partition(self): assert chunks[3].partition_ids == ["partition_3", "partition_3"] assert chunks[3].global_indexes == [13, 17] - # validate _custom_backend_meta is chunked + # validate custom_meta is chunked assert 10 in chunks[0].custom_meta assert 14 in chunks[0].custom_meta assert 18 in chunks[0].custom_meta @@ -406,22 +406,22 @@ def test_batch_meta_union(self): batch1 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields1), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields1), + SampleMeta(partition_id="partition_0", global_index=8, fields=fields1), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields1), ], _custom_backend_meta={ - i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [0, 1] + i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}} for i in [8, 9] }, ) batch1.extra_info["info1"] = "value1" batch2 = BatchMeta( samples=[ - SampleMeta(partition_id="partition_0", global_index=0, fields=fields2), - SampleMeta(partition_id="partition_0", global_index=1, fields=fields2), + SampleMeta(partition_id="partition_0", global_index=8, fields=fields2), + SampleMeta(partition_id="partition_0", global_index=9, fields=fields2), ], _custom_backend_meta={ - i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [0, 1] + i: {"field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} for i in [8, 9] }, ) batch2.extra_info["info2"] = "value2" diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 475b229..1bfc9ff 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -314,7 +314,7 @@ async def async_set_custom_meta( if response_msg.request_type != ZMQRequestType.SET_CUSTOM_META_RESPONSE: raise RuntimeError( - f"[{self.client_id}]: Failed to set custom metadata from controller {self._controller.id}: " + f"[{self.client_id}]: Failed to set custom metadata to controller {self._controller.id}: " f"{response_msg.body.get('message', 'Unknown error')}" ) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c104219..0049c4e 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -703,8 +703,7 @@ def set_custom_meta(self, custom_meta: dict[int, dict]) -> None: Existing entries will be overwritten. """ - for k in custom_meta.keys(): - self.custom_meta[k] = custom_meta[k] + self.custom_meta.update(custom_meta) # ==================== Statistics and Monitoring ==================== @@ -1076,6 +1075,11 @@ def set_custom_meta(self, partition_custom_meta: dict[str, dict[int, dict]]) -> partition = self._get_partition(partition_id) if partition: partition.set_custom_meta(custom_meta) + else: + logger.warning( + f"set_custom_meta: partition {partition_id}' not found; " + f"custom_metadata for this partition will be ignored" + ) def get_metadata( self, diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index b54a125..5e25404 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -215,8 +215,6 @@ def __post_init__(self): """Initialize all computed properties during initialization""" self.samples = copy.deepcopy(self.samples) self.extra_info = copy.deepcopy(self.extra_info) - self.custom_meta = copy.deepcopy(self.custom_meta) - self._custom_backend_meta = copy.deepcopy(self._custom_backend_meta) # Basic properties object.__setattr__(self, "_size", len(self.samples)) @@ -352,7 +350,7 @@ def update_custom_meta(self, new_meta: dict[int, dict[str, Any]]): return non_exist_global_indexes = set(new_meta.keys()) - set(self.global_indexes) - if bool(non_exist_global_indexes): + if non_exist_global_indexes: raise ValueError( f"Trying to update custom_meta with non-exist global_indexes! {non_exist_global_indexes} " f"do not exist in this batch." @@ -683,7 +681,7 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet elif idx in self._custom_backend_meta: merged_custom_backend_meta[idx] = {**self._custom_backend_meta[idx]} elif idx in other._custom_backend_meta: - merged_custom_backend_meta[idx] = other._custom_backend_meta[idx] + merged_custom_backend_meta[idx] = {**other._custom_backend_meta[idx]} return BatchMeta( samples=merged_samples, From c2d75ce38573644f4ae28f2f4ef24639b39f54c4 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 2 Feb 2026 16:54:35 +0800 Subject: [PATCH 21/23] fix Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 98c4b57..2bbf40c 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -441,7 +441,7 @@ def test_batch_meta_union(self): # _custom_backend_meta is merged assert result._custom_backend_meta == { i: {"field1": {"dtype": torch.float32}, "field2": {"dtype": torch.int64}, "field3": {"dtype": torch.bool}} - for i in [0, 1] + for i in [8, 9] } def test_batch_meta_union_validation(self): From c5ea293275523e6da660c485002de69f5f543712 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 3 Feb 2026 09:22:09 +0800 Subject: [PATCH 22/23] update Signed-off-by: 0oshowero0 --- transfer_queue/dataloader/streaming_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transfer_queue/dataloader/streaming_dataset.py b/transfer_queue/dataloader/streaming_dataset.py index 4b54fb3..eb2afe4 100644 --- a/transfer_queue/dataloader/streaming_dataset.py +++ b/transfer_queue/dataloader/streaming_dataset.py @@ -121,8 +121,8 @@ def __init__( self.partition_id = partition_id self.task_name = task_name self.dp_rank = dp_rank - self.get_batch_func = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn - self.post_process_for_micro_func = process_batch_fn if process_batch_fn else chunk_batch_fn + self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn + self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn # Build sampling config for controller self.sampling_config = { @@ -195,10 +195,10 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: if self.batch_index <= len(self.buffer) - 1: current_data = self.buffer[self.batch_index] self.batch_index += 1 - yield from self.post_process_for_micro_func(*current_data, micro_batch_size=self.micro_batch_size) + yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size) else: - batch_data, batch_meta = self.get_batch_func( + batch_data, batch_meta = self.fetch_batch_fn( self._tq_client, self.data_fields, self.batch_size, From 45394fd92243bbd32d1e85eaab46f49b27f926cb Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 3 Feb 2026 09:40:40 +0800 Subject: [PATCH 23/23] fix Signed-off-by: 0oshowero0 --- tutorial/05_streaming_dataloader.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tutorial/05_streaming_dataloader.py b/tutorial/05_streaming_dataloader.py index 96ac33f..916be00 100644 --- a/tutorial/05_streaming_dataloader.py +++ b/tutorial/05_streaming_dataloader.py @@ -22,11 +22,11 @@ Key Components: - StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue - StreamingDataLoader: DataLoader wrapper that yields (batch, batch_meta) tuples -- RankAwareSampler: Enables data replica group coordination for consistent +- RankAwareSampler: Enables DP group coordination for consistent sampling across multiple ranks Use Cases: -- Distributed training with multiple data replica groups +- Distributed training with multiple DP groups - Fine-grained micro-batch-level data retrieval """ @@ -94,7 +94,7 @@ def setup_transfer_queue(): print("[Setup]: Setup TransferQueue components") print( "Note: Using RankAwareSampler when each rank retrieves data independently. It guarantees that " - "all ranks within the same data replica group receive the same sample indices." + "The same DP rank receives the same sample indices." ) print( "Note: When using streaming data retrieval, please set polling_mode=True when initializing " @@ -102,7 +102,7 @@ def setup_transfer_queue(): "available data cannot meet the consumption requirements. User side need to retry later." ) controller = TransferQueueController.remote( - sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling across ranks in same replica group + sampler=RankAwareSampler, # RankAwareSampler enables consistent sampling for each DP rank polling_mode=True, # Enable polling mode for streaming data retrieval ) @@ -186,7 +186,7 @@ def update_worker( max_steps: Maximum number of batches to consume Returns: - dict: Contains data_replica_rank, data_replica_group, and consumed_ids + dict: Contains dp_rank and consumed_ids Example: For a setup with 2 data rank (0 and 1): @@ -317,8 +317,8 @@ def main(): Key Concepts: - StreamingDataset: PyTorch IterableDataset that integrates with TransferQueue - StreamingDataLoader: DataLoader wrapper yielding (batch, batch_meta) tuples - - RankAwareSampler: Enables correct data consumption across data replica ranks - - Data Replica Group: Ranks that should receive identical data samples (TP, PP, ...) + - RankAwareSampler: Enables correct data consumption across DP ranks + - DP Rank: Ranks that should receive identical data samples """ ) ) @@ -358,10 +358,7 @@ def main(): print("Results Summary") print("=" * 80) for result in update_results: - print( - f" Rank {result['data_replica_rank']} (Group {result['data_replica_group']}): " - f"consumed {len(result['consumed_ids'])} samples" - ) + print(f" DP Rank {result['dp_rank']}: consumed {len(result['consumed_ids'])} samples") print("\n" + "=" * 80) print("Tutorial Complete!") @@ -369,7 +366,7 @@ def main(): print("Key Takeaways:") print("1. StreamingDataset provides PyTorch IterableDataset interface for TransferQueue") print("2. StreamingDataLoader wraps the dataset and yields (batch, batch_meta) tuples") - print("3. Ranks in the same data_replica_group receive identical samples") + print("3. Ranks with the same DP rank receive identical samples") print("4. The system enables efficient streaming capabilities")