diff --git a/tests/test_yuanrong_storage_manager.py b/tests/test_yuanrong_client_zero_copy.py similarity index 84% rename from tests/test_yuanrong_storage_manager.py rename to tests/test_yuanrong_client_zero_copy.py index 9433d6a..3048ec5 100644 --- a/tests/test_yuanrong_storage_manager.py +++ b/tests/test_yuanrong_client_zero_copy.py @@ -25,7 +25,7 @@ sys.path.append(str(parent_dir)) from transfer_queue.storage.clients.yuanrong_client import ( # noqa: E402 - YuanrongStorageClient, + GeneralKVClientAdapter, ) @@ -37,7 +37,7 @@ def MutableData(self): return self.data -class TestYuanrongStorageZCopy: +class TestYuanrongKVClientZCopy: @pytest.fixture def mock_kv_client(self, mocker): mock_client = MagicMock() @@ -45,13 +45,12 @@ def mock_kv_client(self, mocker): mocker.patch("yr.datasystem.KVClient", return_value=mock_client) mocker.patch("yr.datasystem.DsTensorClient") - mocker.patch("transfer_queue.storage.clients.yuanrong_client.TORCH_NPU_IMPORTED", False) return mock_client @pytest.fixture def storage_client(self, mock_kv_client): - return YuanrongStorageClient({"host": "127.0.0.1", "port": 31501}) + return GeneralKVClientAdapter({"host": "127.0.0.1", "port": 31501}) def test_mset_mget_p2p(self, storage_client, mocker): # Mock serialization/deserialization @@ -80,13 +79,13 @@ def side_effect_mcreate(keys, sizes): stored_raw_buffers.append(b.MutableData()) return buffers - storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate - storage_client._cpu_ds_client.get_buffers.return_value = stored_raw_buffers + storage_client._ds_client.mcreate.side_effect = side_effect_mcreate + storage_client._ds_client.get_buffers.return_value = stored_raw_buffers - storage_client.mset_zcopy( + storage_client.mset_zero_copy( ["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"] ) - results = storage_client.mget_zcopy(["tensor_key", "string_key"]) + results = storage_client.mget_zero_copy(["tensor_key", "string_key"]) assert torch.allclose(results[0], torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)) assert results[1] == "hello yuanrong" diff --git a/tests/test_yuanrong_storage_client_e2e.py b/tests/test_yuanrong_storage_client_e2e.py new file mode 100644 index 0000000..519335f --- /dev/null +++ b/tests/test_yuanrong_storage_client_e2e.py @@ -0,0 +1,214 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest import mock + +import pytest +import torch + +try: + import torch_npu # noqa: F401 +except ImportError: + pass + + +# --- Mock Backend Implementation --- +# In real scenarios, multiple DsTensorClients or KVClients share storage. +# Here, each mockClient is implemented with independent storage using a simple dictionary, +# and is only suitable for unit testing. + + +class MockDsTensorClient: + def __init__(self, host, port, device_id): + self.storage = {} + + def init(self): + pass + + def dev_mset(self, keys, values): + for k, v in zip(keys, values, strict=True): + assert v.device.type == "npu" + self.storage[k] = v + + def dev_mget(self, keys, out_tensors): + for i, k in enumerate(keys): + # Note: If key is missing, tensor remains unchanged (mock limitation) + if k in self.storage: + out_tensors[i].copy_(self.storage[k]) + + def dev_delete(self, keys): + for k in keys: + self.storage.pop(k, None) + + +class MockKVClient: + def __init__(self, host, port): + self.storage = {} + + def init(self): + pass + + def mcreate(self, keys, sizes): + class MockBuffer: + def __init__(self, size): + self._data = bytearray(size) + + def MutableData(self): + return memoryview(self._data) + + self._current_keys = keys + return [MockBuffer(s) for s in sizes] + + def mset_buffer(self, buffers): + for key, buf in zip(self._current_keys, buffers, strict=True): + self.storage[key] = bytes(buf.MutableData()) + + def get_buffers(self, keys): + return [memoryview(self.storage[k]) if k in self.storage else None for k in keys] + + def delete(self, keys): + for k in keys: + self.storage.pop(k, None) + + +# --- Fixtures --- + + +@pytest.fixture +def mock_yr_datasystem(): + """Wipe real 'yr' modules and inject mocks.""" + + # 1. Clean up sys.modules to force a fresh import under mock conditions + # This ensures top-level code in yuanrong_client.py is re-executed + to_delete = [k for k in sys.modules if k.startswith("yr")] + for mod in to_delete: + del sys.modules[mod] + + # 2. Setup Mock Objects + ds_mock = mock.MagicMock() + ds_mock.DsTensorClient = MockDsTensorClient + ds_mock.KVClient = MockKVClient + + yr_mock = mock.MagicMock(datasystem=ds_mock) + + # 3. Apply patches + # - sys.modules: Redirects 'import yr' to our mocks + # - YUANRONG_DATASYSTEM_IMPORTED: Forces the existence check to True so initialize the client successfully + # - datasystem: Direct attribute patch for the module + with ( + mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}), + mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True), + mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock), + ): + yield + + +@pytest.fixture +def config(): + return {"host": "127.0.0.1", "port": 12345, "enable_yr_npu_optimization": True} + + +def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor): + assert a.shape == b.shape and a.dtype == b.dtype + # Move to CPU for cross-device comparison + assert torch.equal(a.cpu(), b.cpu()) + + +# --- Test Suite --- + + +class TestYuanrongStorageE2E: + @pytest.fixture(autouse=True) + def setup_client(self, mock_yr_datasystem, config): + # Lazy import to ensure mocks are active + from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient + + self.client_cls = YuanrongStorageClient + self.config = config + + def _create_data(self, mode="cpu"): + if mode == "cpu": + keys = ["t", "s", "i"] + vals = [torch.randn(2), "hi", 1] + elif mode == "npu": + if not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU required") + keys = ["n1", "n2"] + vals = [torch.randn(2).npu(), torch.tensor([1]).npu()] + else: # mixed + if not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU required") + keys = ["n1", "c1"] + vals = [torch.randn(2).npu(), "cpu"] + + shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals] + dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals] + return keys, vals, shapes, dtypes + + def test_mock_can_work(self, config): + mock_class = (MockDsTensorClient, MockKVClient) + client = self.client_cls(config) + for strategy in client._strategies: + assert isinstance(strategy._ds_client, mock_class) + + def test_cpu_only_flow(self, config): + client = self.client_cls(config) + keys, vals, shp, dt = self._create_data("cpu") + + # Put & Verify Meta + meta = client.put(keys, vals) + # "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path. + assert all(m == "2" for m in meta) + + # Get & Verify Values + ret = client.get(keys, shp, dt, meta) + for o, r in zip(vals, ret, strict=True): + if isinstance(o, torch.Tensor): + assert_tensors_equal(o, r) + else: + assert o == r + + # Clear & Verify + client.clear(keys, meta) + assert all(v is None for v in client.get(keys, shp, dt, meta)) + + def test_npu_only_flow(self, config): + keys, vals, shp, dt = self._create_data("npu") + client = self.client_cls(config) + + meta = client.put(keys, vals) + # "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path. + assert all(m == "1" for m in meta) + + ret = client.get(keys, shp, dt, meta) + for o, r in zip(vals, ret, strict=True): + assert_tensors_equal(o, r) + + client.clear(keys, meta) + + def test_mixed_flow(self, config): + keys, vals, shp, dt = self._create_data("mixed") + client = self.client_cls(config) + + meta = client.put(keys, vals) + assert set(meta) == {"1", "2"} + + ret = client.get(keys, shp, dt, meta) + for o, r in zip(vals, ret, strict=True): + if isinstance(o, torch.Tensor): + assert_tensors_equal(o, r) + else: + assert o == r diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index 90c33fe..a6d63f6 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -65,6 +65,6 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non raise NotImplementedError("Subclasses must implement get") @abstractmethod - def clear(self, keys: list[str]) -> None: + def clear(self, keys: list[str], custom_backend_meta=None) -> None: """Clear key-value pairs in the storage backend.""" raise NotImplementedError("Subclasses must implement clear") diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 6b730ca..4b4d9a3 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -139,7 +139,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key + custom_backend_meta (List[str], optional): ... Returns: List[Any]: Retrieved values in the same order as input keys. @@ -216,11 +216,12 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: results.extend(batch_results) return results - def clear(self, keys: list[str]): + def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from MooncakeStore. Args: keys (List[str]): List of keys to remove. + custom_backend_meta (List[Any], optional): ... """ for key in keys: ret = self._store.remove(key) diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index 8bd4468..c290f6f 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -106,10 +106,11 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non raise RuntimeError(f"Failed to retrieve value for key '{keys}': {e}") from e return values - def clear(self, keys: list[str]): + def clear(self, keys: list[str], custom_backend_meta=None): """ Delete entries from storage by keys. Args: keys (list): List of keys to delete + custom_backend_meta (List[Any], optional): ... """ ray.get(self.storage_actor.clear_obj_ref.remote(keys)) diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index b408b93..41219c2 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -15,10 +15,10 @@ import logging import os -import pickle import struct +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional, TypeAlias +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -27,146 +27,141 @@ from transfer_queue.storage.clients.factory import StorageClientFactory from transfer_queue.utils.serial_utils import _decoder, _encoder -bytestr: TypeAlias = bytes | bytearray | memoryview - logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) -NPU_DS_CLIENT_KEYS_LIMIT: int = 9999 -CPU_DS_CLIENT_KEYS_LIMIT: int = 1999 YUANRONG_DATASYSTEM_IMPORTED: bool = True -TORCH_NPU_IMPORTED: bool = True -DS_MAX_WORKERS: int = 16 + try: from yr import datasystem except ImportError: YUANRONG_DATASYSTEM_IMPORTED = False -# Header: number of entries (uint32, little-endian) -HEADER_FMT = " int: - """ - Calculate the total size (in bytes) required to pack a list of memoryview items - into the structured binary format used by pack_into. + @staticmethod + @abstractmethod + def init(config: dict) -> Optional["StorageStrategy"]: + """Initialize strategy from config; return None if not applicable.""" - Args: - items: List of memoryview objects to be packed. + @abstractmethod + def strategy_tag(self) -> Any: + """Return metadata identifying this strategy (e.g., string name, byte tag).""" - Returns: - Total buffer size in bytes. - """ - return HEADER_SIZE + len(items) * ENTRY_SIZE + sum(item.nbytes for item in items) + @abstractmethod + def supports_put(self, value: Any) -> bool: + """Check if this strategy can store the given value.""" + @abstractmethod + def put(self, keys: list[str], values: list[Any]): + """Store key-value pairs using this strategy.""" -def pack_into(target: memoryview, items: list[memoryview]): - """ - Pack multiple contiguous buffers into a single buffer. - ┌───────────────┐ - │ item_count │ uint32 - ├───────────────┤ - │ entries │ N * item entries - ├───────────────┤ - │ payload blob │ N * concatenated buffers - └───────────────┘ - - Args: - target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData(). - It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items. - This buffer is usually mapped to shared memory or Zero-Copy memory area. - items (List[memoryview]): List of read-only memory views (e.g., from serialized objects). Each item must support - the buffer protocol and be readable as raw bytes. - - """ - struct.pack_into(HEADER_FMT, target, 0, len(items)) + @abstractmethod + def supports_get(self, strategy_tag: Any) -> bool: + """Check if this strategy can retrieve data with given tag.""" - entry_offset = HEADER_SIZE - payload_offset = HEADER_SIZE + len(items) * ENTRY_SIZE + @abstractmethod + def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + """Retrieve values by keys; kwargs may include shapes/dtypes.""" - target_tensor = torch.frombuffer(target, dtype=torch.uint8) + @abstractmethod + def supports_clear(self, strategy_tag: Any) -> bool: + """Check if this strategy owns data identified by metadata.""" - for item in items: - struct.pack_into(ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes) - src_tensor = torch.frombuffer(item, dtype=torch.uint8) - target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor) - entry_offset += ENTRY_SIZE - payload_offset += item.nbytes + @abstractmethod + def clear(self, keys: list[str]): + """Delete keys from storage.""" -def unpack_from(source: memoryview) -> list[bytestr]: - """ - Unpack multiple contiguous buffers from a single packed buffer. - Args: - source (memoryview): The packed source buffer. - Returns: - list[bytestr]: List of unpacked contiguous buffers. +class NPUTensorKVClientAdapter(StorageStrategy): + """Adapter for YuanRong's high-performance NPU tensor storage. + Using yr.datasystem.DsTensorClient to connect datasystem backends. """ - mv = memoryview(source) - item_count = struct.unpack_from(HEADER_FMT, mv, 0)[0] - offsets = [] - for i in range(item_count): - offset, length = struct.unpack_from(ENTRY_FMT, mv, HEADER_SIZE + i * ENTRY_SIZE) - offsets.append((offset, length)) - return [mv[offset : offset + length] for offset, length in offsets] + KEYS_LIMIT: int = 10_000 -@StorageClientFactory.register("YuanrongStorageClient") -class YuanrongStorageClient(TransferQueueStorageKVClient): - """ - Storage client for YuanRong DataSystem. + def __init__(self, config: dict): + host = config.get("host") + port = config.get("port") - Supports storing and fetching both: - - NPU tensors via DsTensorClient (for high performance). - - General objects (CPU tensors, str, bool, list, etc.) via KVClient with pickle serialization. - """ + self.device_id = torch.npu.current_device() + torch.npu.set_device(self.device_id) - def __init__(self, config: dict[str, Any]): - if not YUANRONG_DATASYSTEM_IMPORTED: - raise ImportError("YuanRong DataSystem not installed.") + self._ds_client = datasystem.DsTensorClient(host, port, self.device_id) + self._ds_client.init() + logger.info("YuanrongStorageClient: Create DsTensorClient to connect with yuanrong-datasystem backend!") - global TORCH_NPU_IMPORTED + @staticmethod + def init(config: dict) -> Optional["StorageStrategy"]: + """Initialize only if NPU and torch_npu are available.""" + torch_npu_imported: bool = True try: import torch_npu # noqa: F401 except ImportError: - TORCH_NPU_IMPORTED = False - - self.host = config.get("host") - self.port = config.get("port") - - self.device_id = None - self._npu_ds_client = None - self._cpu_ds_client = None - - if not TORCH_NPU_IMPORTED: - logger.warning( - "'torch_npu' import failed. " - "It results in the inability to quickly put/get tensors on the NPU side, which may affect performance." - ) - elif not torch.npu.is_available(): - logger.warning( - "NPU is not available. " - "It results in the inability to quickly put/get tensors on the NPU side, which may affect performance." - ) - else: - self.device_id = torch.npu.current_device() - self._npu_ds_client = datasystem.DsTensorClient(self.host, self.port, self.device_id) - self._npu_ds_client.init() - - self._cpu_ds_client = datasystem.KVClient(self.host, self.port) - self._cpu_ds_client.init() - - def npu_ds_client_is_available(self): - """Check if NPU client is available.""" - return self._npu_ds_client is not None + torch_npu_imported = False + enable = config.get("enable_yr_npu_transport", True) + if not (enable and torch_npu_imported and torch.npu.is_available()): + return None + + return NPUTensorKVClientAdapter(config) + + def strategy_tag(self) -> str: + """Strategy tag for NPU tensor storage. Using a single byte is for better performance.""" + return "1" + + def supports_put(self, value: Any) -> bool: + """Supports contiguous NPU tensors only.""" + if not (isinstance(value, torch.Tensor) and value.device.type == "npu"): + return False + # Only contiguous NPU tensors are supported by this adapter. + return value.is_contiguous() + + def put(self, keys: list[str], values: list[Any]): + """Store NPU tensors in batches; deletes before overwrite.""" + for i in range(0, len(keys), self.KEYS_LIMIT): + batch_keys = keys[i : i + self.KEYS_LIMIT] + batch_values = values[i : i + self.KEYS_LIMIT] + # _npu_ds_client.dev_mset doesn't support to overwrite + try: + self._ds_client.dev_delete(batch_keys) + except Exception: + pass + self._ds_client.dev_mset(batch_keys, batch_values) + + def supports_get(self, strategy_tag: str) -> bool: + """Matches 'DsTensorClient' Strategy tag.""" + return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() + + def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + """Fetch NPU tensors using pre-allocated empty buffers.""" + shapes = kwargs.get("shapes", None) + dtypes = kwargs.get("dtypes", None) + if shapes is None or dtypes is None: + raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes") + results = [] + for i in range(0, len(keys), self.KEYS_LIMIT): + batch_keys = keys[i : i + self.KEYS_LIMIT] + batch_shapes = shapes[i : i + self.KEYS_LIMIT] + batch_dtypes = dtypes[i : i + self.KEYS_LIMIT] + + batch_values = self._create_empty_npu_tensorlist(batch_shapes, batch_dtypes) + self._ds_client.dev_mget(batch_keys, batch_values) + # Todo(dpj): consider checking and logging keys that fail during dev_mget + results.extend(batch_values) + return results + + def supports_clear(self, strategy_tag: str) -> bool: + """Matches 'DsTensorClient' strategy tag.""" + return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def cpu_ds_client_is_available(self): - """Check if CPU client is available.""" - return self._cpu_ds_client is not None + def clear(self, keys: list[str]): + """Delete NPU tensor keys in batches.""" + for i in range(0, len(keys), self.KEYS_LIMIT): + batch = keys[i : i + self.KEYS_LIMIT] + # Todo(dpj): Test call clear when no (key,value) put in ds + self._ds_client.dev_delete(batch) def _create_empty_npu_tensorlist(self, shapes, dtypes): """ @@ -179,28 +174,162 @@ def _create_empty_npu_tensorlist(self, shapes, dtypes): list: List of uninitialized NPU tensors """ tensors: list[Tensor] = [] - for shape, dtype in zip(shapes, dtypes, strict=False): + for shape, dtype in zip(shapes, dtypes, strict=True): tensor = torch.empty(shape, dtype=dtype, device=f"npu:{self.device_id}") tensors.append(tensor) return tensors - def mset_zcopy(self, keys: list[str], objs: list[Any]): + +class GeneralKVClientAdapter(StorageStrategy): + """Adapter for general-purpose KV storage with serialization. + Using yr.datasystem.KVClient to connect datasystem backends. + The serialization method uses '_decoder' and '_encoder' from 'transfer_queue.utils.serial_utils'. + """ + + PUT_KEYS_LIMIT: int = 2_000 + GET_CLEAR_KEYS_LIMIT: int = 10_000 + + # Header: number of entries (uint32, little-endian) + HEADER_FMT = " Optional["StorageStrategy"]: + """Always enabled for general objects.""" + return GeneralKVClientAdapter(config) + + def strategy_tag(self) -> str: + """Strategy tag for general KV storage. Using a single byte is for better performance.""" + return "2" + + def supports_put(self, value: Any) -> bool: + """Accepts any Python object.""" + return True + + def put(self, keys: list[str], values: list[Any]): + """Store objects via zero-copy serialization in batches.""" + for i in range(0, len(keys), self.PUT_KEYS_LIMIT): + batch_keys = keys[i : i + self.PUT_KEYS_LIMIT] + batch_vals = values[i : i + self.PUT_KEYS_LIMIT] + self.mset_zero_copy(batch_keys, batch_vals) + + def supports_get(self, strategy_tag: str) -> bool: + """Matches 'KVClient' strategy tag.""" + return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() + + def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: + """Retrieve and deserialize objects in batches.""" + results = [] + for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): + batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] + objects = self.mget_zero_copy(batch_keys) + results.extend(objects) + return results + + def supports_clear(self, strategy_tag: str) -> bool: + """Matches 'KVClient' strategy tag.""" + return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() + + def clear(self, keys: list[str]): + """Delete keys in batches.""" + for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): + batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] + self._ds_client.delete(batch_keys) + + @classmethod + def calc_packed_size(cls, items: list[memoryview]) -> int: + """ + Calculate the total size (in bytes) required to pack a list of memoryview items + into the structured binary format used by pack_into. + + Args: + items: List of memoryview objects to be packed. + + Returns: + Total buffer size in bytes. + """ + return cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + sum(item.nbytes for item in items) + + @classmethod + def pack_into(cls, target: memoryview, items: list[memoryview]): + """ + Pack multiple contiguous buffers into a single buffer. + ┌───────────────┐ + │ item_count │ uint32 + ├───────────────┤ + │ entries │ N * item entries + ├───────────────┤ + │ payload blob │ N * concatenated buffers + └───────────────┘ + + Args: + target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData(). + It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items. + This buffer is usually mapped to shared memory or Zero-Copy memory area. + items (List[memoryview]): List of read-only memory views (e.g., from serialized objects). + Each item must support the buffer protocol and be readable as raw bytes. + + """ + struct.pack_into(cls.HEADER_FMT, target, 0, len(items)) + + entry_offset = cls.HEADER_SIZE + payload_offset = cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + + target_tensor = torch.frombuffer(target, dtype=torch.uint8) + + for item in items: + struct.pack_into(cls.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes) + src_tensor = torch.frombuffer(item, dtype=torch.uint8) + target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor) + entry_offset += cls.ENTRY_SIZE + payload_offset += item.nbytes + + @classmethod + def unpack_from(cls, source: memoryview) -> list[memoryview]: + """ + Unpack multiple contiguous buffers from a single packed buffer. + Args: + source (memoryview): The packed source buffer. + Returns: + list[memoryview]: List of unpacked contiguous buffers. + """ + mv = memoryview(source) + item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0] + offsets = [] + for i in range(item_count): + offset, length = struct.unpack_from(cls.ENTRY_FMT, mv, cls.HEADER_SIZE + i * cls.ENTRY_SIZE) + offsets.append((offset, length)) + return [mv[offset : offset + length] for offset, length in offsets] + + def mset_zero_copy(self, keys: list[str], objs: list[Any]): """Store multiple objects in zero-copy mode using parallel serialization and buffer packing. Args: keys (list[str]): List of string keys under which the objects will be stored. objs (list[Any]): List of Python objects to store (e.g., tensors, strings). """ - assert self._cpu_ds_client is not None, "CPU DS client is not available" items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs] - packed_sizes = [calc_packed_size(items) for items in items_list] - buffers = self._cpu_ds_client.mcreate(keys, packed_sizes) - tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=False)] - with ThreadPoolExecutor(max_workers=DS_MAX_WORKERS) as executor: - list(executor.map(lambda p: pack_into(*p), tasks)) - self._cpu_ds_client.mset_buffer(buffers) - - def mget_zcopy(self, keys: list[str]) -> list[Any]: + packed_sizes = [self.calc_packed_size(items) for items in items_list] + buffers = self._ds_client.mcreate(keys, packed_sizes) + tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=True)] + with ThreadPoolExecutor(max_workers=self.DS_MAX_WORKERS) as executor: + list(executor.map(lambda p: self.pack_into(*p), tasks)) + self._ds_client.mset_buffer(buffers) + + def mget_zero_copy(self, keys: list[str]) -> list[Any]: """Retrieve multiple objects in zero-copy mode by directly deserializing from shared memory buffers. Args: @@ -209,67 +338,40 @@ def mget_zcopy(self, keys: list[str]) -> list[Any]: Returns: list[Any]: List of deserialized objects corresponding to the input keys. """ - assert self._cpu_ds_client is not None, "CPU DS client is not available" - buffers = self._cpu_ds_client.get_buffers(keys) - return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers] + buffers = self._ds_client.get_buffers(keys) + return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers] - def _batch_put(self, keys: list[str], values: list[Any]): - """Stores a batch of key-value pairs to remote storage, splitting by device type. - NPU tensors are sent via DsTensorClient (with higher batch limit), - while all other objects are pickled and sent via KVClient. +@StorageClientFactory.register("YuanrongStorageClient") +class YuanrongStorageClient(TransferQueueStorageKVClient): + """ + Storage client for YuanRong DataSystem. - Args: - keys (List[str]): List of string keys. - values (List[Any]): Corresponding values (tensors or general objects). - """ - if self.npu_ds_client_is_available(): - # Classify NPU and CPU data - npu_keys = [] - npu_values = [] - - cpu_keys = [] - cpu_values = [] - - for key, value in zip(keys, values, strict=True): - if isinstance(value, torch.Tensor) and value.device.type == "npu": - if not value.is_contiguous(): - raise ValueError(f"NPU Tensor is not contiguous: {value}") - npu_keys.append(key) - npu_values.append(value) - - else: - cpu_keys.append(key) - cpu_values.append(pickle.dumps(value)) - - # put NPU data - assert self._npu_ds_client is not None, "NPU DS client is not available" - for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - batch_values = npu_values[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - - # _npu_ds_client.dev_mset doesn't support to overwrite - try: - self._npu_ds_client.dev_delete(batch_keys) - except Exception as e: - logger.warning(f"dev_delete error({e}) before dev_mset") - self._npu_ds_client.dev_mset(batch_keys, batch_values) - - # put CPU data - assert self._cpu_ds_client is not None, "CPU DS client is not available" - for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - batch_values = cpu_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - self.mset_zcopy(batch_keys, batch_values) - - else: - # All data goes through CPU path - for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - batch_vals = values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - self.mset_zcopy(batch_keys, batch_vals) - - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + Use different storage strategies depending on the data type. + Supports storing and fetching both: + - NPU tensors via NPUTensorKVClientAdapter (for high performance). + - General objects (CPU tensors, str, bool, list, etc.) via GeneralKVClientAdapter with serialization. + """ + + def __init__(self, config: dict[str, Any]): + if not YUANRONG_DATASYSTEM_IMPORTED: + raise ImportError("YuanRong DataSystem not installed.") + + super().__init__(config) + + # Storage strategies are prioritized in ascending order of list element index. + # In other words, the later in the order, the lower the priority. + storage_strategies_priority = [NPUTensorKVClientAdapter, GeneralKVClientAdapter] + self._strategies: list[StorageStrategy] = [] + for strategy_cls in storage_strategies_priority: + strategy = strategy_cls.init(config) + if strategy is not None: + self._strategies.append(strategy) + + if not self._strategies: + raise RuntimeError("No storage strategy available for YuanrongStorageClient") + + def put(self, keys: list[str], values: list[Any]) -> list[str]: """Stores multiple key-value pairs to remote storage. Automatically routes NPU tensors to high-performance tensor storage, @@ -278,101 +380,29 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: Args: keys (List[str]): List of unique string identifiers. values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). + + Returns: + List[str]: storage strategy tag of YuanrongStorageClient in the same order as input keys. """ if not isinstance(keys, list) or not isinstance(values, list): raise ValueError("keys and values must be lists") if len(keys) != len(values): raise ValueError("Number of keys must match number of values") - self._batch_put(keys, values) - return None - def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: - """Retrieves a batch of values from remote storage using expected metadata. + routed_indexes = self._route_to_strategies(values, lambda strategy_, item_: strategy_.supports_put(item_)) - NPU tensors are fetched via DsTensorClient using pre-allocated buffers. - Other objects are fetched via KVClient and unpickled. + # Define the 'put_task': Slicing the input list and calling the backend strategy. + # The closure captures local 'keys' and 'values' for zero-overhead parameter passing. + def put_task(strategy, indexes): + strategy.put([keys[i] for i in indexes], [values[i] for i in indexes]) + return strategy.strategy_tag(), indexes - Args: - keys (List[str]): Keys to fetch. - shapes (List[List[int]]): Expected shapes for each key (empty list for scalars). - dtypes (List[Optional[torch.dtype]]): Expected dtypes; None indicates non-tensor data. - - Returns: - List[Any]: Retrieved values in the same order as input keys. - """ - - if self.npu_ds_client_is_available(): - # classify npu and cpu queries - npu_indices = [] - npu_keys = [] - npu_shapes = [] - npu_dtypes = [] - - cpu_indices = [] - cpu_keys = [] - - for idx, (key, shape, dtype) in enumerate(zip(keys, shapes, dtypes, strict=False)): - if dtype is not None: - npu_indices.append(idx) - npu_keys.append(key) - npu_shapes.append(shape) - npu_dtypes.append(dtype) - else: - cpu_indices.append(idx) - cpu_keys.append(key) - - results = [None] * len(keys) - - # Fetch NPU tensors - assert self._npu_ds_client is not None, "NPU DS client is not available" - for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = npu_keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - batch_shapes = npu_shapes[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - batch_dtypes = npu_dtypes[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - batch_indices = npu_indices[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - - batch_values = self._create_empty_npu_tensorlist(batch_shapes, batch_dtypes) - failed_subkeys = [] - try: - failed_subkeys = self._npu_ds_client.dev_mget(batch_keys, batch_values) - # failed_keys = f'{key},{npu_device_id}' - failed_subkeys = [f_key.rsplit(",", 1)[0] for f_key in failed_subkeys] - except Exception: - failed_subkeys = batch_keys - - # Fill successfully retrieved tensors - failed_set = set(failed_subkeys) - for idx, key, value in zip(batch_indices, batch_keys, batch_values, strict=False): - if key not in failed_set: - results[idx] = value - - # Add failed keys to CPU fallback queue - if failed_subkeys: - cpu_keys.extend(failed_subkeys) - cpu_indices.extend([batch_indices[j] for j, k in enumerate(batch_keys) if k in failed_set]) - - # Fetch CPU/general objects (including NPU fallbacks) - assert self._cpu_ds_client is not None, "CPU DS client is not available" - for i in range(0, len(cpu_keys), CPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = cpu_keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - objects = self.mget_zcopy(batch_keys) - for idx, obj in zip(batch_indices, objects, strict=False): - results[idx] = obj - - return results - - else: - results = [None] * len(keys) - cpu_indices = list(range(len(keys))) - - for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): - batch_keys = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - objects = self.mget_zcopy(batch_keys) - for idx, obj in zip(batch_indices, objects, strict=False): - results[idx] = obj - return results + # Dispatch tasks and map strategy_tag back to original positions + strategy_tags: list[str] = [""] * len(keys) + for tag, indexes in self._dispatch_tasks(routed_indexes, put_task): + for original_index in indexes: + strategy_tags[original_index] = tag + return strategy_tags def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. @@ -383,48 +413,123 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key + custom_backend_meta (List[str]): StorageStrategy tag for each key Returns: List[Any]: Retrieved values in the same order as input keys. """ - if shapes is None or dtypes is None: - raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes") - if not (len(keys) == len(shapes) == len(dtypes)): - raise ValueError("Lengths of keys, shapes, dtypes must match") - return self._batch_get(keys, shapes, dtypes) + if shapes is None or dtypes is None or custom_backend_meta is None: + raise ValueError("YuanrongStorageClient.get() needs Expected shapes, dtypes and custom_backend_meta") - def _batch_clear(self, keys: list[str]): - """Deletes a batch of keys from remote storage. + if not (len(keys) == len(shapes) == len(dtypes) == len(custom_backend_meta)): + raise ValueError("Lengths of keys, shapes, dtypes, custom_backend_meta must match") - Attempts deletion via NPU client first (if available), then falls back to CPU client - for any keys not handled by NPU. + strategy_tags = custom_backend_meta + routed_indexes = self._route_to_strategies( + strategy_tags, lambda strategy_, item_: strategy_.supports_get(item_) + ) - Args: - keys (List[str]): Keys to delete. - """ - if self.npu_ds_client_is_available(): - assert self._npu_ds_client is not None, "NPU DS client is not available" - assert self._cpu_ds_client is not None, "CPU DS client is not available" - # Try to delete all keys via npu client - for i in range(0, len(keys), NPU_DS_CLIENT_KEYS_LIMIT): - batch = keys[i : i + NPU_DS_CLIENT_KEYS_LIMIT] - # Return the keys that failed to delete - self._npu_ds_client.dev_delete(batch) - # Delete failed keys via CPU client - for j in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): - sub_batch = keys[j : j + CPU_DS_CLIENT_KEYS_LIMIT] - self._cpu_ds_client.delete(sub_batch) - else: - assert self._cpu_ds_client is not None, "CPU DS client is not available" - for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT): - batch = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT] - self._cpu_ds_client.delete(batch) + # Define the 'get_task': handles slicing of keys, shapes, and dtypes simultaneously. + def get_task(strategy, indexes): + res = strategy.get( + [keys[i] for i in indexes], shapes=[shapes[i] for i in indexes], dtypes=[dtypes[i] for i in indexes] + ) + return res, indexes - def clear(self, keys: list[str]): + # Gather results and restore original order + results = [None] * len(keys) + for strategy_res, indexes in self._dispatch_tasks(routed_indexes, get_task): + for value, original_index in zip(strategy_res, indexes, strict=True): + results[original_index] = value + return results + + def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from remote storage. Args: keys (List[str]): List of keys to remove. + custom_backend_meta (List[str]): StorageStrategy tag for each key + """ + if not isinstance(keys, list) or not isinstance(custom_backend_meta, list): + raise ValueError("keys and custom_backend_meta must be a list") + + if len(custom_backend_meta) != len(keys): + raise ValueError("custom_backend_meta length must match keys") + + strategy_tags = custom_backend_meta + routed_indexes = self._route_to_strategies( + strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_) + ) + + def clear_task(strategy, indexes): + strategy.clear([keys[i] for i in indexes]) + + # Execute deletions (no return values needed) + self._dispatch_tasks(routed_indexes, clear_task) + + def _route_to_strategies( + self, + items: list[Any], + selector: Callable[[StorageStrategy, Any], bool], + ) -> dict[StorageStrategy, list[int]]: + """Groups item indices by the first strategy that supports them. + + Used to route data to storage strategies by grouped indexes. + + Args: + items: A list used to distinguish which storage strategy the data is routed to. + e.g., route for put based on types of values, + or route for get/clear based on strategy_tags. + The order must correspond to the original keys. + selector: A function that determines whether a strategy supports an item. + Signature: `(strategy: StorageStrategy, item: Any) -> bool`. + + Returns: + A dictionary mapping each active strategy to a list of indexes in `items` + that it should handle. Every index appears exactly once. + """ + routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies} + for i, item in enumerate(items): + for strategy in self._strategies: + if selector(strategy, item): + routed_indexes[strategy].append(i) + break + else: + raise ValueError( + f"No strategy supports item of type {type(item).__name__}: {item}. " + f"Available strategies: {[type(s).__name__ for s in self._strategies]}" + ) + return routed_indexes + + @staticmethod + def _dispatch_tasks(routed_tasks: dict[StorageStrategy, list[int]], task_function: Callable) -> list[Any]: + """Executes tasks across one or more storage strategies, optionally in parallel. + + Optimizes for common case: if only one strategy is involved, runs synchronously + to avoid thread overhead. Otherwise, uses thread pool for concurrency. + + Args: + routed_tasks: Mapping from strategy to list of indexes it should process. + task_function: Callable accepting `(strategy, list_of_indexes)` and returning any result. + + Returns: + List of results from `task_function`, one per active strategy, in arbitrary order. + Each result typically includes data and the corresponding indices for reassembly. """ - self._batch_clear(keys) + active_tasks = [(strategy, indexes) for strategy, indexes in routed_tasks.items() if indexes] + + if not active_tasks: + return [] + + # Fast path: single strategy → avoid threading + if len(active_tasks) == 1: + return [task_function(*active_tasks[0])] + + # Parallel path: overlap NPU and CPU operations + # Cap the number of worker threads to avoid resource exhaustion if many + # strategies are added in the future. + max_workers = min(len(active_tasks), 4) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # futures' results are from task_function + futures = [executor.submit(task_function, strategy, indexes) for strategy, indexes in active_tasks] + return [f.result() for f in futures] diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 2c76e91..c07ec1f 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -555,6 +555,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: keys = self._generate_keys(data.keys(), metadata.global_indexes) values = self._generate_values(data) loop = asyncio.get_event_loop() + + # put to storage backends custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes: dict[int, dict[str, Any]] = {} @@ -632,4 +634,5 @@ async def clear_data(self, metadata: BatchMeta) -> None: logger.warning("Attempted to clear data, but metadata contains no fields.") return keys = self._generate_keys(metadata.field_names, metadata.global_indexes) - self.storage_client.clear(keys=keys) + _, _, custom_meta = self._get_shape_type_custom_backend_meta_list(metadata) + self.storage_client.clear(keys=keys, custom_backend_meta=custom_meta)