diff --git a/src/commands/run.py b/src/commands/run.py index 028da2e..035ee4b 100644 --- a/src/commands/run.py +++ b/src/commands/run.py @@ -126,21 +126,30 @@ async def start_tasks(tasks): active_tasks = [] def remove_done_tasks(): for active_task in active_tasks: - if active_task._state == 'FINISHED': + if active_task.finished: active_tasks.remove(active_task) break - for task in tasks: - while len(active_tasks) > max_ssh: + + async def add_tasks(): + for task in tasks: + while len(active_tasks) > max_ssh: + remove_done_tasks() + await asyncio.sleep(0.01) + asyncio.ensure_future(task, loop = loop) + active_tasks.append(task) + run_print('Starting Job: {}'.format(task.node)) + await(asyncio.sleep(start_delay)) + + async def await_finished(): + while len(active_tasks) > 0: remove_done_tasks() await asyncio.sleep(0.01) - asyncio.ensure_future(task, loop = loop) - active_tasks.append(task) - run_print('Starting Job: {}'.format(task.node)) - await(asyncio.sleep(start_delay)) - run_print('Waiting for jobs to finish...') - while len(active_tasks) > 0: - remove_done_tasks() - await asyncio.sleep(0.01) + + try: + await add_tasks() + run_print('Waiting for jobs to finish...') + except concurrent.futures.CancelledError: pass + finally: await await_finished() session = Session() try: @@ -153,7 +162,6 @@ def remove_done_tasks(): run_print('Executing: {}'.format(batch.name())) tasks = map(lambda j: j.task(thread_pool = pool), batch.jobs) loop.run_until_complete(start_tasks(tasks)) - except concurrent.futures.CancelledError: pass finally: run_print('Cleaning up...') pool.shutdown(wait = True) diff --git a/src/models/job.py b/src/models/job.py index 8cb5770..d9b40d0 100644 --- a/src/models/job.py +++ b/src/models/job.py @@ -64,10 +64,17 @@ def __init__(self, job, thread_pool = None): self.thread_pool = thread_pool super().__init__(self.run_async()) self.job = job - self.add_job_callback(lambda job: job.connection().close()) + self.add_job_callback(type(self).close) self.add_job_callback(lambda job: job.set_ssh_results()) self.add_done_callback(type(self).report_results) + def close(self): + try: job.connection.close() + except: pass + + def finished(self): + return self._state == 'FINISHED' + def __getattr__(self, attr): return getattr(self.job, attr) @@ -117,8 +124,7 @@ def catch_errors(func, *args): async def run_async(self): if self.check_command(): - try: await self._run_thread(self.connection().open) - except concurrent.futures.CancelledError as e: raise e + await self._run_thread(self.connection().open) if self.connection().is_connected: await self._run_thread(self.run, self.batch)