diff --git a/devlib/target.py b/devlib/target.py index 5f2c595a0..907fe2807 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -13,6 +13,7 @@ # limitations under the License. # +import atexit import asyncio from contextlib import contextmanager import io @@ -40,6 +41,7 @@ from past.types import basestring from numbers import Number from shlex import quote +from weakref import WeakMethod try: from collections.abc import Mapping except ImportError: @@ -413,6 +415,10 @@ def kind_conflict(kind, names): )) self._modules = modules + atexit.register( + WeakMethod(self.disconnect, atexit.unregister) + ) + self._update_modules('early') if connect: self.connect(max_async=max_async) @@ -521,10 +527,32 @@ async def check_connection(self): def disconnect(self): connections = self._conn.get_all_values() + # Now that we have all the connection objects, we simply reset the TLS + # property so that the connections we got will not be reused anywhere. + del self._conn + + unused_conns = self._unused_conns + self._unused_conns.clear() + for conn in itertools.chain(connections, self._unused_conns): conn.close() - if self._async_pool is not None: - self._async_pool.__exit__(None, None, None) + + pool = self._async_pool + self._async_pool = None + if pool is not None: + pool.__exit__(None, None, None) + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.disconnect() + + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, *args, **kwargs): + return self.__exit__(*args, **kwargs) def get_connection(self, timeout=None): if self.conn_cls is None: diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 0eb7db2ba..aa1e873bb 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -24,14 +24,12 @@ import socket import sys import time -import atexit import contextlib import select import copy import functools import shutil from shlex import quote -from weakref import WeakMethod from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy import paramiko.ssh_exception @@ -372,8 +370,6 @@ def __init__(self, self.client = None try: self.client = self._make_client() - weak_close = WeakMethod(self.close, atexit.unregister) - atexit.register(weak_close) # Use a marker in the output so that we will be able to differentiate # target connection issues with "password needed". @@ -815,9 +811,6 @@ def __init__(self, self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) - weak_close = WeakMethod(self.close, atexit.unregister) - atexit.register(weak_close) - def fmt_remote_path(self, path): return '{}@{}:{}'.format(self.username, self.host, path)