From 4361e73b1ccf382ed137ed2e605af5509075cd3d Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Thu, 12 Dec 2019 17:25:47 +0000 Subject: [PATCH 1/6] devlib.utils.misc: Use Popen.communicate(timeout=...) in check_output Use the timeout parameter added in Python 3.3, which removes the need for the timer thread and avoids some weird issues in preexec_fn, as it's now documented to sometimes not work when threads are involved. --- devlib/utils/misc.py | 43 +++++++++++++++---------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 4d488b678..0738b3eca 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -136,9 +136,6 @@ def get_cpu_name(implementer, part, variant): def preexec_function(): - # Ignore the SIGINT signal by setting the handler to the standard - # signal handler SIG_IGN. - signal.signal(signal.SIGINT, signal.SIG_IGN) # Change process group in case we have to kill the subprocess and all of # its children later. # TODO: this is Unix-specific; would be good to find an OS-agnostic way @@ -167,13 +164,6 @@ def check_output(command, timeout=None, ignore=None, inputtext=None, if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') - def callback(pid): - try: - check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid)) - os.killpg(pid, signal.SIGKILL) - except OSError: - pass # process may have already terminated. - with check_output_lock: stderr = subprocess.STDOUT if combined_output else subprocess.PIPE process = subprocess.Popen(command, @@ -183,27 +173,24 @@ def callback(pid): preexec_fn=preexec_function, **kwargs) - if timeout: - timer = threading.Timer(timeout, callback, [process.pid, ]) - timer.start() - try: - output, error = process.communicate(inputtext) - if sys.version_info[0] == 3: - # Currently errors=replace is needed as 0x8c throws an error - output = output.decode(sys.stdout.encoding or 'utf-8', "replace") - if error: - error = error.decode(sys.stderr.encoding or 'utf-8', "replace") - finally: - if timeout: - timer.cancel() + output, error = process.communicate(inputtext, timeout=timeout) + except subprocess.TimeoutExpired as e: + timeout_expired = e + else: + timeout_expired = None + + # Currently errors=replace is needed as 0x8c throws an error + output = output.decode(sys.stdout.encoding or 'utf-8', "replace") + if error: + error = error.decode(sys.stderr.encoding or 'utf-8', "replace") + + if timeout_expired: + raise TimeoutError(command, output='\n'.join([output or '', error or ''])) retcode = process.poll() - if retcode: - if retcode == -9: # killed, assume due to timeout callback - raise TimeoutError(command, output='\n'.join([output or '', error or ''])) - elif ignore != 'all' and retcode not in ignore: - raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output or '', error or ''])) + if retcode and ignore != 'all' and retcode not in ignore: + raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output or '', error or ''])) return output, error From 2c12314be003e2a0e80d83e424b4dd40c5d6f47c Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Wed, 15 Jan 2020 16:05:44 +0000 Subject: [PATCH 2/6] utils/misc: Add tls_property() Similar to a regular property(), with the following differences: * Values are memoized and are threadlocal * The value returned by the property needs to be called (like a weakref) to get the actual value. This level of indirection is needed to allow methods to be implemented in the proxy object without clashing with the value's methods. * If the above is too annoying, a "sub property" can be created with the regular property() behavior (and therefore without the additional methods) using tls_property.basic_property . --- devlib/utils/misc.py | 118 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 0738b3eca..4f8beb12c 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -23,6 +23,7 @@ from functools import partial, reduce from itertools import groupby from operator import itemgetter +from weakref import WeakKeyDictionary, WeakSet import ctypes import functools @@ -705,3 +706,120 @@ def batch_contextmanager(f, kwargs_list): for kwargs in kwargs_list: stack.enter_context(f(**kwargs)) yield + +class tls_property: + """ + Use it like `property` decorator, but the result will be memoized per + thread. When the owning thread dies, the values for that thread will be + destroyed. + + In order to get the values, it's necessary to call the object + given by the property. This is necessary in order to be able to add methods + to that object, like :meth:`_BoundTLSProperty.get_all_values`. + + Values can be set and deleted as well, which will be a thread-local set. + """ + + @property + def name(self): + return self.factory.__name__ + + def __init__(self, factory): + self.factory = factory + # Lock accesses to shared WeakKeyDictionary and WeakSet + self.lock = threading.Lock() + + def __get__(self, instance, owner=None): + return _BoundTLSProperty(self, instance, owner) + + def _get_value(self, instance, owner): + tls, values = self._get_tls(instance) + try: + return tls.value + except AttributeError: + # Bind the method to `instance` + f = self.factory.__get__(instance, owner) + obj = f() + tls.value = obj + # Since that's a WeakSet, values will be removed automatically once + # the threading.local variable that holds them is destroyed + with self.lock: + values.add(obj) + return obj + + def _get_all_values(self, instance, owner): + with self.lock: + # Grab a reference to all the objects at the time of the call by + # using a regular set + tls, values = self._get_tls(instance=instance) + return set(values) + + def __set__(self, instance, value): + tls, values = self._get_tls(instance) + tls.value = value + with self.lock: + values.add(value) + + def __delete__(self, instance): + tls, values = self._get_tls(instance) + with self.lock: + values.discard(tls.value) + del tls.value + + def _get_tls(self, instance): + dct = instance.__dict__ + name = self.name + try: + # Using instance.__dict__[self.name] is safe as + # getattr(instance, name) will return the property instead, as + # the property is a descriptor + tls = dct[name] + except KeyError: + with self.lock: + # Double check after taking the lock to avoid a race + if name not in dct: + tls = (threading.local(), WeakSet()) + dct[name] = tls + + return tls + + @property + def basic_property(self): + """ + Return a basic property that can be used to access the TLS value + without having to call it first. + + The drawback is that it's not possible to do anything over than + getting/setting/deleting. + """ + def getter(instance, owner=None): + prop = self.__get__(instance, owner) + return prop() + + return property(getter, self.__set__, self.__delete__) + +class _BoundTLSProperty: + """ + Simple proxy object to allow either calling it to get the TLS value, or get + some other informations by calling methods. + """ + def __init__(self, tls_property, instance, owner): + self.tls_property = tls_property + self.instance = instance + self.owner = owner + + def __call__(self): + return self.tls_property._get_value( + instance=self.instance, + owner=self.owner, + ) + + def get_all_values(self): + """ + Returns all the thread-local values currently in use in the process for + that property for that instance. + """ + return self.tls_property._get_all_values( + instance=self.instance, + owner=self.owner, + ) From 73f2201dc72c4b79eb996591e77f8f386ab78198 Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Wed, 15 Jan 2020 17:16:47 +0000 Subject: [PATCH 3/6] target: Use tls_property() to manage a thread-local connection This frees the connection to have to handle threading issues, since each thread using the Target will have its own connection. The connection will be garbage collected when the thread using it dies, avoiding connection leaks. --- devlib/target.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/devlib/target.py b/devlib/target.py index e3765a389..2b8296cd6 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -53,7 +53,7 @@ from devlib.utils.misc import memoized, isiterable, convert_new_lines from devlib.utils.misc import commonprefix, merge_lists from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list -from devlib.utils.misc import batch_contextmanager +from devlib.utils.misc import batch_contextmanager, tls_property from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex @@ -214,22 +214,19 @@ def page_size_kb(self): cmd = "cat /proc/self/smaps | {0} grep KernelPageSize | {0} head -n 1 | {0} awk '{{ print $2 }}'" return int(self.execute(cmd.format(self.busybox))) - @property - def conn(self): - if self._connections: - tid = id(threading.current_thread()) - if tid not in self._connections: - self._connections[tid] = self.get_connection() - return self._connections[tid] - else: - return None - @property def shutils(self): if self._shutils is None: self._setup_shutils() return self._shutils + @tls_property + def _conn(self): + return self.get_connection() + + # Add a basic property that does not require calling to get the value + conn = _conn.basic_property + def __init__(self, connection_settings=None, platform=None, @@ -242,6 +239,7 @@ def __init__(self, conn_cls=None, is_container=False ): + self._is_rooted = None self.connection_settings = connection_settings or {} # Set self.platform: either it's given directly (by platform argument) @@ -271,7 +269,6 @@ def __init__(self, self._installed_binaries = {} self._installed_modules = {} self._cache = {} - self._connections = {} self._shutils = None self._file_transfer_cache = None self.busybox = None @@ -290,8 +287,9 @@ def __init__(self, def connect(self, timeout=None, check_boot_completed=True): self.platform.init_target_connection(self) - tid = id(threading.current_thread()) - self._connections[tid] = self.get_connection(timeout=timeout) + # Forcefully set the thread-local value for the connection, with the + # timeout we want + self.conn = self.get_connection(timeout=timeout) if check_boot_completed: self.wait_boot_complete(timeout) self._resolve_paths() @@ -304,9 +302,9 @@ def connect(self, timeout=None, check_boot_completed=True): self._install_module(get_module('bl')) def disconnect(self): - for conn in self._connections.values(): + connections = self._conn.get_all_values() + for conn in connections: conn.close() - self._connections = {} def get_connection(self, timeout=None): if self.conn_cls is None: From da3d94e28dc2983babeaba6afbb52bf94e7d80c9 Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Wed, 15 Jan 2020 16:06:20 +0000 Subject: [PATCH 4/6] utils/misc: Add redirect_streams() helper Update a command line to redirect standard streams as specified using the parameters. This helper allows honoring streams specified in the same way as subprocess.Popen, by doing it as much using shell redirections as possible. --- devlib/utils/misc.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 4f8beb12c..2413cc722 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -46,6 +46,11 @@ except AttributeError: from contextlib2 import ExitStack +try: + from shlex import quote +except ImportError: + from pipes import quote + from past.builtins import basestring # pylint: disable=redefined-builtin @@ -232,6 +237,32 @@ def __try_import(path): mods.append(submod) return mods +def redirect_streams(stdout, stderr, command): + """ + Update a command to redirect a given stream to /dev/null if it's + ``subprocess.DEVNULL``. + + :return: A tuple (stdout, stderr, command) with stream set to ``subprocess.PIPE`` + if the `stream` parameter was set to ``subprocess.DEVNULL``. + """ + def redirect(stream, redirection): + if stream == subprocess.DEVNULL: + suffix = '{}/dev/null'.format(redirection) + elif stream == subprocess.STDOUT: + suffix = '{}&1'.format(redirection) + # Indicate that there is nothing to monitor for stderr anymore + # since it's merged into stdout + stream = subprocess.DEVNULL + else: + suffix = '' + + return (stream, suffix) + + stdout, suffix1 = redirect(stdout, '>') + stderr, suffix2 = redirect(stderr, '2>') + + command = 'sh -c {} {} {}'.format(quote(command), suffix1, suffix2) + return (stdout, stderr, command) def ensure_directory_exists(dirpath): """A filter for directory paths to ensure they exist.""" From 9b712109f5618efe1aabba4350c0072941ffddc1 Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Wed, 15 Jan 2020 17:19:24 +0000 Subject: [PATCH 5/6] connections: Unify BackgroundCommand API and use paramiko for SSH * Unify the behavior of background commands in connections.BackgroundCommand(). This implements a subset of subprocess.Popen class, with a unified behavior across all connection types * Implement the SSH connection using paramiko rather than pxssh. --- devlib/connection.py | 351 ++++++++++++++++++++++ devlib/host.py | 23 +- devlib/utils/android.py | 52 +++- devlib/utils/misc.py | 49 +++- devlib/utils/ssh.py | 628 +++++++++++++++++++++++++++++++++++----- setup.py | 1 + 6 files changed, 1022 insertions(+), 82 deletions(-) create mode 100644 devlib/connection.py diff --git a/devlib/connection.py b/devlib/connection.py new file mode 100644 index 000000000..e6d67c0af --- /dev/null +++ b/devlib/connection.py @@ -0,0 +1,351 @@ +# Copyright 2019 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import subprocess +import signal +import threading +from weakref import WeakSet +from abc import ABC, abstractmethod + +from devlib.utils.misc import InitCheckpoint + +_KILL_TIMEOUT = 3 + + +def _kill_pgid_cmd(pgid, sig): + return 'kill -{} -{}'.format(sig.name, pgid) + + +class ConnectionBase(InitCheckpoint): + """ + Base class for all connections. + """ + def __init__(self): + self._current_bg_cmds = WeakSet() + self._closed = False + self._close_lock = threading.Lock() + + def cancel_running_command(self): + bg_cmds = set(self._current_bg_cmds) + for bg_cmd in bg_cmds: + bg_cmd.cancel() + + @abstractmethod + def _close(self): + """ + Close the connection. + + The public :meth:`close` method makes sure that :meth:`_close` will + only be called once, and will serialize accesses to it if it happens to + be called from multiple threads at once. + """ + + def close(self): + # Locking the closing allows any thread to safely call close() as long + # as the connection can be closed from a thread that is not the one it + # started its life in. + with self._close_lock: + if not self._closed: + self._close() + self._closed = True + + # Ideally, that should not be relied upon but that will improve the chances + # of the connection being properly cleaned up when it's not in use anymore. + def __del__(self): + # Since __del__ will be called if an exception is raised in __init__ + # (e.g. we cannot connect), we only run close() when we are sure + # __init__ has completed successfully. + if self.initialized: + self.close() + + +class BackgroundCommand(ABC): + """ + Allows managing a running background command using a subset of the + :class:`subprocess.Popen` API. + + Instances of this class can be used as context managers, with the same + semantic as :class:`subprocess.Popen`. + """ + @abstractmethod + def send_signal(self, sig): + """ + Send a POSIX signal to the background command's process group ID + (PGID). + + :param signal: Signal to send. + :type signal: signal.Signals + """ + + def kill(self): + """ + Send SIGKILL to the background command. + """ + self.send_signal(signal.SIGKILL) + + @abstractmethod + def cancel(self, kill_timeout=_KILL_TIMEOUT): + """ + Try to gracefully terminate the process by sending ``SIGTERM``, then + waiting for ``kill_timeout`` to send ``SIGKILL``. + """ + + @abstractmethod + def wait(self): + """ + Block until the background command completes, and return its exit code. + """ + + @abstractmethod + def poll(self): + """ + Return exit code if the command has exited, None otherwise. + """ + + @property + @abstractmethod + def stdin(self): + """ + File-like object connected to the background's command stdin. + """ + + @property + @abstractmethod + def stdout(self): + """ + File-like object connected to the background's command stdout. + """ + + @property + @abstractmethod + def stderr(self): + """ + File-like object connected to the background's command stderr. + """ + + @property + @abstractmethod + def pid(self): + """ + Process Group ID (PGID) of the background command. + + Since the command is usually wrapped in shell processes for IO + redirections, sudo etc, the PID cannot be assumed to be the actual PID + of the command passed by the user. It's is guaranteed to be a PGID + instead, which means signals sent to it as such will target all + subprocesses involved in executing that command. + """ + + @abstractmethod + def close(self): + """ + Close all opened streams and then wait for command completion. + + :returns: Exit code of the command. + + .. note:: If the command is writing to its stdout/stderr, it might be + blocked on that and die when the streams are closed. + """ + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class PopenBackgroundCommand(BackgroundCommand): + """ + :class:`subprocess.Popen`-based background command. + """ + + def __init__(self, popen): + self.popen = popen + + def send_signal(self, sig): + return os.killpg(self.popen.pid, sig) + + @property + def stdin(self): + return self.popen.stdin + + @property + def stdout(self): + return self.popen.stdout + + @property + def stderr(self): + return self.popen.stderr + + @property + def pid(self): + return self.popen.pid + + def wait(self): + return self.popen.wait() + + def poll(self): + return self.popen.poll() + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + popen = self.popen + popen.send_signal(signal.SIGTERM) + try: + popen.wait(timeout=_KILL_TIMEOUT) + except subprocess.TimeoutExpired: + popen.kill() + + def close(self): + self.popen.__exit__(None, None, None) + return self.popen.returncode + + def __enter__(self): + self.popen.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.popen.__exit__(*args, **kwargs) + +class ParamikoBackgroundCommand(BackgroundCommand): + """ + :mod:`paramiko`-based background command. + """ + def __init__(self, conn, chan, pid, as_root, stdin, stdout, stderr, redirect_thread): + self.chan = chan + self.as_root = as_root + self.conn = conn + self._pid = pid + self._stdin = stdin + self._stdout = stdout + self._stderr = stderr + self.redirect_thread = redirect_thread + + def send_signal(self, sig): + # If the command has already completed, we don't want to send a signal + # to another process that might have gotten that PID in the meantime. + if self.poll() is not None: + return + # Use -PGID to target a process group rather than just the process + # itself + cmd = _kill_pgid_cmd(self.pid, sig) + self.conn.execute(cmd, as_root=self.as_root) + + @property + def pid(self): + return self._pid + + def wait(self): + return self.chan.recv_exit_status() + + def poll(self): + if self.chan.exit_status_ready(): + return self.wait() + else: + return None + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + self.send_signal(signal.SIGTERM) + # Check if the command terminated quickly + time.sleep(10e-3) + # Otherwise wait for the full timeout and kill it + if self.poll() is None: + time.sleep(kill_timeout) + self.send_signal(signal.SIGKILL) + self.wait() + + @property + def stdin(self): + return self._stdin + + @property + def stdout(self): + return self._stdout + + @property + def stderr(self): + return self._stderr + + def close(self): + for x in (self.stdin, self.stdout, self.stderr): + if x is not None: + x.close() + + exit_code = self.wait() + thread = self.redirect_thread + if thread: + thread.join() + + return exit_code + + +class AdbBackgroundCommand(BackgroundCommand): + """ + ``adb``-based background command. + """ + + def __init__(self, conn, adb_popen, pid, as_root): + self.conn = conn + self.as_root = as_root + self.adb_popen = adb_popen + self._pid = pid + + def send_signal(self, sig): + self.conn.execute( + _kill_pgid_cmd(self.pid, sig), + as_root=self.as_root, + ) + + @property + def stdin(self): + return self.adb_popen.stdin + + @property + def stdout(self): + return self.adb_popen.stdout + + @property + def stderr(self): + return self.adb_popen.stderr + + @property + def pid(self): + return self._pid + + def wait(self): + return self.adb_popen.wait() + + def poll(self): + return self.adb_popen.poll() + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + self.send_signal(signal.SIGTERM) + try: + self.adb_popen.wait(timeout=_KILL_TIMEOUT) + except subprocess.TimeoutExpired: + self.send_signal(signal.SIGKILL) + self.adb_popen.kill() + + def close(self): + self.adb_popen.__exit__(None, None, None) + return self.adb_popen.returncode + + def __enter__(self): + self.adb_popen.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.adb_popen.__exit__(*args, **kwargs) diff --git a/devlib/host.py b/devlib/host.py index 11f48b914..a694e5e3f 100644 --- a/devlib/host.py +++ b/devlib/host.py @@ -24,6 +24,7 @@ from devlib.exception import TargetTransientError, TargetStableError from devlib.utils.misc import check_output +from devlib.connection import ConnectionBase, PopenBackgroundCommand PACKAGE_BIN_DIRECTORY = os.path.join(os.path.dirname(__file__), 'bin') @@ -37,7 +38,7 @@ def kill_children(pid, signal=signal.SIGKILL): os.kill(cpid, signal) -class LocalConnection(object): +class LocalConnection(ConnectionBase): name = 'local' host = 'localhost' @@ -56,6 +57,7 @@ def connected_as_root(self, state): # pylint: disable=unused-argument def __init__(self, platform=None, keep_password=True, unrooted=False, password=None, timeout=None): + super().__init__() self._connected_as_root = None self.logger = logging.getLogger('local_connection') self.keep_password = keep_password @@ -105,9 +107,24 @@ def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as raise TargetStableError('unrooted') password = self._get_password() command = 'echo {} | sudo -S '.format(quote(password)) + command - return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) - def close(self): + # Make sure to get a new PGID so PopenBackgroundCommand() can kill + # all sub processes that could be started without troubles. + def preexec_fn(): + os.setpgrp() + + popen = subprocess.Popen( + command, + stdout=stdout, + stderr=stderr, + shell=True, + preexec_fn=preexec_fn, + ) + bg_cmd = PopenBackgroundCommand(popen) + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _close(self): pass def cancel_running_command(self): diff --git a/devlib/utils/android.py b/devlib/utils/android.py index 4c19b8fe3..6a6c3c1d4 100755 --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -30,6 +30,7 @@ import pexpect import xml.etree.ElementTree import zipfile +import uuid try: from shlex import quote @@ -37,7 +38,8 @@ from pipes import quote from devlib.exception import TargetTransientError, TargetStableError, HostError -from devlib.utils.misc import check_output, which, ABI_MAP +from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams +from devlib.connection import ConnectionBase, AdbBackgroundCommand logger = logging.getLogger('android') @@ -233,7 +235,7 @@ def _run(self, command): return output -class AdbConnection(object): +class AdbConnection(ConnectionBase): # maintains the count of parallel active connections to a device, so that # adb disconnect is not invoked untill all connections are closed @@ -263,6 +265,7 @@ def connected_as_root(self, state): # pylint: disable=unused-argument def __init__(self, device=None, timeout=None, platform=None, adb_server=None, adb_as_root=False): + super().__init__() self.timeout = timeout if timeout is not None else self.default_timeout if device is None: device = adb_get_device(timeout=timeout, adb_server=adb_server) @@ -312,13 +315,21 @@ def execute(self, command, timeout=None, check_exit_code=False, raise def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): - return adb_background_shell(self.device, command, stdout, stderr, as_root, adb_server=self.adb_server) - - def close(self): + adb_shell, pid = adb_background_shell(self, command, stdout, stderr, as_root) + bg_cmd = AdbBackgroundCommand( + conn=self, + adb_popen=adb_shell, + pid=pid, + as_root=as_root + ) + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _close(self): AdbConnection.active_connections[self.device] -= 1 if AdbConnection.active_connections[self.device] <= 0: if self.adb_as_root: - self.adb_root(self.device, enable=False) + self.adb_root(enable=False) adb_disconnect(self.device, self.adb_server) del AdbConnection.active_connections[self.device] @@ -536,20 +547,41 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, return output -def adb_background_shell(device, command, +def adb_background_shell(conn, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - as_root=False, - adb_server=None): + as_root=False): """Runs the sepcified command in a subprocess, returning the the Popen object.""" + device = conn.device + adb_server = conn.adb_server + _check_env() + stdout, stderr, command = redirect_streams(stdout, stderr, command) if as_root: command = 'echo {} | su'.format(quote(command)) + # Attach a unique UUID to the command line so it can be looked for without + # any ambiguity with ps + uuid_ = uuid.uuid4().hex + uuid_var = 'BACKGROUND_COMMAND_UUID={}'.format(uuid_) + command = "{} sh -c {}".format(uuid_var, quote(command)) + adb_cmd = get_adb_command(None, 'shell', adb_server) full_command = '{} {}'.format(adb_cmd, quote(command)) logger.debug(full_command) - return subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True) + p = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True) + + # Out of band PID lookup, to avoid conflicting needs with stdout redirection + find_pid = 'ps -A -o pid,args | grep {}'.format(quote(uuid_var)) + ps_out = conn.execute(find_pid) + pids = [ + int(line.strip().split(' ', 1)[0]) + for line in ps_out.splitlines() + ] + # The line we are looking for is the first one, since it was started before + # any look up command + pid = sorted(pids)[0] + return (p, pid) def adb_kill_server(timeout=30, adb_server=None): adb_command(None, 'kill-server', timeout, adb_server) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 2413cc722..c3254c7cc 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -20,7 +20,7 @@ """ from __future__ import division from contextlib import contextmanager -from functools import partial, reduce +from functools import partial, reduce, wraps from itertools import groupby from operator import itemgetter from weakref import WeakKeyDictionary, WeakSet @@ -854,3 +854,50 @@ def get_all_values(self): instance=self.instance, owner=self.owner, ) + + +class InitCheckpointMeta(type): + """ + Metaclass providing an ``initialized`` boolean attributes on instances. + + ``initialized`` is set to ``True`` once the ``__init__`` constructor has + returned. It will deal cleanly with nested calls to ``super().__init__``. + """ + def __new__(metacls, name, bases, dct, **kwargs): + cls = super().__new__(metacls, name, bases, dct, **kwargs) + init_f = cls.__init__ + + @wraps(init_f) + def init_wrapper(self, *args, **kwargs): + self.initialized = False + + # Track the nesting of super()__init__ to set initialized=True only + # when the outer level is finished + try: + stack = self._init_stack + except AttributeError: + stack = [] + self._init_stack = stack + + stack.append(init_f) + try: + x = init_f(self, *args, **kwargs) + finally: + stack.pop() + + if not stack: + self.initialized = True + del self._init_stack + + return x + + cls.__init__ = init_wrapper + + return cls + + +class InitCheckpoint(metaclass=InitCheckpointMeta): + """ + Inherit from this class to set the :class:`InitCheckpointMeta` metaclass. + """ + pass diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 056ee325a..b871e22bc 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -26,9 +26,18 @@ import sys import time import atexit +import contextlib +import weakref +import select +import copy from pipes import quote from future.utils import raise_from +from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy +import paramiko.ssh_exception +# By default paramiko is very verbose, including at the INFO level +logging.getLogger("paramiko").setLevel(logging.WARNING) + # pylint: disable=import-error,wrong-import-position,ungrouped-imports,wrong-import-order import pexpect from distutils.version import StrictVersion as V @@ -42,8 +51,9 @@ from devlib.exception import (HostError, TargetStableError, TargetNotRespondingError, TimeoutError, TargetTransientError) from devlib.utils.misc import (which, strip_bash_colors, check_output, - sanitize_cmd_template, memoized) + sanitize_cmd_template, memoized, redirect_streams) from devlib.utils.types import boolean +from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand ssh = None @@ -54,31 +64,113 @@ logger = logging.getLogger('ssh') gem5_logger = logging.getLogger('gem5-connection') -def ssh_get_shell(host, +@contextlib.contextmanager +def _handle_paramiko_exceptions(command=None): + try: + yield + except paramiko.ssh_exception.NoValidConnectionsError as e: + raise TargetNotRespondingError('Connection lost: {}'.format(e)) + except paramiko.ssh_exception.AuthenticationException as e: + raise TargetStableError('Could not authenticate: {}'.format(e)) + except paramiko.ssh_exception.BadAuthenticationType as e: + raise TargetStableError('Bad authentication type: {}'.format(e)) + except paramiko.ssh_exception.BadHostKeyException as e: + raise TargetStableError('Bad host key: {}'.format(e)) + except paramiko.ssh_exception.ChannelException as e: + raise TargetStableError('Could not open an SSH channel: {}'.format(e)) + except paramiko.ssh_exception.PasswordRequiredException as e: + raise TargetStableError('Please unlock the private key file: {}'.format(e)) + except paramiko.ssh_exception.ProxyCommandFailure as e: + raise TargetStableError('Proxy command failure: {}'.format(e)) + except paramiko.ssh_exception.SSHException as e: + raise TargetTransientError('SSH logic error: {}'.format(e)) + except socket.timeout: + raise TimeoutError(command, output=None) + + +def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk_size=int(1e42)): + try: + return _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size) + finally: + # Close the channel to make sure the remove process will receive + # SIGPIPE when writing on its streams. That could happen if the + # user closed the out_streams but the remote process has not + # finished yet. + assert stdout.channel is stderr.channel + stdout.channel.close() + + +def _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size): + channel = stdout.channel + assert stdout.channel is stderr.channel + + def read_channel(callback_state): + read_list, _, _ = select.select([channel], [], [], select_timeout) + for desc in read_list: + for ready, recv, name in ( + (desc.recv_ready(), desc.recv, 'stdout'), + (desc.recv_stderr_ready(), desc.recv_stderr, 'stderr') + ): + if ready: + chunk = recv(chunk_size) + if chunk: + try: + callback_state = callback(callback_state, name, chunk) + except Exception as e: + return (e, callback_state) + + return (None, callback_state) + + def read_all_channel(callback=None, callback_state=None): + for stream, name in ((stdout, 'stdout'), (stderr, 'stderr')): + try: + chunk = stream.read() + except Exception: + continue + + if callback is not None and chunk: + callback_state = callback(callback_state, name, chunk) + + return callback_state + + callback_excep = None + try: + callback_state = init + while not channel.exit_status_ready(): + callback_excep, callback_state = read_channel(callback_state) + if callback_excep is not None: + raise callback_excep + # Make sure to always empty the streams to unblock the remote process on + # the way to exit, in case something bad happened. For example, the + # callback could raise an exception to signal it does not want to do + # anything anymore, or only reading from one of the stream might have + # raised an exception, leaving the other one non-empty. + except Exception as e: + if callback_excep is None: + # Only call the callback if there was no exception originally, as + # we don't want to reenter it if it raised an exception + read_all_channel(callback, callback_state) + raise e + else: + # Finish emptying the buffers + callback_state = read_all_channel(callback, callback_state) + exit_code = channel.recv_exit_status() + return (callback_state, exit_code) + + +def telnet_get_shell(host, username, password=None, - keyfile=None, port=None, timeout=10, - telnet=False, - original_prompt=None, - options=None): + original_prompt=None): _check_env() start_time = time.time() while True: - if telnet: - if keyfile: - raise ValueError('keyfile may not be used with a telnet connection.') - conn = TelnetPxssh(original_prompt=original_prompt) - else: # ssh - conn = pxssh.pxssh(options=options, - echo=False) + conn = TelnetPxssh(original_prompt=original_prompt) try: - if keyfile: - conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout) - else: - conn.login(host, username, password, port=port, login_timeout=timeout) + conn.login(host, username, password, port=port, login_timeout=timeout) break except EOF: timeout -= time.time() - start_time @@ -157,10 +249,11 @@ def check_keyfile(keyfile): return keyfile -class SshConnection(object): +class SshConnectionBase(ConnectionBase): + """ + Base class for SSH connections. + """ - default_password_prompt = '[sudo] password' - max_cancel_attempts = 5 default_timeout = 10 @property @@ -170,51 +263,473 @@ def name(self): @property def connected_as_root(self): if self._connected_as_root is None: - # Execute directly to prevent deadlocking of connection - result = self._execute_and_wait_for_prompt('id', as_root=False) - self._connected_as_root = 'uid=0(' in result + try: + result = self.execute('id', as_root=False) + except TargetStableError: + is_root = False + else: + is_root = 'uid=0(' in result + self._connected_as_root = is_root return self._connected_as_root @connected_as_root.setter def connected_as_root(self, state): self._connected_as_root = state - # pylint: disable=unused-argument,super-init-not-called def __init__(self, host, username, password=None, keyfile=None, port=None, - timeout=None, - telnet=False, - password_prompt=None, - original_prompt=None, platform=None, - sudo_cmd="sudo -- sh -c {}", - options=None + sudo_cmd="sudo -S -- sh -c {}", + strict_host_check=True, ): + super().__init__() self._connected_as_root = None self.host = host self.username = username self.password = password self.keyfile = check_keyfile(keyfile) if keyfile else keyfile self.port = port + self.sudo_cmd = sanitize_cmd_template(sudo_cmd) + self.platform = platform + self.strict_host_check = strict_host_check + logger.debug('Logging in {}@{}'.format(username, host)) + + +class SshConnection(SshConnectionBase): + # pylint: disable=unused-argument,super-init-not-called + def __init__(self, + host, + username, + password=None, + keyfile=None, + port=22, + timeout=None, + platform=None, + sudo_cmd="sudo -S -- sh -c {}", + strict_host_check=True, + ): + + super().__init__( + host=host, + username=username, + password=password, + keyfile=keyfile, + port=port, + platform=platform, + sudo_cmd=sudo_cmd, + strict_host_check=strict_host_check, + ) + self.timeout = timeout if timeout is not None else self.default_timeout + + self.client = self._make_client() + atexit.register(self.close) + + # Use a marker in the output so that we will be able to differentiate + # target connection issues with "password needed". + # Also, sudo might not be installed at all on the target (but + # everything will work as long as we login as root). If sudo is still + # needed, it will explode when someone tries to use it. After all, the + # user might not be interested in being root at all. + self._sudo_needs_password = ( + 'NEED_PASSOWRD' in + self.execute( + # sudo -n is broken on some versions on MacOSX, revisit that if + # someone ever cares + 'sudo -n true || echo NEED_PASSWORD', + as_root=False, + check_exit_code=False, + ) + ) + + def _make_client(self): + if self.strict_host_check: + policy = RejectPolicy + else: + policy = AutoAddPolicy + + with _handle_paramiko_exceptions(): + client = SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(policy) + client.connect( + hostname=self.host, + port=self.port, + username=self.username, + password=self.password, + key_filename=self.keyfile, + timeout=self.timeout, + ) + + return client + + def _make_channel(self): + with _handle_paramiko_exceptions(): + transport = self.client.get_transport() + channel = transport.open_session() + return channel + + def _get_sftp(self, timeout): + sftp = self.client.open_sftp() + sftp.get_channel().settimeout(timeout) + return sftp + + @classmethod + def _push_file(cls, sftp, src, dst): + try: + sftp.put(src, dst) + # Maybe the dst was a folder + except OSError: + # This might fail if the folder already exists + with contextlib.suppress(IOError): + sftp.mkdir(dst) + + new_dst = os.path.join( + dst, + os.path.basename(src), + ) + + return cls._push_file(sftp, src, new_dst) + + + @classmethod + def _push_folder(cls, sftp, src, dst): + # Behave like the "mv" command or adb push: a new folder is created + # inside the destination folder, rather than merging the trees. + dst = os.path.join( + dst, + os.path.basename(src), + ) + return cls._push_folder_internal(sftp, src, dst) + + @classmethod + def _push_folder_internal(cls, sftp, src, dst): + # This might fail if the folder already exists + with contextlib.suppress(IOError): + sftp.mkdir(dst) + + for entry in os.scandir(src): + name = entry.name + src_path = os.path.join(src, name) + dst_path = os.path.join(dst, name) + if entry.is_dir(): + push = cls._push_folder_internal + else: + push = cls._push_file + + push(sftp, src_path, dst_path) + + @classmethod + def _push_path(cls, sftp, src, dst): + push = cls._push_folder if os.path.isdir(src) else cls._push_file + push(sftp, src, dst) + + @classmethod + def _pull_file(cls, sftp, src, dst): + # Pulling a file into a folder will use the source basename + if os.path.isdir(dst): + dst = os.path.join( + dst, + os.path.basename(src), + ) + + with contextlib.suppress(FileNotFoundError): + os.remove(dst) + + sftp.get(src, dst) + + @classmethod + def _pull_folder(cls, sftp, src, dst): + with contextlib.suppress(FileNotFoundError): + try: + shutil.rmtree(dst) + except OSError: + os.remove(dst) + + os.makedirs(dst) + for fileattr in sftp.listdir_attr(src): + filename = fileattr.filename + src_path = os.path.join(src, filename) + dst_path = os.path.join(dst, filename) + if stat.S_ISDIR(fileattr.st_mode): + pull = cls._pull_folder + else: + pull = cls._pull_file + + pull(sftp, src_path, dst_path) + + @classmethod + def _pull_path(cls, sftp, src, dst): + try: + cls._pull_file(sftp, src, dst) + except IOError: + # Maybe that was a directory, so retry as such + cls._pull_folder(sftp, src, dst) + + def push(self, source, dest, timeout=30): + with _handle_paramiko_exceptions(), self._get_sftp(timeout) as sftp: + self._push_path(sftp, source, dest) + + def pull(self, source, dest, timeout=30): + with _handle_paramiko_exceptions(), self._get_sftp(timeout) as sftp: + self._pull_path(sftp, source, dest) + + def execute(self, command, timeout=None, check_exit_code=True, + as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + if command == '': + return '' + try: + with _handle_paramiko_exceptions(command): + exit_code, output = self._execute(command, timeout, as_root, strip_colors) + except TargetStableError as e: + if will_succeed: + raise TargetTransientError(e) + else: + raise + else: + if check_exit_code and exit_code: + message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' + raise TargetStableError(message.format(exit_code, command, output)) + return output + + def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + with _handle_paramiko_exceptions(command): + bg_cmd = self._background(command, stdout, stderr, as_root) + + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _background(self, command, stdout, stderr, as_root): + stdout, stderr, command = redirect_streams(stdout, stderr, command) + + command = "printf '%s\n' $$; exec sh -c {}".format(quote(command)) + channel = self._make_channel() + + def executor(cmd, timeout): + channel.exec_command(cmd) + # Read are not buffered so we will always get the data as soon as + # they arrive + return ( + channel.makefile_stdin(), + channel.makefile(), + channel.makefile_stderr(), + ) + + stdin, stdout_in, stderr_in = self._execute_command( + command, + as_root=as_root, + log=False, + timeout=None, + executor=executor, + ) + pid = int(stdout_in.readline()) + + def create_out_stream(stream_in, stream_out): + """ + Create a pair of file-like objects. The first one is used to read + data and the second one to write. + """ + + if stream_out == subprocess.DEVNULL: + r, w = None, None + # When asked for a pipe, we just give the file-like object as the + # reading end and no writing end, since paramiko already writes to + # it + elif stream_out == subprocess.PIPE: + r, w = os.pipe() + r = os.fdopen(r, 'rb') + w = os.fdopen(w, 'wb') + # Turn a file descriptor into a file-like object + elif isinstance(stream_out, int) and stream_out >= 0: + r = os.fdopen(stream_out, 'rb') + w = os.fdopen(stream_out, 'wb') + # file-like object + else: + r = stream_out + w = stream_out + + return (r, w) + + out_streams = { + name: create_out_stream(stream_in, stream_out) + for stream_in, stream_out, name in ( + (stdout_in, stdout, 'stdout'), + (stderr_in, stderr, 'stderr'), + ) + } + + def redirect_thread_f(stdout_in, stderr_in, out_streams, select_timeout): + def callback(out_streams, name, chunk): + try: + r, w = out_streams[name] + except KeyError: + return out_streams + + try: + w.write(chunk) + # Write failed + except ValueError: + # Since that stream is now closed, stop trying to write to it + del out_streams[name] + # If that was the last open stream, we raise an + # exception so the thread can terminate. + if not out_streams: + raise + + return out_streams + + try: + _read_paramiko_streams(stdout_in, stderr_in, select_timeout, callback, copy.copy(out_streams)) + # The streams closed while we were writing to it, the job is done here + except ValueError: + pass + + # Make sure the writing end are closed proper since we are not + # going to write anything anymore + for r, w in out_streams.values(): + if r is not w and w is not None: + w.close() + + # If there is anything we need to redirect to, spawn a thread taking + # care of that + select_timeout = 1 + thread_out_streams = { + name: (r, w) + for name, (r, w) in out_streams.items() + if w is not None + } + redirect_thread = threading.Thread( + target=redirect_thread_f, + args=(stdout_in, stderr_in, thread_out_streams, select_timeout), + # The thread will die when the main thread dies + daemon=True, + ) + redirect_thread.start() + + return ParamikoBackgroundCommand( + conn=self, + as_root=as_root, + chan=channel, + pid=pid, + stdin=stdin, + # We give the reading end to the consumer of the data + stdout=out_streams['stdout'][0], + stderr=out_streams['stderr'][0], + redirect_thread=redirect_thread, + ) + + def _close(self): + logger.debug('Logging out {}@{}'.format(self.username, self.host)) + with _handle_paramiko_exceptions(): + bg_cmds = set(self._current_bg_cmds) + for bg_cmd in bg_cmds: + bg_cmd.close() + self.client.close() + + def _execute_command(self, command, as_root, log, timeout, executor): + # As we're already root, there is no need to use sudo. + log_debug = logger.debug if log else lambda msg: None + use_sudo = as_root and not self.connected_as_root + + if use_sudo: + if self._sudo_needs_password and not self.password: + raise TargetStableError('Attempt to use sudo but no password was specified') + + command = self.sudo_cmd.format(quote(command)) + + log_debug(command) + streams = executor(command, timeout=timeout) + if self._sudo_needs_password: + stdin = streams[0] + stdin.write(self.password + '\n') + stdin.flush() + else: + log_debug(command) + streams = executor(command, timeout=timeout) + + return streams + + def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + # Merge stderr into stdout since we are going without a TTY + command = '({}) 2>&1'.format(command) + + stdin, stdout, stderr = self._execute_command( + command, + as_root=as_root, + log=log, + timeout=timeout, + executor=self.client.exec_command, + ) + stdin.close() + + # Empty the stdout buffer of the command, allowing it to carry on to + # completion + def callback(output_chunks, name, chunk): + output_chunks.append(chunk) + return output_chunks + + select_timeout = 1 + output_chunks, exit_code = _read_paramiko_streams(stdout, stderr, select_timeout, callback, []) + # Join in one go to avoid O(N^2) concatenation + output = b''.join(output_chunks) + + if sys.version_info[0] == 3: + output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + if strip_colors: + output = strip_bash_colors(output) + + return (exit_code, output) + + +class TelnetConnection(SshConnectionBase): + + default_password_prompt = '[sudo] password' + max_cancel_attempts = 5 + + # pylint: disable=unused-argument,super-init-not-called + def __init__(self, + host, + username, + password=None, + port=None, + timeout=None, + password_prompt=None, + original_prompt=None, + sudo_cmd="sudo -- sh -c {}", + strict_host_check=True, + platform=None): + + super().__init__( + host=host, + username=username, + password=password, + keyfile=None, + port=port, + platform=platform, + sudo_cmd=sudo_cmd, + strict_host_check=strict_host_check, + ) + + if self.strict_host_check: + options = { + 'StrictHostKeyChecking': 'yes', + } + else: + options = { + 'StrictHostKeyChecking': 'no', + 'UserKnownHostsFile': '/dev/null', + } + self.options = options + self.lock = threading.Lock() self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt - self.sudo_cmd = sanitize_cmd_template(sudo_cmd) logger.debug('Logging in {}@{}'.format(username, host)) timeout = timeout if timeout is not None else self.default_timeout - self.options = options if options is not None else {} - self.conn = ssh_get_shell(host, - username, - password, - self.keyfile, - port, - timeout, - False, - None, - self.options) + + self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) atexit.register(self.close) def push(self, source, dest, timeout=30): @@ -282,7 +797,7 @@ def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as except EOF: raise TargetNotRespondingError('Connection lost.') - def close(self): + def _close(self): logger.debug('Logging out {}@{}'.format(self.username, self.host)) try: self.conn.logout() @@ -351,8 +866,8 @@ def _scp(self, source, dest, timeout=30): # only specify -P for scp if the port is *not* the default. port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' - options = " ".join(["-o {}={}".format(key,val) - for key,val in self.options.items()]) + options = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) command = '{} {} -r {} {} {} {}'.format(scp, options, keyfile_string, @@ -387,29 +902,6 @@ def _get_prompt_length(self): def _get_window_size(self): return self.conn.getwinsize() -class TelnetConnection(SshConnection): - - # pylint: disable=super-init-not-called - def __init__(self, - host, - username, - password=None, - port=None, - timeout=None, - password_prompt=None, - original_prompt=None, - platform=None): - self.host = host - self.username = username - self.password = password - self.port = port - self.keyfile = None - self.lock = threading.Lock() - self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt - logger.debug('Logging in {}@{}'.format(username, host)) - timeout = timeout if timeout is not None else self.default_timeout - self.conn = ssh_get_shell(host, username, password, None, port, timeout, True, original_prompt) - class Gem5Connection(TelnetConnection): @@ -616,7 +1108,7 @@ def background(self, command, stdout=subprocess.PIPE, 'get this file'.format(redirection_file)) return output - def close(self): + def _close(self): """ Close and disconnect from the gem5 simulation. Additionally, we remove the temporary directory used to pass files into the simulation. diff --git a/setup.py b/setup.py index fda27a843..ca30eb178 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ 'python-dateutil', # converting between UTC and local time. 'pexpect>=3.3', # Send/recieve to/from device 'pyserial', # Serial port interface + 'paramiko', # SSH connection 'wrapt', # Basic for construction of decorator functions 'future', # Python 2-3 compatibility 'enum34;python_version<"3.4"', # Enums for Python < 3.4 From 30e0ead3ea41703faf8c65112e8ba252c718d3fb Mon Sep 17 00:00:00 2001 From: Douglas RAILLARD Date: Fri, 17 Jan 2020 17:47:24 +0000 Subject: [PATCH 6/6] target: Check that the connection works cleanly upon connection Check that executing the most basic command works without troubles or stderr content. If that's not the case, raise a TargetStableError. --- devlib/target.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/devlib/target.py b/devlib/target.py index 2b8296cd6..da712de05 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -292,6 +292,7 @@ def connect(self, timeout=None, check_boot_completed=True): self.conn = self.get_connection(timeout=timeout) if check_boot_completed: self.wait_boot_complete(timeout) + self.check_connection() self._resolve_paths() self.execute('mkdir -p {}'.format(quote(self.working_directory))) self.execute('mkdir -p {}'.format(quote(self.executables_directory))) @@ -301,6 +302,14 @@ def connect(self, timeout=None, check_boot_completed=True): if self.platform.big_core and self.load_default_modules: self._install_module(get_module('bl')) + def check_connection(self): + """ + Check that the connection works without obvious issues. + """ + out = self.execute('true', as_root=False) + if out.strip(): + raise TargetStableError('The shell seems to not be functional and adds content to stderr: {}'.format(out)) + def disconnect(self): connections = self._conn.get_all_values() for conn in connections: