Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 97 additions & 47 deletions ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def _run_device_worker_subprocess(
pto_isa_commit: str | None = None,
print_log_on_fail: bool = False,
quiet: bool = True,
timeout: int | None = None,
) -> list[TaskResult]:
"""Run a task batch in one device-worker subprocess and return its reported results.

Expand Down Expand Up @@ -703,9 +704,11 @@ def _run_device_worker_subprocess(
logger.info(f"[{tag}:dev{device_id}] Launching: {' '.join(full_cmd)}")
try:
if quiet:
proc = subprocess.run(full_cmd, check=False, capture_output=True, text=True)
proc = subprocess.run(full_cmd, check=False, capture_output=True, text=True, timeout=timeout)
else:
proc = subprocess.run(full_cmd, check=False, stdout=None, stderr=subprocess.PIPE, text=True)
proc = subprocess.run(
full_cmd, check=False, stdout=None, stderr=subprocess.PIPE, text=True, timeout=timeout
)
device_results = _read_results_json(result_path)
if proc.returncode != 0:
if print_log_on_fail and quiet:
Expand All @@ -732,6 +735,24 @@ def _run_device_worker_subprocess(
)
)
return device_results
except subprocess.TimeoutExpired:
logger.error(f"[{tag}:dev{device_id}] Subprocess timed out after {timeout}s")
device_results = _read_results_json(result_path)
reported_names = {r.name for r in device_results}
for t in tasks:
if t.name not in reported_names:
device_results.append(
TaskResult(
name=t.name,
platform=t.platform,
passed=False,
device=str(device_id),
attempt=0,
elapsed_s=0,
error=f"Timed out after {timeout}s",
)
)
return device_results
finally:
task_list_path.unlink(missing_ok=True)
result_path.unlink(missing_ok=True)
Expand All @@ -756,39 +777,6 @@ def _normalize_task_result(
)


def run_sim_tasks_subprocess(
tasks: list[TaskSpec],
args: argparse.Namespace,
pto_isa_commit: str | None = None,
) -> list[TaskResult]:
"""Run simulation tasks: one subprocess per runtime group.

Tasks sharing the same runtime reuse a single ChipWorker within their
subprocess. Different runtimes get separate subprocesses so the host SO
is never dlclose/dlopen'd within a single process.
"""
groups: dict[str, list[TaskSpec]] = {}
for t in tasks:
groups.setdefault(t.runtime_name, []).append(t)

is_pin_retry = pto_isa_commit is not None
results: list[TaskResult] = []
for rt_name, group_tasks in groups.items():
logger.info(f"[sim] Launching subprocess for runtime {rt_name} ({len(group_tasks)} task(s))")
results.extend(
_run_device_worker_subprocess(
group_tasks,
0,
args,
tag="sim",
pto_isa_commit=pto_isa_commit,
print_log_on_fail=is_pin_retry,
quiet=False,
)
)
return results


def run_hw_tasks_subprocess(
tasks: list[TaskSpec],
devices: list[int],
Expand Down Expand Up @@ -992,7 +980,13 @@ def _run_tasks_on_device(
pto_isa_root: str,
args: argparse.Namespace,
) -> list[TaskResult]:
"""Compile and run all tasks on a single device. Returns all TaskResults."""
"""Compile and run all tasks on a single device. Returns all TaskResults.

For simulation platforms with sufficient CPUs, tasks are distributed
across multiple virtual device IDs and executed in parallel threads.
ChipWorker.run() internally uses std::thread + join, so GIL is released
during execution, enabling true parallelism.
"""
logger.info(f"Compiling {len(tasks)} tasks...")
try:
compiled = compile_all_tasks(
Expand All @@ -1012,6 +1006,60 @@ def _run_tasks_on_device(
for t in tasks
]

is_sim = platform.endswith("sim")
if is_sim:
cpu_count = os.cpu_count() or 1
max_workers = min(max(cpu_count // 20, 1), len(compiled))
else:
max_workers = 1

if max_workers <= 1:
return _run_compiled_tasks(compiled, device_id, platform)

# Parallel: distribute tasks round-robin across virtual device IDs
buckets: list[list[CompiledTask]] = [[] for _ in range(max_workers)]
for i, ct in enumerate(compiled):
buckets[i % max_workers].append(ct)

logger.info(f"[sim] Parallel execution: {max_workers} workers, {len(compiled)} tasks")

results: list[TaskResult] = []
results_lock = Lock()
completed_count = [0]
total = len(compiled)

def _worker(worker_id: int, worker_tasks: list[CompiledTask]):
dev_id = worker_id
worker_results = _run_compiled_tasks(worker_tasks, dev_id, platform)
with results_lock:
for r in worker_results:
completed_count[0] += 1
n = completed_count[0]
results.append(r)
status = "PASS" if r.passed else "FAIL"
logger.info(f"[dev{dev_id}] [{n}/{total}] {status}: {r.name} ({r.elapsed_s:.1f}s)")

threads = []
for i in range(max_workers):
if not buckets[i]:
continue
t = Thread(target=_worker, args=(i, buckets[i]))
t.start()
threads.append(t)

for t in threads:
t.join()

return results


def _run_compiled_tasks(
compiled: list[CompiledTask],
device_id: int,
platform: str,
) -> list[TaskResult]:
"""Run compiled tasks serially on a single device."""

groups = group_by_runtime(compiled)
all_results: list[TaskResult] = []

Expand Down Expand Up @@ -1163,16 +1211,13 @@ def _run_single_platform(platform: str, args: argparse.Namespace) -> list[TaskRe
return []
logger.info(f"[{platform}] Discovered {len(tasks)} tasks")

# Compile and run.
# Both sim and hw use subprocess isolation (different runtimes cannot share a process).
# Within each subprocess, tasks with the same runtime share a ChipWorker.
# Override platform in args for subprocess spawning.
# Compile and run via subprocess isolation.
# Sim: single subprocess with all tasks (ChipWorker reuse + parallel within).
# HW: one subprocess per task with device-level quarantine.
sub_args = argparse.Namespace(**vars(args))
sub_args.platform = platform
if is_sim:
all_results = _run_with_timeout(
f"{platform} initial pass", args.timeout, lambda: run_sim_tasks_subprocess(tasks, sub_args)
)
all_results = _run_device_worker_subprocess(tasks, 0, sub_args, tag="sim", timeout=args.timeout, quiet=False)
else:
all_results = _run_with_timeout(
f"{platform} initial pass",
Expand All @@ -1191,10 +1236,15 @@ def _run_single_platform(platform: str, args: argparse.Namespace) -> list[TaskRe
failed_tasks = [t for t in tasks if t.name in failed_names]
logger.info(f"[{platform}] {len(failed_tasks)} failure(s), retrying with pinned PTO-ISA {args.pto_isa_commit}")
if is_sim:
pin_results = _run_with_timeout(
f"{platform} pin retry",
args.timeout,
lambda: run_sim_tasks_subprocess(failed_tasks, sub_args, pto_isa_commit=args.pto_isa_commit),
pin_results = _run_device_worker_subprocess(
failed_tasks,
0,
sub_args,
tag="sim",
pto_isa_commit=args.pto_isa_commit,
print_log_on_fail=True,
quiet=False,
timeout=args.timeout,
)
else:
pin_results = _run_with_timeout(
Expand Down
30 changes: 21 additions & 9 deletions src/a2a3/platform/onboard/host/device_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ std::thread DeviceRunner::create_thread(std::function<void()> fn) {
return std::thread([dev_id, fn = std::move(fn)]() {
rtSetDevice(dev_id);
fn();
rtDeviceReset(dev_id);
});
}

Expand All @@ -248,20 +247,20 @@ int DeviceRunner::ensure_device_initialized(
}

int DeviceRunner::ensure_device_set(int device_id) {
// Check if already initialized
if (stream_aicpu_ != nullptr) {
return 0;
}

device_id_ = device_id;

// Set device
// Always set device for the calling thread (CANN device context is per-thread)
int rc = rtSetDevice(device_id);
if (rc != 0) {
LOG_ERROR("rtSetDevice(%d) failed: %d", device_id, rc);
return rc;
}

// Create streams only on first call
if (stream_aicpu_ != nullptr) {
return 0;
}

device_id_ = device_id;

// Create streams
rc = rtStreamCreate(&stream_aicpu_, 0);
if (rc != 0) {
Expand All @@ -281,6 +280,19 @@ int DeviceRunner::ensure_device_set(int device_id) {
return 0;
}

void DeviceRunner::reset_device_context() {
// Destroy streams (they belong to the current thread's CANN context)
if (stream_aicpu_ != nullptr) {
rtStreamDestroy(stream_aicpu_);
stream_aicpu_ = nullptr;
}
if (stream_aicore_ != nullptr) {
rtStreamDestroy(stream_aicore_);
stream_aicore_ = nullptr;
}
rtDeviceReset(device_id_);
}

int DeviceRunner::ensure_binaries_loaded(
const std::vector<uint8_t> &aicpu_so_binary, const std::vector<uint8_t> &aicore_kernel_binary
) {
Expand Down
7 changes: 7 additions & 0 deletions src/a2a3/platform/onboard/host/device_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,13 @@ class DeviceRunner {
*/
int ensure_device_set(int device_id);

/**
* Reset per-thread CANN device context and clear cached streams.
* Called after each run_runtime() completes so the next run on a
* fresh thread can recreate streams in its own context.
*/
void reset_device_context();

private:
// Internal state
int device_id_{-1};
Expand Down
Loading
Loading