Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3338e37
Renamed 'test_yuanrong_storage_manager.py' to 'test_yuanrong_storage_…
dpj135 Jan 28, 2026
146cbd2
Added abstract interface 'StorageStrategy'
dpj135 Jan 28, 2026
4d184b8
Added DsTensorClient
dpj135 Jan 28, 2026
d5ec724
Added 'KVClientAdapter'
dpj135 Jan 28, 2026
28320da
Refactored 'YuanrongStorageClient.put&get'
dpj135 Jan 28, 2026
8c8e9a2
Added 'route_to_strategy' to class 'YuanrongStorageClient' & Adjust t…
dpj135 Jan 29, 2026
4d8ae64
Fixed the order about '@staticmethod' and 'abstractmethod' (Now yuanr…
dpj135 Jan 29, 2026
a9235e6
Added custom_meta to clear for all TransferQueueKVStorageClient
dpj135 Jan 29, 2026
ffa5970
Added multi-threads optimization to 'put/get/clear' of 'YuanrongStora…
dpj135 Jan 30, 2026
39c72bb
Added more annotation for methods
dpj135 Jan 30, 2026
ed07b61
Added end-to-end test(generated by AI) for 'YuanrongStorageClient'
dpj135 Jan 30, 2026
91f2532
Fixed tests about yuanrong_clint
dpj135 Jan 30, 2026
2faa9eb
Added license to test_yuanrong_client
dpj135 Jan 30, 2026
7531714
Added method 'test_mock_can_work' to test_yuanrong_client
dpj135 Jan 30, 2026
a8293ef
Added an annotation to class 'StorageStrategy'
dpj135 Jan 30, 2026
e120eb2
Fixed test_yuanrong_client_zero_copy
dpj135 Jan 30, 2026
c0fc536
Apply suggestions from code review
dpj135 Feb 2, 2026
c196940
Renamed adapter classes & rename 'custom_meta()' to 'strategy_tag()' …
dpj135 Feb 2, 2026
50c5284
Fixed 'KVClientAdapter' imported error
dpj135 Feb 2, 2026
a7303cf
Modified docstrings
dpj135 Feb 2, 2026
1fd53f4
Fixed 'test_yuanrong_storage_client_e2e.py' about strategy_tag
dpj135 Feb 2, 2026
7bb83a8
Adjusted annotations of test_yuanrong_storage_client_e2e.py
dpj135 Feb 3, 2026
9fac913
Fixed reviews about yuanrong_client(modified strategy_tag, rename cus…
dpj135 Feb 3, 2026
20ae39b
Rename custom_meta to custom_backend_meta
dpj135 Feb 3, 2026
8f3417c
Modified annotations about clients
dpj135 Feb 3, 2026
4a3be0f
Adjusted expression of annotations and renamed one variable
dpj135 Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
sys.path.append(str(parent_dir))

from transfer_queue.storage.clients.yuanrong_client import ( # noqa: E402
YuanrongStorageClient,
GeneralKVClientAdapter,
)


Expand All @@ -37,21 +37,20 @@ def MutableData(self):
return self.data


class TestYuanrongStorageZCopy:
class TestYuanrongKVClientZCopy:
@pytest.fixture
def mock_kv_client(self, mocker):
mock_client = MagicMock()
mock_client.init.return_value = None

mocker.patch("yr.datasystem.KVClient", return_value=mock_client)
mocker.patch("yr.datasystem.DsTensorClient")
mocker.patch("transfer_queue.storage.clients.yuanrong_client.TORCH_NPU_IMPORTED", False)

return mock_client

@pytest.fixture
def storage_client(self, mock_kv_client):
return YuanrongStorageClient({"host": "127.0.0.1", "port": 31501})
return GeneralKVClientAdapter({"host": "127.0.0.1", "port": 31501})

def test_mset_mget_p2p(self, storage_client, mocker):
# Mock serialization/deserialization
Expand Down Expand Up @@ -80,13 +79,13 @@ def side_effect_mcreate(keys, sizes):
stored_raw_buffers.append(b.MutableData())
return buffers

storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate
storage_client._cpu_ds_client.get_buffers.return_value = stored_raw_buffers
storage_client._ds_client.mcreate.side_effect = side_effect_mcreate
storage_client._ds_client.get_buffers.return_value = stored_raw_buffers

storage_client.mset_zcopy(
storage_client.mset_zero_copy(
["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"]
)
results = storage_client.mget_zcopy(["tensor_key", "string_key"])
results = storage_client.mget_zero_copy(["tensor_key", "string_key"])

assert torch.allclose(results[0], torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))
assert results[1] == "hello yuanrong"
214 changes: 214 additions & 0 deletions tests/test_yuanrong_storage_client_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from unittest import mock

import pytest
import torch

try:
import torch_npu # noqa: F401
except ImportError:
pass


# --- Mock Backend Implementation ---
# In real scenarios, multiple DsTensorClients or KVClients share storage.
# Here, each mockClient is implemented with independent storage using a simple dictionary,
# and is only suitable for unit testing.


class MockDsTensorClient:
def __init__(self, host, port, device_id):
self.storage = {}

def init(self):
pass

def dev_mset(self, keys, values):
for k, v in zip(keys, values, strict=True):
assert v.device.type == "npu"
self.storage[k] = v

def dev_mget(self, keys, out_tensors):
for i, k in enumerate(keys):
# Note: If key is missing, tensor remains unchanged (mock limitation)
if k in self.storage:
out_tensors[i].copy_(self.storage[k])

def dev_delete(self, keys):
for k in keys:
self.storage.pop(k, None)


class MockKVClient:
def __init__(self, host, port):
self.storage = {}

def init(self):
pass

def mcreate(self, keys, sizes):
class MockBuffer:
def __init__(self, size):
self._data = bytearray(size)

def MutableData(self):
return memoryview(self._data)

self._current_keys = keys
return [MockBuffer(s) for s in sizes]

def mset_buffer(self, buffers):
for key, buf in zip(self._current_keys, buffers, strict=True):
self.storage[key] = bytes(buf.MutableData())

def get_buffers(self, keys):
return [memoryview(self.storage[k]) if k in self.storage else None for k in keys]

def delete(self, keys):
for k in keys:
self.storage.pop(k, None)


# --- Fixtures ---


@pytest.fixture
def mock_yr_datasystem():
"""Wipe real 'yr' modules and inject mocks."""

# 1. Clean up sys.modules to force a fresh import under mock conditions
# This ensures top-level code in yuanrong_client.py is re-executed
to_delete = [k for k in sys.modules if k.startswith("yr")]
for mod in to_delete:
del sys.modules[mod]

# 2. Setup Mock Objects
ds_mock = mock.MagicMock()
ds_mock.DsTensorClient = MockDsTensorClient
ds_mock.KVClient = MockKVClient

yr_mock = mock.MagicMock(datasystem=ds_mock)

# 3. Apply patches
# - sys.modules: Redirects 'import yr' to our mocks
# - YUANRONG_DATASYSTEM_IMPORTED: Forces the existence check to True so initialize the client successfully
# - datasystem: Direct attribute patch for the module
with (
mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}),
mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True),
mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock),
):
yield


