Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions packages/gds-framework/gds/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import threading
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -168,6 +169,7 @@ def _to_ports(names: list[str] | None) -> tuple[Port, ...]:
CheckFn = Callable[[SystemIR], list[Finding]]

_CUSTOM_CHECKS: list[CheckFn] = []
_CUSTOM_CHECKS_LOCK = threading.Lock()


def gds_check(
Expand All @@ -177,7 +179,8 @@ def gds_check(
"""Decorator that registers a verification check function.

Attaches ``check_id`` and ``severity`` as function attributes and
adds it to the module-level custom check registry.
adds it to the module-level custom check registry. Registration is
thread-safe.

Usage::

Expand All @@ -189,19 +192,25 @@ def check_no_orphan_spaces(system: SystemIR) -> list[Finding]:
def decorator(fn: CheckFn) -> CheckFn:
fn.check_id = check_id # type: ignore[attr-defined]
fn.severity = severity # type: ignore[attr-defined]
_CUSTOM_CHECKS.append(fn)
with _CUSTOM_CHECKS_LOCK:
_CUSTOM_CHECKS.append(fn)
return fn

return decorator


def get_custom_checks() -> list[CheckFn]:
"""Return all check functions registered via ``@gds_check``."""
return list(_CUSTOM_CHECKS)
"""Return a snapshot of all check functions registered via ``@gds_check``.

Returns a copy so callers cannot mutate the internal registry.
"""
with _CUSTOM_CHECKS_LOCK:
return list(_CUSTOM_CHECKS)


def all_checks() -> list[CheckFn]:
"""Return built-in generic checks + all custom-registered checks."""
from gds.verification.engine import ALL_CHECKS

return list(ALL_CHECKS) + list(_CUSTOM_CHECKS)
with _CUSTOM_CHECKS_LOCK:
return list(ALL_CHECKS) + list(_CUSTOM_CHECKS)
40 changes: 38 additions & 2 deletions packages/gds-framework/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from gds.helpers import (
_CUSTOM_CHECKS,
_CUSTOM_CHECKS_LOCK,
all_checks,
entity,
gds_check,
Expand Down Expand Up @@ -245,11 +246,13 @@ def test_empty(self):
class TestGdsCheck:
def setup_method(self):
"""Clear custom check registry before each test."""
_CUSTOM_CHECKS.clear()
with _CUSTOM_CHECKS_LOCK:
_CUSTOM_CHECKS.clear()

def teardown_method(self):
"""Clear custom check registry after each test."""
_CUSTOM_CHECKS.clear()
with _CUSTOM_CHECKS_LOCK:
_CUSTOM_CHECKS.clear()

def test_registers_check(self):
@gds_check("TEST-001", Severity.WARNING)
Expand Down Expand Up @@ -309,3 +312,36 @@ def check_b(system: SystemIR) -> list[Finding]:
assert len(checks) == 2
assert checks[0] is check_a
assert checks[1] is check_b

def test_concurrent_registration(self):
"""Register checks from multiple threads; all must appear."""
import threading

num_threads = 20
barrier = threading.Barrier(num_threads)
errors: list[Exception] = []

def register_check(index: int) -> None:
try:
barrier.wait(timeout=5)

@gds_check(f"THREAD-{index:03d}")
def _check(system: SystemIR) -> list[Finding]:
return []
except Exception as exc:
errors.append(exc)

threads = [
threading.Thread(target=register_check, args=(i,))
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)

assert not errors, f"Thread errors: {errors}"
checks = get_custom_checks()
assert len(checks) == num_threads
check_ids = {c.check_id for c in checks} # type: ignore[attr-defined]
assert check_ids == {f"THREAD-{i:03d}" for i in range(num_threads)}