Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions src/scriptworker/cot/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,8 @@ def _sort_dependencies_by_name_then_task_id(dependencies):


# build_task_dependencies {{{1
async def build_link(chain, task_name, task_id):
"""Build a LinkOfTrust and add it to the chain.
async def add_link(chain, task_name, task_id):
"""Fetch a task definition and add it as a LinkOfTrust to the chain.

Args:
chain (ChainOfTrust): the chain of trust to add to.
Expand All @@ -633,24 +633,22 @@ async def build_link(chain, task_name, task_id):
"""
link = LinkOfTrust(chain.context, task_name, task_id)
json_path = link.get_artifact_full_path("task.json")
task_defn = await retry_get_task_definition(chain.context.queue, task_id, exception=CoTError)
link.task = task_defn
link.task = await retry_get_task_definition(chain.context.queue, task_id, exception=CoTError)
chain.links.append(link)
# write task json to disk
makedirs(os.path.dirname(json_path))
with open(json_path, "w") as fh:
fh.write(format_json(task_defn))
await build_task_dependencies(chain, task_defn, task_name, task_id)
fh.write(format_json(link.task))


async def build_task_dependencies(chain, task, name, my_task_id):
async def build_task_dependencies(chain, task, name, my_task_id, seen=None):
"""Recursively build the task dependencies of a task.

Args:
chain (ChainOfTrust): the chain of trust to add to.
task (dict): the task definition to operate on.
name (str): the name of the task to operate on.
my_task_id (str): the taskId of the task to operate on.
seen (set): shared set of already-seen task IDs to avoid duplicates

Raises:
CoTError: on failure.
Expand All @@ -661,9 +659,20 @@ async def build_task_dependencies(chain, task, name, my_task_id):
raise CoTError("Too deep recursion!\n{}".format(name))
sorted_dependencies = find_sorted_task_dependencies(task, name, my_task_id)

if seen is None:
seen = set(chain.dependent_task_ids())
new_deps = []
for task_name, task_id in sorted_dependencies:
if task_id not in chain.dependent_task_ids():
await build_link(chain, task_name, task_id)
if task_id not in seen:
seen.add(task_id)
new_deps.append((task_name, task_id))

if not new_deps:
return

await asyncio.gather(*[add_link(chain, task_name, task_id) for task_name, task_id in new_deps])

await asyncio.gather(*[build_task_dependencies(chain, chain.get_link(task_id).task, task_name, task_id, seen) for task_name, task_id in new_deps])


# download_cot {{{1
Expand Down Expand Up @@ -1565,8 +1574,10 @@ async def get_jsone_context_and_template(chain, parent_link, decision_link, task
if tasks_for in ("action", "pr-action"):
jsone_context, tmpl = await get_action_context_and_template(chain, parent_link, decision_link, tasks_for)
else:
tmpl = await get_in_tree_template(decision_link)
jsone_context = await populate_jsone_context(chain, parent_link, decision_link, tasks_for)
tmpl, jsone_context = await asyncio.gather(
get_in_tree_template(decision_link),
populate_jsone_context(chain, parent_link, decision_link, tasks_for),
)
return jsone_context, tmpl


Expand Down Expand Up @@ -1803,14 +1814,14 @@ async def verify_task_types(chain):
"""
valid_task_types = get_valid_task_types()
task_count = {}
tasks = []
for obj in chain.get_all_links_in_chain():
task_type = obj.task_type
log.info("Verifying {} {} as a {} task...".format(obj.name, obj.task_id, task_type))
task_count.setdefault(task_type, 0)
task_count[task_type] += 1
# Run tests synchronously for now. We can parallelize if efficiency
# is more important than a single simple logfile.
await valid_task_types[task_type](chain, obj)
tasks.append(valid_task_types[task_type](chain, obj))
await asyncio.gather(*tasks)
return task_count


Expand Down Expand Up @@ -1880,12 +1891,12 @@ async def verify_worker_impls(chain):

"""
valid_worker_impls = get_valid_worker_impls()
tasks = []
for obj in chain.get_all_links_in_chain():
worker_impl = obj.worker_impl
log.info("Verifying {} {} as a {} task...".format(obj.name, obj.task_id, worker_impl))
# Run tests synchronously for now. We can parallelize if efficiency
# is more important than a single simple logfile.
await valid_worker_impls[worker_impl](chain, obj)
tasks.append(valid_worker_impls[worker_impl](chain, obj))
await asyncio.gather(*tasks)


# get_source_url {{{1
Expand Down Expand Up @@ -2042,9 +2053,8 @@ async def verify_chain_of_trust(chain, *, check_task=False):
try:
# build LinkOfTrust objects
if check_task:
await build_link(chain, chain.name, chain.task_id)
else:
await build_task_dependencies(chain, chain.task, chain.name, chain.task_id)
await add_link(chain, chain.name, chain.task_id)
await build_task_dependencies(chain, chain.task, chain.name, chain.task_id)
# download the signed chain of trust artifacts
await download_cot(chain)
# verify the signatures and populate the ``link.cot``s
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cot_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,7 @@ async def maybe_die(*args):
if exc is not None:
raise exc("blah")

for func in ("build_task_dependencies", "build_link", "download_cot", "download_cot_artifacts", "verify_task_types", "verify_worker_impls"):
for func in ("build_task_dependencies", "add_link", "download_cot", "download_cot_artifacts", "verify_task_types", "verify_worker_impls"):
mocker.patch.object(cotverify, func, new=noop_async)
mocker.patch.object(cotverify, "verify_cot_signatures", new=noop_sync)
mocker.patch.object(cotverify, "trace_back_to_tree", new=maybe_die)
Expand Down