Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False)
- name: Test with pytest
run: |
pytest
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True)
run: |
ray stop --force
export TQ_ZERO_COPY_SERIALIZATION=True
pytest
163 changes: 61 additions & 102 deletions scripts/put_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@
import sys
import time
from pathlib import Path

import numpy as np
import ray
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict
from tensordict.utils import LinkedList


parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import (
from transfer_queue import ( # noqa: E402
AsyncTransferQueueClient,
SimpleStorageUnit,
TransferQueueController,
process_zmq_server_info,
)
from transfer_queue.utils.utils import get_placement_group
from transfer_queue.utils.utils import get_placement_group # noqa: E402

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -33,48 +33,13 @@
# Configuration Map
# =========================================================
CONFIG_MAP = {
"debug": {
"global_batch_size": 32,
"seq_length": 128,
"field_num": 2,
"desc": "Debug (~32KB)"
},
"tiny": {
"global_batch_size": 64,
"seq_length": 1024,
"field_num": 4,
"desc": "Tiny (~1MB)"
},
"small": {
"global_batch_size": 512,
"seq_length": 12800,
"field_num": 4,
"desc": "Small (~100MB)"
},
"medium": {
"global_batch_size": 1024,
"seq_length": 65536,
"field_num": 4,
"desc": "Medium (~1GB)"
},
"large": {
"global_batch_size": 2048,
"seq_length": 128000,
"field_num": 5,
"desc": "Large (~5GB)"
},
"xlarge": {
"global_batch_size": 4096,
"seq_length": 128000,
"field_num": 5,
"desc": "X-Large (~10GB)"
},
"huge": {
"global_batch_size": 4096,
"seq_length": 128000,
"field_num": 10,
"desc": "Huge (~20GB)"
}
"debug": {"global_batch_size": 32, "seq_length": 128, "field_num": 2, "desc": "Debug (~32KB)"},
"tiny": {"global_batch_size": 64, "seq_length": 1024, "field_num": 4, "desc": "Tiny (~1MB)"},
"small": {"global_batch_size": 512, "seq_length": 12800, "field_num": 4, "desc": "Small (~100MB)"},
"medium": {"global_batch_size": 1024, "seq_length": 65536, "field_num": 4, "desc": "Medium (~1GB)"},
"large": {"global_batch_size": 2048, "seq_length": 128000, "field_num": 5, "desc": "Large (~5GB)"},
"xlarge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 5, "desc": "X-Large (~10GB)"},
"huge": {"global_batch_size": 4096, "seq_length": 128000, "field_num": 10, "desc": "Huge (~20GB)"},
}


Expand All @@ -89,7 +54,7 @@ def calculate_stats(data: list) -> dict:
"mean": float(np.mean(data)),
"max": float(np.max(data)),
"min": float(np.min(data)),
"p99": float(np.percentile(data, 99))
"p99": float(np.percentile(data, 99)),
}


Expand Down Expand Up @@ -119,26 +84,26 @@ def _generate_nested_tensor(batch_size, total_elements, dtype):
# Use Dirichlet distribution to generate random proportions summing to 1
proportions = np.random.dirichlet(np.ones(batch_size))
lengths = (proportions * total_elements).astype(int)

# Fix rounding errors to ensure exact total element count
diff = total_elements - lengths.sum()
if diff != 0:
# Distribute difference to largest elements
indices = np.argsort(lengths)[::-1]
for i in range(abs(diff)):
lengths[indices[i % batch_size]] += 1 if diff > 0 else -1

# Ensure each length is at least 1
lengths = np.maximum(lengths, 1)

# Generate tensors with different lengths
tensors = []
for length in lengths:
if dtype in (torch.int32, torch.int64):
tensors.append(torch.randint(0, 10000, (int(length),), dtype=dtype))
else:
tensors.append(torch.randn(int(length), dtype=dtype))

return torch.nested.nested_tensor(tensors, dtype=dtype)


Expand Down Expand Up @@ -168,7 +133,7 @@ def create_complex_test_case(batch_size, seq_length, field_num):
fields[field_name] = tensor_data
total_size_bytes += total_elements_per_field * bytes_per_elem

total_size_gb = total_size_bytes / (1024 ** 3)
total_size_gb = total_size_bytes / (1024**3)

prompt_batch = TensorDict(
fields,
Expand All @@ -190,18 +155,18 @@ def _compare_nested_tensors(original, retrieved, path):
# Unbind to list for element-wise comparison
orig_tensors = original.unbind()
retr_tensors = retrieved.unbind()

if len(orig_tensors) != len(retr_tensors):
return False, f"[{path}] NestedTensor batch size mismatch: {len(orig_tensors)} vs {len(retr_tensors)}"
for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors)):

