diff --git a/src/scriptworker/cot/verify.py b/src/scriptworker/cot/verify.py index bedad9d9..58821d1c 100644 --- a/src/scriptworker/cot/verify.py +++ b/src/scriptworker/cot/verify.py @@ -619,8 +619,8 @@ def _sort_dependencies_by_name_then_task_id(dependencies): # build_task_dependencies {{{1 -async def build_link(chain, task_name, task_id): - """Build a LinkOfTrust and add it to the chain. +async def add_link(chain, task_name, task_id): + """Fetch a task definition and add it as a LinkOfTrust to the chain. Args: chain (ChainOfTrust): the chain of trust to add to. @@ -633,17 +633,14 @@ async def build_link(chain, task_name, task_id): """ link = LinkOfTrust(chain.context, task_name, task_id) json_path = link.get_artifact_full_path("task.json") - task_defn = await retry_get_task_definition(chain.context.queue, task_id, exception=CoTError) - link.task = task_defn + link.task = await retry_get_task_definition(chain.context.queue, task_id, exception=CoTError) chain.links.append(link) - # write task json to disk makedirs(os.path.dirname(json_path)) with open(json_path, "w") as fh: - fh.write(format_json(task_defn)) - await build_task_dependencies(chain, task_defn, task_name, task_id) + fh.write(format_json(link.task)) -async def build_task_dependencies(chain, task, name, my_task_id): +async def build_task_dependencies(chain, task, name, my_task_id, seen=None): """Recursively build the task dependencies of a task. Args: @@ -651,6 +648,7 @@ async def build_task_dependencies(chain, task, name, my_task_id): task (dict): the task definition to operate on. name (str): the name of the task to operate on. my_task_id (str): the taskId of the task to operate on. + seen (set): shared set of already-seen task IDs to avoid duplicates Raises: CoTError: on failure. @@ -661,9 +659,20 @@ async def build_task_dependencies(chain, task, name, my_task_id): raise CoTError("Too deep recursion!\n{}".format(name)) sorted_dependencies = find_sorted_task_dependencies(task, name, my_task_id) + if seen is None: + seen = set(chain.dependent_task_ids()) + new_deps = [] for task_name, task_id in sorted_dependencies: - if task_id not in chain.dependent_task_ids(): - await build_link(chain, task_name, task_id) + if task_id not in seen: + seen.add(task_id) + new_deps.append((task_name, task_id)) + + if not new_deps: + return + + await asyncio.gather(*[add_link(chain, task_name, task_id) for task_name, task_id in new_deps]) + + await asyncio.gather(*[build_task_dependencies(chain, chain.get_link(task_id).task, task_name, task_id, seen) for task_name, task_id in new_deps]) # download_cot {{{1 @@ -1565,8 +1574,10 @@ async def get_jsone_context_and_template(chain, parent_link, decision_link, task if tasks_for in ("action", "pr-action"): jsone_context, tmpl = await get_action_context_and_template(chain, parent_link, decision_link, tasks_for) else: - tmpl = await get_in_tree_template(decision_link) - jsone_context = await populate_jsone_context(chain, parent_link, decision_link, tasks_for) + tmpl, jsone_context = await asyncio.gather( + get_in_tree_template(decision_link), + populate_jsone_context(chain, parent_link, decision_link, tasks_for), + ) return jsone_context, tmpl @@ -1803,14 +1814,14 @@ async def verify_task_types(chain): """ valid_task_types = get_valid_task_types() task_count = {} + tasks = [] for obj in chain.get_all_links_in_chain(): task_type = obj.task_type log.info("Verifying {} {} as a {} task...".format(obj.name, obj.task_id, task_type)) task_count.setdefault(task_type, 0) task_count[task_type] += 1 - # Run tests synchronously for now. We can parallelize if efficiency - # is more important than a single simple logfile. - await valid_task_types[task_type](chain, obj) + tasks.append(valid_task_types[task_type](chain, obj)) + await asyncio.gather(*tasks) return task_count @@ -1880,12 +1891,12 @@ async def verify_worker_impls(chain): """ valid_worker_impls = get_valid_worker_impls() + tasks = [] for obj in chain.get_all_links_in_chain(): worker_impl = obj.worker_impl log.info("Verifying {} {} as a {} task...".format(obj.name, obj.task_id, worker_impl)) - # Run tests synchronously for now. We can parallelize if efficiency - # is more important than a single simple logfile. - await valid_worker_impls[worker_impl](chain, obj) + tasks.append(valid_worker_impls[worker_impl](chain, obj)) + await asyncio.gather(*tasks) # get_source_url {{{1 @@ -2042,9 +2053,8 @@ async def verify_chain_of_trust(chain, *, check_task=False): try: # build LinkOfTrust objects if check_task: - await build_link(chain, chain.name, chain.task_id) - else: - await build_task_dependencies(chain, chain.task, chain.name, chain.task_id) + await add_link(chain, chain.name, chain.task_id) + await build_task_dependencies(chain, chain.task, chain.name, chain.task_id) # download the signed chain of trust artifacts await download_cot(chain) # verify the signatures and populate the ``link.cot``s diff --git a/tests/test_cot_verify.py b/tests/test_cot_verify.py index d06ac5a0..20434fb4 100644 --- a/tests/test_cot_verify.py +++ b/tests/test_cot_verify.py @@ -2208,7 +2208,7 @@ async def maybe_die(*args): if exc is not None: raise exc("blah") - for func in ("build_task_dependencies", "build_link", "download_cot", "download_cot_artifacts", "verify_task_types", "verify_worker_impls"): + for func in ("build_task_dependencies", "add_link", "download_cot", "download_cot_artifacts", "verify_task_types", "verify_worker_impls"): mocker.patch.object(cotverify, func, new=noop_async) mocker.patch.object(cotverify, "verify_cot_signatures", new=noop_sync) mocker.patch.object(cotverify, "trace_back_to_tree", new=maybe_die)