Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/tutorial-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tutorial check

on:
push:
branches:
- main
- v0.*
pull_request:
branches:
- main
- v0.*

jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -e ".[yuanrong]"
- name: Run tutorials
run: |
export TQ_NUM_THREADS=2
for file in tutorial/*.py; do python3 "$file"; done
51 changes: 51 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def _handle_requests(self):
"partition_ids": ["partition_0", "partition_1", "test_partition"],
}
response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
response_body = {"message": "success"}
response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
else:
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
Expand Down Expand Up @@ -774,3 +777,51 @@ async def test_sync_and_async_methods_mixed_usage(client_setup):
assert async_data is not None

print("✓ Mixed async and sync method calls work correctly")


# =====================================================
# Custom Meta Interface Tests
# =====================================================


class TestClientCustomMetaInterface:
"""Tests for client custom_meta interface methods."""

def test_set_custom_meta_sync(self, client_setup):
"""Test synchronous set_custom_meta method."""
client, _, _ = client_setup

# Test synchronous set_custom_meta

# First get metadata
metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
{
0: {"input_ids": {"token_count": 100}},
1: {"input_ids": {"token_count": 120}},
}
)

# Call set_custom_meta with metadata (BatchMeta)
client.set_custom_meta(metadata)
print("✓ set_custom_meta sync method works")

@pytest.mark.asyncio
async def test_set_custom_meta_async(self, client_setup):
"""Test asynchronous async_set_custom_meta method."""
client, _, _ = client_setup

# First get metadata
metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
# Set custom_meta on the metadata
metadata.update_custom_meta(
{
0: {"input_ids": {"token_count": 100}},
1: {"input_ids": {"token_count": 120}},
}
)

# Call async_set_custom_meta with metadata (BatchMeta)
await client.async_set_custom_meta(metadata)
print("✓ async_set_custom_meta async method works")
155 changes: 154 additions & 1 deletion tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_controller_with_single_partition(self, ray_setup):
field_names=metadata.field_names,
dtypes=dtypes,
shapes=shapes,
custom_meta=None,
custom_backend_meta=None,
)
)
assert success
Expand Down Expand Up @@ -450,3 +450,156 @@ def test_controller_clear_meta(self, ray_setup):
assert set(partition_after.global_indexes) == set([4, 5, 7])

print("✓ Clear meta correct")


class TestTransferQueueControllerCustomMeta:
"""Integration tests for TransferQueueController custom_meta and custom_backend_meta methods.

Note: In this codebase:
- custom_meta: per-sample metadata (simple key-value pairs per sample)
- custom_backend_meta: per-sample per-field metadata (stored via update_production_status)
"""

def test_controller_with_custom_meta(self, ray_setup):
"""Test TransferQueueController with custom_backend_meta and custom_meta functionality"""

batch_size = 3
partition_id = "custom_meta_test"

tq_controller = TransferQueueController.remote()

# Create metadata in insert mode
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
mode="insert",
)
)

assert metadata.global_indexes == list(range(batch_size))

# Build custom_backend_meta (per-sample per-field metadata)
custom_backend_meta = {
0: {"prompt_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}},
1: {"prompt_ids": {"token_count": 120}, "attention_mask": {"mask_ratio": 0.15}},
2: {"prompt_ids": {"token_count": 90}, "attention_mask": {"mask_ratio": 0.12}},
}

# Update production status with custom_backend_meta
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_names=metadata.field_names,
dtypes=dtypes,
shapes=shapes,
custom_backend_meta=custom_backend_meta,
)
)
assert success

# Get partition snapshot and verify custom_backend_meta is stored
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert partition is not None

# Verify custom_backend_meta via get_field_custom_backend_meta
result = partition.get_field_custom_backend_meta(list(range(batch_size)), ["prompt_ids", "attention_mask"])
assert len(result) == batch_size
assert result[0]["prompt_ids"]["token_count"] == 100
assert result[2]["attention_mask"]["mask_ratio"] == 0.12

print("✓ Controller set custom_backend_meta via update_production_status correct")

# Now set custom_meta (per-sample metadata)
# Format: {partition_id: {global_index: custom_meta_dict}}
custom_meta = {
partition_id: {
0: {"sample_score": 0.9, "quality": "high"},
1: {"sample_score": 0.8, "quality": "medium"},
# You can set partial samples with custom_meta.
}
}

# Verify set_custom_meta method exists and can be called
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))

# Verify via partition snapshot
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
result = partition.get_custom_meta([0, 1])
assert 0 in result
assert result[0]["sample_score"] == 0.9
assert result[0]["quality"] == "high"
assert 1 in result
assert result[1]["sample_score"] == 0.8
assert 2 not in result

# Init another partition
new_partition_id = "custom_meta_test2"
# Create metadata in insert mode
data_fields = ["prompt_ids", "attention_mask"]
new_metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=batch_size,
partition_id=new_partition_id,
mode="insert",
)
)

# Update production status
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in new_metadata.global_indexes}
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in new_metadata.global_indexes}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=new_partition_id,
global_indexes=new_metadata.global_indexes,
field_names=new_metadata.field_names,
dtypes=dtypes,
shapes=shapes,
custom_backend_meta=None,
)
)
assert success

# Provide complicated case: update custom_meta with mixed partitions, and update previous custom_meta
new_custom_meta = {
new_partition_id: {
3: {"sample_score": 1, "quality": "high"},
4: {"sample_score": 0, "quality": "low"},
},
partition_id: {
2: {"sample_score": 0.7, "quality": "high"},
0: {"sample_score": 0.001, "quality": "low"},
},
}

# update with new_custom_meta
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=new_custom_meta))

# Verify via partition snapshot
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
result = partition.get_custom_meta([0, 1, 2])
assert 0 in result
assert result[0]["sample_score"] == 0.001 # updated!
assert result[0]["quality"] == "low" # updated!
assert 1 in result # unchanged
assert result[1]["sample_score"] == 0.8 # unchanged
assert 2 in result # unchanged
assert result[2]["sample_score"] == 0.7 # new

new_partition = ray.get(tq_controller.get_partition_snapshot.remote(new_partition_id))
result = new_partition.get_custom_meta([3, 4, 5])
assert 3 in result
assert result[3]["sample_score"] == 1
assert result[3]["quality"] == "high"
assert 4 in result
assert result[4]["sample_score"] == 0
assert 5 not in result # 5 has no custom_meta, it will not return even we retrieve for 5

# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))
Loading