diff --git a/packages/gds-framework/gds/helpers.py b/packages/gds-framework/gds/helpers.py index b51c635..bd47e8c 100644 --- a/packages/gds-framework/gds/helpers.py +++ b/packages/gds-framework/gds/helpers.py @@ -7,6 +7,7 @@ from __future__ import annotations +import threading from collections.abc import Callable from typing import Any @@ -168,6 +169,7 @@ def _to_ports(names: list[str] | None) -> tuple[Port, ...]: CheckFn = Callable[[SystemIR], list[Finding]] _CUSTOM_CHECKS: list[CheckFn] = [] +_CUSTOM_CHECKS_LOCK = threading.Lock() def gds_check( @@ -177,7 +179,8 @@ def gds_check( """Decorator that registers a verification check function. Attaches ``check_id`` and ``severity`` as function attributes and - adds it to the module-level custom check registry. + adds it to the module-level custom check registry. Registration is + thread-safe. Usage:: @@ -189,19 +192,25 @@ def check_no_orphan_spaces(system: SystemIR) -> list[Finding]: def decorator(fn: CheckFn) -> CheckFn: fn.check_id = check_id # type: ignore[attr-defined] fn.severity = severity # type: ignore[attr-defined] - _CUSTOM_CHECKS.append(fn) + with _CUSTOM_CHECKS_LOCK: + _CUSTOM_CHECKS.append(fn) return fn return decorator def get_custom_checks() -> list[CheckFn]: - """Return all check functions registered via ``@gds_check``.""" - return list(_CUSTOM_CHECKS) + """Return a snapshot of all check functions registered via ``@gds_check``. + + Returns a copy so callers cannot mutate the internal registry. + """ + with _CUSTOM_CHECKS_LOCK: + return list(_CUSTOM_CHECKS) def all_checks() -> list[CheckFn]: """Return built-in generic checks + all custom-registered checks.""" from gds.verification.engine import ALL_CHECKS - return list(ALL_CHECKS) + list(_CUSTOM_CHECKS) + with _CUSTOM_CHECKS_LOCK: + return list(ALL_CHECKS) + list(_CUSTOM_CHECKS) diff --git a/packages/gds-framework/tests/test_helpers.py b/packages/gds-framework/tests/test_helpers.py index 4d22e29..d244e88 100644 --- a/packages/gds-framework/tests/test_helpers.py +++ b/packages/gds-framework/tests/test_helpers.py @@ -20,6 +20,7 @@ ) from gds.helpers import ( _CUSTOM_CHECKS, + _CUSTOM_CHECKS_LOCK, all_checks, entity, gds_check, @@ -245,11 +246,13 @@ def test_empty(self): class TestGdsCheck: def setup_method(self): """Clear custom check registry before each test.""" - _CUSTOM_CHECKS.clear() + with _CUSTOM_CHECKS_LOCK: + _CUSTOM_CHECKS.clear() def teardown_method(self): """Clear custom check registry after each test.""" - _CUSTOM_CHECKS.clear() + with _CUSTOM_CHECKS_LOCK: + _CUSTOM_CHECKS.clear() def test_registers_check(self): @gds_check("TEST-001", Severity.WARNING) @@ -309,3 +312,36 @@ def check_b(system: SystemIR) -> list[Finding]: assert len(checks) == 2 assert checks[0] is check_a assert checks[1] is check_b + + def test_concurrent_registration(self): + """Register checks from multiple threads; all must appear.""" + import threading + + num_threads = 20 + barrier = threading.Barrier(num_threads) + errors: list[Exception] = [] + + def register_check(index: int) -> None: + try: + barrier.wait(timeout=5) + + @gds_check(f"THREAD-{index:03d}") + def _check(system: SystemIR) -> list[Finding]: + return [] + except Exception as exc: + errors.append(exc) + + threads = [ + threading.Thread(target=register_check, args=(i,)) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + checks = get_custom_checks() + assert len(checks) == num_threads + check_ids = {c.check_id for c in checks} # type: ignore[attr-defined] + assert check_ids == {f"THREAD-{i:03d}" for i in range(num_threads)}