Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,25 @@
}


def should_run(nodeid: str, num_shards: int, shard_index: int) -> bool:
def find_shard_index(nodeid: str, num_shards: int) -> int:
"""
Return true if this test should run on this shard
Return the index of the shard that should run this test
"""
for prefix, target_shard_idx in FIXED_ALLOCATION_PREFIXES.items():
if nodeid.startswith(prefix):
if target_shard_idx >= num_shards:
raise RuntimeError(
f"Cannot collect sharded tests, {nodeid} has hardcoded shard index {target_shard_idx} among only {num_shards} shards"
)
return target_shard_idx == shard_index
return target_shard_idx

if nodeid in HARDCODED_ALLOCATIONS:
hash = HARDCODED_ALLOCATIONS[nodeid]
else:
hash = hashlib.md5(nodeid.encode())
hash = int(hash.hexdigest(), 16)

return hash % num_shards == shard_index
return hash % num_shards


def pytest_collection_modifyitems(config, items):
Expand All @@ -89,5 +89,10 @@ def pytest_collection_modifyitems(config, items):

print(f"Marking tests for shard {shard_index} of {num_shards}")
for item in items:
if not should_run(item.nodeid, num_shards=num_shards, shard_index=shard_index):
item.add_marker(pytest.mark.skip())
item_shard_index = find_shard_index(item.nodeid, num_shards=num_shards)
item.add_marker(
pytest.mark.skipif(
item_shard_index != shard_index,
reason=f"Test running on shard {item_shard_index} of {num_shards}",
)
)