for idx, (orig, retr) in enumerate(zip(orig_tensors, retr_tensors, strict=False)):
if orig.shape != retr.shape:
return False, f"[{path}][{idx}] Shape mismatch: {orig.shape} vs {retr.shape}"
if orig.dtype != retr.dtype:
return False, f"[{path}][{idx}] Dtype mismatch: {orig.dtype} vs {retr.dtype}"
if not torch.equal(orig.cpu(), retr.cpu()):
return False, f"[{path}][{idx}] Values mismatch"

return True, "Passed"


Expand All @@ -214,8 +179,8 @@ def check_data_consistency(original, retrieved, path="root"):
retrieved = list(retrieved)

# NestedTensor check (must be before regular Tensor since NestedTensor is also a Tensor)
if original.is_nested if hasattr(original, 'is_nested') else False:
if not (retrieved.is_nested if hasattr(retrieved, 'is_nested') else False):
if original.is_nested if hasattr(original, "is_nested") else False:
if not (retrieved.is_nested if hasattr(retrieved, "is_nested") else False):
return False, f"[{path}] Type mismatch: NestedTensor vs non-NestedTensor"
return _compare_nested_tensors(original, retrieved, path)

Expand Down Expand Up @@ -255,10 +220,11 @@ def check_data_consistency(original, retrieved, path="root"):
# Core Tester Class
# =========================================================


def sync_stage(flag_to_create, flag_to_wait):
"""Profile sync helper function for synchronizing with external profiler process"""
with open(flag_to_create, 'w') as f:
f.write('1')
with open(flag_to_create, "w") as f:
f.write("1")
while not os.path.exists(flag_to_wait):
time.sleep(0.05)
try:
Expand All @@ -283,12 +249,14 @@ def __init__(self, target_ip=None, storage_units=8, enable_profile=False):
def initialize_system(self, config_dict):
"""Initialize TransferQueue system based on current configuration"""
# Basic config conversion
self.tq_config = OmegaConf.create({
"global_batch_size": config_dict["global_batch_size"],
"num_global_batch": 1,
"num_data_storage_units": self.num_storage_units,
"num_data_controllers": 1
})
self.tq_config = OmegaConf.create(
{
"global_batch_size": config_dict["global_batch_size"],
"num_global_batch": 1,
"num_data_storage_units": self.num_storage_units,
"num_data_controllers": 1,
}
)

total_storage_size = self.tq_config.global_batch_size * 2

Expand All @@ -301,9 +269,7 @@ def initialize_system(self, config_dict):
num_cpus=1,
resources={f"node:{self.target_ip}": 0.001},
runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}},
).remote(
storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)
)
).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units))
else:
# Local Mode: Use placement group
self.storage_placement_group = get_placement_group(self.num_storage_units, num_cpus_per_actor=2)
Expand All @@ -312,9 +278,7 @@ def initialize_system(self, config_dict):
placement_group=self.storage_placement_group,
placement_group_bundle_index=rank,
runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}},
).remote(
storage_unit_size=math.ceil(total_storage_size / self.num_storage_units)
)
).remote(storage_unit_size=math.ceil(total_storage_size / self.num_storage_units))

# Controller Init
self.data_system_controller = TransferQueueController.remote()
Expand All @@ -331,11 +295,11 @@ def initialize_system(self, config_dict):

# Client Init
self.data_system_client = AsyncTransferQueueClient(
client_id='Trainer',
controller_info=self.data_system_controller_info
client_id="Trainer", controller_info=self.data_system_controller_info
)
self.data_system_client.initialize_storage_manager(
manager_type="AsyncSimpleStorageManager", config=self.tq_config
)
self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager",
config=self.tq_config)

return self.data_system_client

Expand All @@ -361,6 +325,7 @@ def cleanup(self):

# 4. Force garbage collection
import gc

gc.collect()

# 5. Wait for Ray scheduler to update state (prevent race condition)
Expand All @@ -370,9 +335,7 @@ def run_benchmark_rounds(self, config_name, config, rounds):
"""Run multiple rounds of PUT/GET bandwidth tests"""
logger.info(f"Generating test data [{config_name}]...")
big_input_ids, total_gb = create_complex_test_case(
batch_size=config["global_batch_size"],
seq_length=config["seq_length"],
field_num=config["field_num"]
batch_size=config["global_batch_size"], seq_length=config["seq_length"], field_num=config["field_num"]
)
logger.info(f"Data Size: {total_gb:.4f} GB")

Expand All @@ -387,27 +350,29 @@ def run_benchmark_rounds(self, config_name, config, rounds):
# PUT operation
start_put = time.time()
if i == 0 and self.enable_profile:
sync_stage('init_ready.flag', 'put_start.flag')
sync_stage("init_ready.flag", "put_start.flag")
asyncio.run(self.data_system_client.async_put(data=big_input_ids, partition_id=partition_key))
put_time = time.time() - start_put
if i == 0 and self.enable_profile:
sync_stage('put_done.flag', 'get_prepare.flag')
sync_stage("put_done.flag", "get_prepare.flag")