@pytest.fixture
def config():
return {"host": "127.0.0.1", "port": 12345, "enable_yr_npu_optimization": True}


def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor):
assert a.shape == b.shape and a.dtype == b.dtype
# Move to CPU for cross-device comparison
assert torch.equal(a.cpu(), b.cpu())


# --- Test Suite ---


class TestYuanrongStorageE2E:
@pytest.fixture(autouse=True)
def setup_client(self, mock_yr_datasystem, config):
# Lazy import to ensure mocks are active
from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient

self.client_cls = YuanrongStorageClient
self.config = config

def _create_data(self, mode="cpu"):
if mode == "cpu":
keys = ["t", "s", "i"]
vals = [torch.randn(2), "hi", 1]
elif mode == "npu":
if not (hasattr(torch, "npu") and torch.npu.is_available()):
pytest.skip("NPU required")
keys = ["n1", "n2"]
vals = [torch.randn(2).npu(), torch.tensor([1]).npu()]
else: # mixed
if not (hasattr(torch, "npu") and torch.npu.is_available()):
pytest.skip("NPU required")
keys = ["n1", "c1"]
vals = [torch.randn(2).npu(), "cpu"]

shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals]
dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals]
return keys, vals, shapes, dtypes

def test_mock_can_work(self, config):
mock_class = (MockDsTensorClient, MockKVClient)
client = self.client_cls(config)
for strategy in client._strategies:
assert isinstance(strategy._ds_client, mock_class)

def test_cpu_only_flow(self, config):
client = self.client_cls(config)
keys, vals, shp, dt = self._create_data("cpu")

# Put & Verify Meta
meta = client.put(keys, vals)
# "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path.
assert all(m == "2" for m in meta)

# Get & Verify Values
ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
if isinstance(o, torch.Tensor):
assert_tensors_equal(o, r)
else:
assert o == r

# Clear & Verify
client.clear(keys, meta)
assert all(v is None for v in client.get(keys, shp, dt, meta))

def test_npu_only_flow(self, config):
keys, vals, shp, dt = self._create_data("npu")
client = self.client_cls(config)

meta = client.put(keys, vals)
# "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path.
assert all(m == "1" for m in meta)

ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
assert_tensors_equal(o, r)

client.clear(keys, meta)

def test_mixed_flow(self, config):
keys, vals, shp, dt = self._create_data("mixed")
client = self.client_cls(config)

meta = client.put(keys, vals)
assert set(meta) == {"1", "2"}

ret = client.get(keys, shp, dt, meta)
for o, r in zip(vals, ret, strict=True):
if isinstance(o, torch.Tensor):
assert_tensors_equal(o, r)
else:
assert o == r
2 changes: 1 addition & 1 deletion transfer_queue/storage/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
raise NotImplementedError("Subclasses must implement get")

@abstractmethod
def clear(self, keys: list[str]) -> None:
def clear(self, keys: list[str], custom_backend_meta=None) -> None:
"""Clear key-value pairs in the storage backend."""
raise NotImplementedError("Subclasses must implement clear")
5 changes: 3 additions & 2 deletions transfer_queue/storage/clients/mooncake_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
keys (List[str]): Keys to fetch.
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key
custom_backend_meta (List[str], optional): ...
Returns:
List[Any]: Retrieved values in the same order as input keys.
Expand Down Expand Up @@ -216,11 +216,12 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]:
results.extend(batch_results)
return results

def clear(self, keys: list[str]):
def clear(self, keys: list[str], custom_backend_meta=None):
"""Deletes multiple keys from MooncakeStore.
Args:
keys (List[str]): List of keys to remove.
custom_backend_meta (List[Any], optional): ...
"""
for key in keys:
ret = self._store.remove(key)
Expand Down
3 changes: 2 additions & 1 deletion transfer_queue/storage/clients/ray_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
raise RuntimeError(f"Failed to retrieve value for key '{keys}': {e}") from e
return values

def clear(self, keys: list[str]):
def clear(self, keys: list[str], custom_backend_meta=None):
"""
Delete entries from storage by keys.
Args:
keys (list): List of keys to delete
custom_backend_meta (List[Any], optional): ...
"""
ray.get(self.storage_actor.clear_obj_ref.remote(keys))
Loading