put_gbps = (total_gb * 8) / put_time
put_speeds.append(put_gbps)
time.sleep(2)
# Get metadata (required step for TQ flow)
prompt_meta = asyncio.run(self.data_system_client.async_get_meta(
data_fields=list(big_input_ids.keys()),
batch_size=big_input_ids.size(0),
partition_id=partition_key,
task_name='generate_sequences',
))
prompt_meta = asyncio.run(
self.data_system_client.async_get_meta(
data_fields=list(big_input_ids.keys()),
batch_size=big_input_ids.size(0),
partition_id=partition_key,
task_name="generate_sequences",
)
)

# GET operation
start_get = time.time()
if i == 0 and self.enable_profile:
sync_stage('get_ready.flag', 'get_start.flag')
sync_stage("get_ready.flag", "get_start.flag")
retrieved_data = asyncio.run(self.data_system_client.async_get_data(prompt_meta))
get_time = time.time() - start_get

Expand All @@ -422,7 +387,7 @@ def run_benchmark_rounds(self, config_name, config, rounds):
if not is_consistent:
print(f" ❌ FAIL: {msg}")
else:
print(f" ✅ PASS", end="")
print(" ✅ PASS", end="")
asyncio.run(self.data_system_client.async_clear_partition(partition_id=partition_key))
print("\n")

Expand All @@ -434,7 +399,7 @@ def make_result(op, speeds):
"data_volume": f"{total_gb * 1024:.2f} MB" if total_gb * 1024 < 10 else f"{total_gb:.4f} GB",
"operation": op,
"payload_gb": total_gb,
"stats_gbps": calculate_stats(speeds)
"stats_gbps": calculate_stats(speeds),
}

return [make_result("PUT", put_speeds), make_result("GET", get_speeds)]
Expand All @@ -446,8 +411,9 @@ def make_result(op, speeds):
def main():
parser = argparse.ArgumentParser(description="TransferQueue Bandwidth Benchmark")
parser.add_argument("--ip", type=str, default=None, help="Worker node IP, local test if not set")
parser.add_argument("--config", type=str, default=None, choices=list(CONFIG_MAP.keys()),
help="Specific config to run")
parser.add_argument(
"--config", type=str, default=None, choices=list(CONFIG_MAP.keys()), help="Specific config to run"
)
parser.add_argument("--output", type=str, default="tq_benchmark_result.json", help="Output JSON file")
parser.add_argument("--rounds", type=int, default=20, help="Test rounds per config (default: 20)")
parser.add_argument("--shards", type=int, default=8, help="Number of storage units (default: 8)")
Expand All @@ -458,12 +424,9 @@ def main():
# Initialize Ray
current_working_dir = os.getcwd()
if not ray.is_initialized():
ray.init(
address="auto" if args.ip else None,
runtime_env={"working_dir": current_working_dir}
)
ray.init(address="auto" if args.ip else None, runtime_env={"working_dir": current_working_dir})

target_address = args.ip if args.ip else '127.0.0.1'
target_address = args.ip if args.ip else "127.0.0.1"
logger.info(f"Ray initialized. Target: {target_address}")

# Create tester
Expand All @@ -488,16 +451,12 @@ def main():
logger.info(f"💾 Results saved to {args.output}")

except Exception as e:
logger.error(f"❌ Critical error: {e}", exc_info=True)
logger.exception(f"❌ Critical error: {e}")
finally:
if ray.is_initialized():
ray.shutdown()


if __name__ == "__main__":
try:
from transfer_queue.utils import serial_utils
print(f'[Startup Check] TQ_ZERO_COPY_SERIALIZATION = {serial_utils.TQ_ZERO_COPY_SERIALIZATION}')
except ImportError:
print('[Startup Check] Could not import serial_utils to check flag')
print("[Startup Check]")
main()
6 changes: 3 additions & 3 deletions tests/test_serial_utils_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def test_single_nested_tensor_serialization():

def test_large_string_serialization():
"""Test serialization of large strings (>10KB).

Note: msgpack natively handles str type, so enc_hook is not called for strings.
This test verifies large strings are correctly serialized/deserialized.
"""
Expand All @@ -583,9 +583,9 @@ def test_large_string_serialization():

# Create a string larger than 10KB
large_string = "x" * 11000 # ~11KB

serialized = encoder.encode({"text": large_string})

# Verify content is correctly restored
decoded = decoder.decode(serialized)
assert decoded["text"] == large_string
Expand Down
Loading