diff --git a/devlib/connection.py b/devlib/connection.py new file mode 100644 index 000000000..e6d67c0af --- /dev/null +++ b/devlib/connection.py @@ -0,0 +1,351 @@ +# Copyright 2019 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import subprocess +import signal +import threading +from weakref import WeakSet +from abc import ABC, abstractmethod + +from devlib.utils.misc import InitCheckpoint + +_KILL_TIMEOUT = 3 + + +def _kill_pgid_cmd(pgid, sig): + return 'kill -{} -{}'.format(sig.name, pgid) + + +class ConnectionBase(InitCheckpoint): + """ + Base class for all connections. + """ + def __init__(self): + self._current_bg_cmds = WeakSet() + self._closed = False + self._close_lock = threading.Lock() + + def cancel_running_command(self): + bg_cmds = set(self._current_bg_cmds) + for bg_cmd in bg_cmds: + bg_cmd.cancel() + + @abstractmethod + def _close(self): + """ + Close the connection. + + The public :meth:`close` method makes sure that :meth:`_close` will + only be called once, and will serialize accesses to it if it happens to + be called from multiple threads at once. + """ + + def close(self): + # Locking the closing allows any thread to safely call close() as long + # as the connection can be closed from a thread that is not the one it + # started its life in. + with self._close_lock: + if not self._closed: + self._close() + self._closed = True + + # Ideally, that should not be relied upon but that will improve the chances + # of the connection being properly cleaned up when it's not in use anymore. + def __del__(self): + # Since __del__ will be called if an exception is raised in __init__ + # (e.g. we cannot connect), we only run close() when we are sure + # __init__ has completed successfully. + if self.initialized: + self.close() + + +class BackgroundCommand(ABC): + """ + Allows managing a running background command using a subset of the + :class:`subprocess.Popen` API. + + Instances of this class can be used as context managers, with the same + semantic as :class:`subprocess.Popen`. + """ + @abstractmethod + def send_signal(self, sig): + """ + Send a POSIX signal to the background command's process group ID + (PGID). + + :param signal: Signal to send. + :type signal: signal.Signals + """ + + def kill(self): + """ + Send SIGKILL to the background command. + """ + self.send_signal(signal.SIGKILL) + + @abstractmethod + def cancel(self, kill_timeout=_KILL_TIMEOUT): + """ + Try to gracefully terminate the process by sending ``SIGTERM``, then + waiting for ``kill_timeout`` to send ``SIGKILL``. + """ + + @abstractmethod + def wait(self): + """ + Block until the background command completes, and return its exit code. + """ + + @abstractmethod + def poll(self): + """ + Return exit code if the command has exited, None otherwise. + """ + + @property + @abstractmethod + def stdin(self): + """ + File-like object connected to the background's command stdin. + """ + + @property + @abstractmethod + def stdout(self): + """ + File-like object connected to the background's command stdout. + """ + + @property + @abstractmethod + def stderr(self): + """ + File-like object connected to the background's command stderr. + """ + + @property + @abstractmethod + def pid(self): + """ + Process Group ID (PGID) of the background command. + + Since the command is usually wrapped in shell processes for IO + redirections, sudo etc, the PID cannot be assumed to be the actual PID + of the command passed by the user. It's is guaranteed to be a PGID + instead, which means signals sent to it as such will target all + subprocesses involved in executing that command. + """ + + @abstractmethod + def close(self): + """ + Close all opened streams and then wait for command completion. + + :returns: Exit code of the command. + + .. note:: If the command is writing to its stdout/stderr, it might be + blocked on that and die when the streams are closed. + """ + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class PopenBackgroundCommand(BackgroundCommand): + """ + :class:`subprocess.Popen`-based background command. + """ + + def __init__(self, popen): + self.popen = popen + + def send_signal(self, sig): + return os.killpg(self.popen.pid, sig) + + @property + def stdin(self): + return self.popen.stdin + + @property + def stdout(self): + return self.popen.stdout + + @property + def stderr(self): + return self.popen.stderr + + @property + def pid(self): + return self.popen.pid + + def wait(self): + return self.popen.wait() + + def poll(self): + return self.popen.poll() + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + popen = self.popen + popen.send_signal(signal.SIGTERM) + try: + popen.wait(timeout=_KILL_TIMEOUT) + except subprocess.TimeoutExpired: + popen.kill() + + def close(self): + self.popen.__exit__(None, None, None) + return self.popen.returncode + + def __enter__(self): + self.popen.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.popen.__exit__(*args, **kwargs) + +class ParamikoBackgroundCommand(BackgroundCommand): + """ + :mod:`paramiko`-based background command. + """ + def __init__(self, conn, chan, pid, as_root, stdin, stdout, stderr, redirect_thread): + self.chan = chan + self.as_root = as_root + self.conn = conn + self._pid = pid + self._stdin = stdin + self._stdout = stdout + self._stderr = stderr + self.redirect_thread = redirect_thread + + def send_signal(self, sig): + # If the command has already completed, we don't want to send a signal + # to another process that might have gotten that PID in the meantime. + if self.poll() is not None: + return + # Use -PGID to target a process group rather than just the process + # itself + cmd = _kill_pgid_cmd(self.pid, sig) + self.conn.execute(cmd, as_root=self.as_root) + + @property + def pid(self): + return self._pid + + def wait(self): + return self.chan.recv_exit_status() + + def poll(self): + if self.chan.exit_status_ready(): + return self.wait() + else: + return None + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + self.send_signal(signal.SIGTERM) + # Check if the command terminated quickly + time.sleep(10e-3) + # Otherwise wait for the full timeout and kill it + if self.poll() is None: + time.sleep(kill_timeout) + self.send_signal(signal.SIGKILL) + self.wait() + + @property + def stdin(self): + return self._stdin + + @property + def stdout(self): + return self._stdout + + @property + def stderr(self): + return self._stderr + + def close(self): + for x in (self.stdin, self.stdout, self.stderr): + if x is not None: + x.close() + + exit_code = self.wait() + thread = self.redirect_thread + if thread: + thread.join() + + return exit_code + + +class AdbBackgroundCommand(BackgroundCommand): + """ + ``adb``-based background command. + """ + + def __init__(self, conn, adb_popen, pid, as_root): + self.conn = conn + self.as_root = as_root + self.adb_popen = adb_popen + self._pid = pid + + def send_signal(self, sig): + self.conn.execute( + _kill_pgid_cmd(self.pid, sig), + as_root=self.as_root, + ) + + @property + def stdin(self): + return self.adb_popen.stdin + + @property + def stdout(self): + return self.adb_popen.stdout + + @property + def stderr(self): + return self.adb_popen.stderr + + @property + def pid(self): + return self._pid + + def wait(self): + return self.adb_popen.wait() + + def poll(self): + return self.adb_popen.poll() + + def cancel(self, kill_timeout=_KILL_TIMEOUT): + self.send_signal(signal.SIGTERM) + try: + self.adb_popen.wait(timeout=_KILL_TIMEOUT) + except subprocess.TimeoutExpired: + self.send_signal(signal.SIGKILL) + self.adb_popen.kill() + + def close(self): + self.adb_popen.__exit__(None, None, None) + return self.adb_popen.returncode + + def __enter__(self): + self.adb_popen.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.adb_popen.__exit__(*args, **kwargs) diff --git a/devlib/host.py b/devlib/host.py index 11f48b914..a694e5e3f 100644 --- a/devlib/host.py +++ b/devlib/host.py @@ -24,6 +24,7 @@ from devlib.exception import TargetTransientError, TargetStableError from devlib.utils.misc import check_output +from devlib.connection import ConnectionBase, PopenBackgroundCommand PACKAGE_BIN_DIRECTORY = os.path.join(os.path.dirname(__file__), 'bin') @@ -37,7 +38,7 @@ def kill_children(pid, signal=signal.SIGKILL): os.kill(cpid, signal) -class LocalConnection(object): +class LocalConnection(ConnectionBase): name = 'local' host = 'localhost' @@ -56,6 +57,7 @@ def connected_as_root(self, state): # pylint: disable=unused-argument def __init__(self, platform=None, keep_password=True, unrooted=False, password=None, timeout=None): + super().__init__() self._connected_as_root = None self.logger = logging.getLogger('local_connection') self.keep_password = keep_password @@ -105,9 +107,24 @@ def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as raise TargetStableError('unrooted') password = self._get_password() command = 'echo {} | sudo -S '.format(quote(password)) + command - return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) - def close(self): + # Make sure to get a new PGID so PopenBackgroundCommand() can kill + # all sub processes that could be started without troubles. + def preexec_fn(): + os.setpgrp() + + popen = subprocess.Popen( + command, + stdout=stdout, + stderr=stderr, + shell=True, + preexec_fn=preexec_fn, + ) + bg_cmd = PopenBackgroundCommand(popen) + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _close(self): pass def cancel_running_command(self): diff --git a/devlib/target.py b/devlib/target.py index e3765a389..da712de05 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -53,7 +53,7 @@ from devlib.utils.misc import memoized, isiterable, convert_new_lines from devlib.utils.misc import commonprefix, merge_lists from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list -from devlib.utils.misc import batch_contextmanager +from devlib.utils.misc import batch_contextmanager, tls_property from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex @@ -214,22 +214,19 @@ def page_size_kb(self): cmd = "cat /proc/self/smaps | {0} grep KernelPageSize | {0} head -n 1 | {0} awk '{{ print $2 }}'" return int(self.execute(cmd.format(self.busybox))) - @property - def conn(self): - if self._connections: - tid = id(threading.current_thread()) - if tid not in self._connections: - self._connections[tid] = self.get_connection() - return self._connections[tid] - else: - return None - @property def shutils(self): if self._shutils is None: self._setup_shutils() return self._shutils + @tls_property + def _conn(self): + return self.get_connection() + + # Add a basic property that does not require calling to get the value + conn = _conn.basic_property + def __init__(self, connection_settings=None, platform=None, @@ -242,6 +239,7 @@ def __init__(self, conn_cls=None, is_container=False ): + self._is_rooted = None self.connection_settings = connection_settings or {} # Set self.platform: either it's given directly (by platform argument) @@ -271,7 +269,6 @@ def __init__(self, self._installed_binaries = {} self._installed_modules = {} self._cache = {} - self._connections = {} self._shutils = None self._file_transfer_cache = None self.busybox = None @@ -290,10 +287,12 @@ def __init__(self, def connect(self, timeout=None, check_boot_completed=True): self.platform.init_target_connection(self) - tid = id(threading.current_thread()) - self._connections[tid] = self.get_connection(timeout=timeout) + # Forcefully set the thread-local value for the connection, with the + # timeout we want + self.conn = self.get_connection(timeout=timeout) if check_boot_completed: self.wait_boot_complete(timeout) + self.check_connection() self._resolve_paths() self.execute('mkdir -p {}'.format(quote(self.working_directory))) self.execute('mkdir -p {}'.format(quote(self.executables_directory))) @@ -303,10 +302,18 @@ def connect(self, timeout=None, check_boot_completed=True): if self.platform.big_core and self.load_default_modules: self._install_module(get_module('bl')) + def check_connection(self): + """ + Check that the connection works without obvious issues. + """ + out = self.execute('true', as_root=False) + if out.strip(): + raise TargetStableError('The shell seems to not be functional and adds content to stderr: {}'.format(out)) + def disconnect(self): - for conn in self._connections.values(): + connections = self._conn.get_all_values() + for conn in connections: conn.close() - self._connections = {} def get_connection(self, timeout=None): if self.conn_cls is None: diff --git a/devlib/utils/android.py b/devlib/utils/android.py index 4c19b8fe3..6a6c3c1d4 100755 --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -30,6 +30,7 @@ import pexpect import xml.etree.ElementTree import zipfile +import uuid try: from shlex import quote @@ -37,7 +38,8 @@ from pipes import quote from devlib.exception import TargetTransientError, TargetStableError, HostError -from devlib.utils.misc import check_output, which, ABI_MAP +from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams +from devlib.connection import ConnectionBase, AdbBackgroundCommand logger = logging.getLogger('android') @@ -233,7 +235,7 @@ def _run(self, command): return output -class AdbConnection(object): +class AdbConnection(ConnectionBase): # maintains the count of parallel active connections to a device, so that # adb disconnect is not invoked untill all connections are closed @@ -263,6 +265,7 @@ def connected_as_root(self, state): # pylint: disable=unused-argument def __init__(self, device=None, timeout=None, platform=None, adb_server=None, adb_as_root=False): + super().__init__() self.timeout = timeout if timeout is not None else self.default_timeout if device is None: device = adb_get_device(timeout=timeout, adb_server=adb_server) @@ -312,13 +315,21 @@ def execute(self, command, timeout=None, check_exit_code=False, raise def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): - return adb_background_shell(self.device, command, stdout, stderr, as_root, adb_server=self.adb_server) - - def close(self): + adb_shell, pid = adb_background_shell(self, command, stdout, stderr, as_root) + bg_cmd = AdbBackgroundCommand( + conn=self, + adb_popen=adb_shell, + pid=pid, + as_root=as_root + ) + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _close(self): AdbConnection.active_connections[self.device] -= 1 if AdbConnection.active_connections[self.device] <= 0: if self.adb_as_root: - self.adb_root(self.device, enable=False) + self.adb_root(enable=False) adb_disconnect(self.device, self.adb_server) del AdbConnection.active_connections[self.device] @@ -536,20 +547,41 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, return output -def adb_background_shell(device, command, +def adb_background_shell(conn, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - as_root=False, - adb_server=None): + as_root=False): """Runs the sepcified command in a subprocess, returning the the Popen object.""" + device = conn.device + adb_server = conn.adb_server + _check_env() + stdout, stderr, command = redirect_streams(stdout, stderr, command) if as_root: command = 'echo {} | su'.format(quote(command)) + # Attach a unique UUID to the command line so it can be looked for without + # any ambiguity with ps + uuid_ = uuid.uuid4().hex + uuid_var = 'BACKGROUND_COMMAND_UUID={}'.format(uuid_) + command = "{} sh -c {}".format(uuid_var, quote(command)) + adb_cmd = get_adb_command(None, 'shell', adb_server) full_command = '{} {}'.format(adb_cmd, quote(command)) logger.debug(full_command) - return subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True) + p = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True) + + # Out of band PID lookup, to avoid conflicting needs with stdout redirection + find_pid = 'ps -A -o pid,args | grep {}'.format(quote(uuid_var)) + ps_out = conn.execute(find_pid) + pids = [ + int(line.strip().split(' ', 1)[0]) + for line in ps_out.splitlines() + ] + # The line we are looking for is the first one, since it was started before + # any look up command + pid = sorted(pids)[0] + return (p, pid) def adb_kill_server(timeout=30, adb_server=None): adb_command(None, 'kill-server', timeout, adb_server) diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 4d488b678..c3254c7cc 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -20,9 +20,10 @@ """ from __future__ import division from contextlib import contextmanager -from functools import partial, reduce +from functools import partial, reduce, wraps from itertools import groupby from operator import itemgetter +from weakref import WeakKeyDictionary, WeakSet import ctypes import functools @@ -45,6 +46,11 @@ except AttributeError: from contextlib2 import ExitStack +try: + from shlex import quote +except ImportError: + from pipes import quote + from past.builtins import basestring # pylint: disable=redefined-builtin @@ -136,9 +142,6 @@ def get_cpu_name(implementer, part, variant): def preexec_function(): - # Ignore the SIGINT signal by setting the handler to the standard - # signal handler SIG_IGN. - signal.signal(signal.SIGINT, signal.SIG_IGN) # Change process group in case we have to kill the subprocess and all of # its children later. # TODO: this is Unix-specific; would be good to find an OS-agnostic way @@ -167,13 +170,6 @@ def check_output(command, timeout=None, ignore=None, inputtext=None, if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') - def callback(pid): - try: - check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid)) - os.killpg(pid, signal.SIGKILL) - except OSError: - pass # process may have already terminated. - with check_output_lock: stderr = subprocess.STDOUT if combined_output else subprocess.PIPE process = subprocess.Popen(command, @@ -183,27 +179,24 @@ def callback(pid): preexec_fn=preexec_function, **kwargs) - if timeout: - timer = threading.Timer(timeout, callback, [process.pid, ]) - timer.start() - try: - output, error = process.communicate(inputtext) - if sys.version_info[0] == 3: - # Currently errors=replace is needed as 0x8c throws an error - output = output.decode(sys.stdout.encoding or 'utf-8', "replace") - if error: - error = error.decode(sys.stderr.encoding or 'utf-8', "replace") - finally: - if timeout: - timer.cancel() + output, error = process.communicate(inputtext, timeout=timeout) + except subprocess.TimeoutExpired as e: + timeout_expired = e + else: + timeout_expired = None + + # Currently errors=replace is needed as 0x8c throws an error + output = output.decode(sys.stdout.encoding or 'utf-8', "replace") + if error: + error = error.decode(sys.stderr.encoding or 'utf-8', "replace") + + if timeout_expired: + raise TimeoutError(command, output='\n'.join([output or '', error or ''])) retcode = process.poll() - if retcode: - if retcode == -9: # killed, assume due to timeout callback - raise TimeoutError(command, output='\n'.join([output or '', error or ''])) - elif ignore != 'all' and retcode not in ignore: - raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output or '', error or ''])) + if retcode and ignore != 'all' and retcode not in ignore: + raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output or '', error or ''])) return output, error @@ -244,6 +237,32 @@ def __try_import(path): mods.append(submod) return mods +def redirect_streams(stdout, stderr, command): + """ + Update a command to redirect a given stream to /dev/null if it's + ``subprocess.DEVNULL``. + + :return: A tuple (stdout, stderr, command) with stream set to ``subprocess.PIPE`` + if the `stream` parameter was set to ``subprocess.DEVNULL``. + """ + def redirect(stream, redirection): + if stream == subprocess.DEVNULL: + suffix = '{}/dev/null'.format(redirection) + elif stream == subprocess.STDOUT: + suffix = '{}&1'.format(redirection) + # Indicate that there is nothing to monitor for stderr anymore + # since it's merged into stdout + stream = subprocess.DEVNULL + else: + suffix = '' + + return (stream, suffix) + + stdout, suffix1 = redirect(stdout, '>') + stderr, suffix2 = redirect(stderr, '2>') + + command = 'sh -c {} {} {}'.format(quote(command), suffix1, suffix2) + return (stdout, stderr, command) def ensure_directory_exists(dirpath): """A filter for directory paths to ensure they exist.""" @@ -718,3 +737,167 @@ def batch_contextmanager(f, kwargs_list): for kwargs in kwargs_list: stack.enter_context(f(**kwargs)) yield + +class tls_property: + """ + Use it like `property` decorator, but the result will be memoized per + thread. When the owning thread dies, the values for that thread will be + destroyed. + + In order to get the values, it's necessary to call the object + given by the property. This is necessary in order to be able to add methods + to that object, like :meth:`_BoundTLSProperty.get_all_values`. + + Values can be set and deleted as well, which will be a thread-local set. + """ + + @property + def name(self): + return self.factory.__name__ + + def __init__(self, factory): + self.factory = factory + # Lock accesses to shared WeakKeyDictionary and WeakSet + self.lock = threading.Lock() + + def __get__(self, instance, owner=None): + return _BoundTLSProperty(self, instance, owner) + + def _get_value(self, instance, owner): + tls, values = self._get_tls(instance) + try: + return tls.value + except AttributeError: + # Bind the method to `instance` + f = self.factory.__get__(instance, owner) + obj = f() + tls.value = obj + # Since that's a WeakSet, values will be removed automatically once + # the threading.local variable that holds them is destroyed + with self.lock: + values.add(obj) + return obj + + def _get_all_values(self, instance, owner): + with self.lock: + # Grab a reference to all the objects at the time of the call by + # using a regular set + tls, values = self._get_tls(instance=instance) + return set(values) + + def __set__(self, instance, value): + tls, values = self._get_tls(instance) + tls.value = value + with self.lock: + values.add(value) + + def __delete__(self, instance): + tls, values = self._get_tls(instance) + with self.lock: + values.discard(tls.value) + del tls.value + + def _get_tls(self, instance): + dct = instance.__dict__ + name = self.name + try: + # Using instance.__dict__[self.name] is safe as + # getattr(instance, name) will return the property instead, as + # the property is a descriptor + tls = dct[name] + except KeyError: + with self.lock: + # Double check after taking the lock to avoid a race + if name not in dct: + tls = (threading.local(), WeakSet()) + dct[name] = tls + + return tls + + @property + def basic_property(self): + """ + Return a basic property that can be used to access the TLS value + without having to call it first. + + The drawback is that it's not possible to do anything over than + getting/setting/deleting. + """ + def getter(instance, owner=None): + prop = self.__get__(instance, owner) + return prop() + + return property(getter, self.__set__, self.__delete__) + +class _BoundTLSProperty: + """ + Simple proxy object to allow either calling it to get the TLS value, or get + some other informations by calling methods. + """ + def __init__(self, tls_property, instance, owner): + self.tls_property = tls_property + self.instance = instance + self.owner = owner + + def __call__(self): + return self.tls_property._get_value( + instance=self.instance, + owner=self.owner, + ) + + def get_all_values(self): + """ + Returns all the thread-local values currently in use in the process for + that property for that instance. + """ + return self.tls_property._get_all_values( + instance=self.instance, + owner=self.owner, + ) + + +class InitCheckpointMeta(type): + """ + Metaclass providing an ``initialized`` boolean attributes on instances. + + ``initialized`` is set to ``True`` once the ``__init__`` constructor has + returned. It will deal cleanly with nested calls to ``super().__init__``. + """ + def __new__(metacls, name, bases, dct, **kwargs): + cls = super().__new__(metacls, name, bases, dct, **kwargs) + init_f = cls.__init__ + + @wraps(init_f) + def init_wrapper(self, *args, **kwargs): + self.initialized = False + + # Track the nesting of super()__init__ to set initialized=True only + # when the outer level is finished + try: + stack = self._init_stack + except AttributeError: + stack = [] + self._init_stack = stack + + stack.append(init_f) + try: + x = init_f(self, *args, **kwargs) + finally: + stack.pop() + + if not stack: + self.initialized = True + del self._init_stack + + return x + + cls.__init__ = init_wrapper + + return cls + + +class InitCheckpoint(metaclass=InitCheckpointMeta): + """ + Inherit from this class to set the :class:`InitCheckpointMeta` metaclass. + """ + pass diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 056ee325a..b871e22bc 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -26,9 +26,18 @@ import sys import time import atexit +import contextlib +import weakref +import select +import copy from pipes import quote from future.utils import raise_from +from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy +import paramiko.ssh_exception +# By default paramiko is very verbose, including at the INFO level +logging.getLogger("paramiko").setLevel(logging.WARNING) + # pylint: disable=import-error,wrong-import-position,ungrouped-imports,wrong-import-order import pexpect from distutils.version import StrictVersion as V @@ -42,8 +51,9 @@ from devlib.exception import (HostError, TargetStableError, TargetNotRespondingError, TimeoutError, TargetTransientError) from devlib.utils.misc import (which, strip_bash_colors, check_output, - sanitize_cmd_template, memoized) + sanitize_cmd_template, memoized, redirect_streams) from devlib.utils.types import boolean +from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand ssh = None @@ -54,31 +64,113 @@ logger = logging.getLogger('ssh') gem5_logger = logging.getLogger('gem5-connection') -def ssh_get_shell(host, +@contextlib.contextmanager +def _handle_paramiko_exceptions(command=None): + try: + yield + except paramiko.ssh_exception.NoValidConnectionsError as e: + raise TargetNotRespondingError('Connection lost: {}'.format(e)) + except paramiko.ssh_exception.AuthenticationException as e: + raise TargetStableError('Could not authenticate: {}'.format(e)) + except paramiko.ssh_exception.BadAuthenticationType as e: + raise TargetStableError('Bad authentication type: {}'.format(e)) + except paramiko.ssh_exception.BadHostKeyException as e: + raise TargetStableError('Bad host key: {}'.format(e)) + except paramiko.ssh_exception.ChannelException as e: + raise TargetStableError('Could not open an SSH channel: {}'.format(e)) + except paramiko.ssh_exception.PasswordRequiredException as e: + raise TargetStableError('Please unlock the private key file: {}'.format(e)) + except paramiko.ssh_exception.ProxyCommandFailure as e: + raise TargetStableError('Proxy command failure: {}'.format(e)) + except paramiko.ssh_exception.SSHException as e: + raise TargetTransientError('SSH logic error: {}'.format(e)) + except socket.timeout: + raise TimeoutError(command, output=None) + + +def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk_size=int(1e42)): + try: + return _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size) + finally: + # Close the channel to make sure the remove process will receive + # SIGPIPE when writing on its streams. That could happen if the + # user closed the out_streams but the remote process has not + # finished yet. + assert stdout.channel is stderr.channel + stdout.channel.close() + + +def _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size): + channel = stdout.channel + assert stdout.channel is stderr.channel + + def read_channel(callback_state): + read_list, _, _ = select.select([channel], [], [], select_timeout) + for desc in read_list: + for ready, recv, name in ( + (desc.recv_ready(), desc.recv, 'stdout'), + (desc.recv_stderr_ready(), desc.recv_stderr, 'stderr') + ): + if ready: + chunk = recv(chunk_size) + if chunk: + try: + callback_state = callback(callback_state, name, chunk) + except Exception as e: + return (e, callback_state) + + return (None, callback_state) + + def read_all_channel(callback=None, callback_state=None): + for stream, name in ((stdout, 'stdout'), (stderr, 'stderr')): + try: + chunk = stream.read() + except Exception: + continue + + if callback is not None and chunk: + callback_state = callback(callback_state, name, chunk) + + return callback_state + + callback_excep = None + try: + callback_state = init + while not channel.exit_status_ready(): + callback_excep, callback_state = read_channel(callback_state) + if callback_excep is not None: + raise callback_excep + # Make sure to always empty the streams to unblock the remote process on + # the way to exit, in case something bad happened. For example, the + # callback could raise an exception to signal it does not want to do + # anything anymore, or only reading from one of the stream might have + # raised an exception, leaving the other one non-empty. + except Exception as e: + if callback_excep is None: + # Only call the callback if there was no exception originally, as + # we don't want to reenter it if it raised an exception + read_all_channel(callback, callback_state) + raise e + else: + # Finish emptying the buffers + callback_state = read_all_channel(callback, callback_state) + exit_code = channel.recv_exit_status() + return (callback_state, exit_code) + + +def telnet_get_shell(host, username, password=None, - keyfile=None, port=None, timeout=10, - telnet=False, - original_prompt=None, - options=None): + original_prompt=None): _check_env() start_time = time.time() while True: - if telnet: - if keyfile: - raise ValueError('keyfile may not be used with a telnet connection.') - conn = TelnetPxssh(original_prompt=original_prompt) - else: # ssh - conn = pxssh.pxssh(options=options, - echo=False) + conn = TelnetPxssh(original_prompt=original_prompt) try: - if keyfile: - conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout) - else: - conn.login(host, username, password, port=port, login_timeout=timeout) + conn.login(host, username, password, port=port, login_timeout=timeout) break except EOF: timeout -= time.time() - start_time @@ -157,10 +249,11 @@ def check_keyfile(keyfile): return keyfile -class SshConnection(object): +class SshConnectionBase(ConnectionBase): + """ + Base class for SSH connections. + """ - default_password_prompt = '[sudo] password' - max_cancel_attempts = 5 default_timeout = 10 @property @@ -170,51 +263,473 @@ def name(self): @property def connected_as_root(self): if self._connected_as_root is None: - # Execute directly to prevent deadlocking of connection - result = self._execute_and_wait_for_prompt('id', as_root=False) - self._connected_as_root = 'uid=0(' in result + try: + result = self.execute('id', as_root=False) + except TargetStableError: + is_root = False + else: + is_root = 'uid=0(' in result + self._connected_as_root = is_root return self._connected_as_root @connected_as_root.setter def connected_as_root(self, state): self._connected_as_root = state - # pylint: disable=unused-argument,super-init-not-called def __init__(self, host, username, password=None, keyfile=None, port=None, - timeout=None, - telnet=False, - password_prompt=None, - original_prompt=None, platform=None, - sudo_cmd="sudo -- sh -c {}", - options=None + sudo_cmd="sudo -S -- sh -c {}", + strict_host_check=True, ): + super().__init__() self._connected_as_root = None self.host = host self.username = username self.password = password self.keyfile = check_keyfile(keyfile) if keyfile else keyfile self.port = port + self.sudo_cmd = sanitize_cmd_template(sudo_cmd) + self.platform = platform + self.strict_host_check = strict_host_check + logger.debug('Logging in {}@{}'.format(username, host)) + + +class SshConnection(SshConnectionBase): + # pylint: disable=unused-argument,super-init-not-called + def __init__(self, + host, + username, + password=None, + keyfile=None, + port=22, + timeout=None, + platform=None, + sudo_cmd="sudo -S -- sh -c {}", + strict_host_check=True, + ): + + super().__init__( + host=host, + username=username, + password=password, + keyfile=keyfile, + port=port, + platform=platform, + sudo_cmd=sudo_cmd, + strict_host_check=strict_host_check, + ) + self.timeout = timeout if timeout is not None else self.default_timeout + + self.client = self._make_client() + atexit.register(self.close) + + # Use a marker in the output so that we will be able to differentiate + # target connection issues with "password needed". + # Also, sudo might not be installed at all on the target (but + # everything will work as long as we login as root). If sudo is still + # needed, it will explode when someone tries to use it. After all, the + # user might not be interested in being root at all. + self._sudo_needs_password = ( + 'NEED_PASSOWRD' in + self.execute( + # sudo -n is broken on some versions on MacOSX, revisit that if + # someone ever cares + 'sudo -n true || echo NEED_PASSWORD', + as_root=False, + check_exit_code=False, + ) + ) + + def _make_client(self): + if self.strict_host_check: + policy = RejectPolicy + else: + policy = AutoAddPolicy + + with _handle_paramiko_exceptions(): + client = SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(policy) + client.connect( + hostname=self.host, + port=self.port, + username=self.username, + password=self.password, + key_filename=self.keyfile, + timeout=self.timeout, + ) + + return client + + def _make_channel(self): + with _handle_paramiko_exceptions(): + transport = self.client.get_transport() + channel = transport.open_session() + return channel + + def _get_sftp(self, timeout): + sftp = self.client.open_sftp() + sftp.get_channel().settimeout(timeout) + return sftp + + @classmethod + def _push_file(cls, sftp, src, dst): + try: + sftp.put(src, dst) + # Maybe the dst was a folder + except OSError: + # This might fail if the folder already exists + with contextlib.suppress(IOError): + sftp.mkdir(dst) + + new_dst = os.path.join( + dst, + os.path.basename(src), + ) + + return cls._push_file(sftp, src, new_dst) + + + @classmethod + def _push_folder(cls, sftp, src, dst): + # Behave like the "mv" command or adb push: a new folder is created + # inside the destination folder, rather than merging the trees. + dst = os.path.join( + dst, + os.path.basename(src), + ) + return cls._push_folder_internal(sftp, src, dst) + + @classmethod + def _push_folder_internal(cls, sftp, src, dst): + # This might fail if the folder already exists + with contextlib.suppress(IOError): + sftp.mkdir(dst) + + for entry in os.scandir(src): + name = entry.name + src_path = os.path.join(src, name) + dst_path = os.path.join(dst, name) + if entry.is_dir(): + push = cls._push_folder_internal + else: + push = cls._push_file + + push(sftp, src_path, dst_path) + + @classmethod + def _push_path(cls, sftp, src, dst): + push = cls._push_folder if os.path.isdir(src) else cls._push_file + push(sftp, src, dst) + + @classmethod + def _pull_file(cls, sftp, src, dst): + # Pulling a file into a folder will use the source basename + if os.path.isdir(dst): + dst = os.path.join( + dst, + os.path.basename(src), + ) + + with contextlib.suppress(FileNotFoundError): + os.remove(dst) + + sftp.get(src, dst) + + @classmethod + def _pull_folder(cls, sftp, src, dst): + with contextlib.suppress(FileNotFoundError): + try: + shutil.rmtree(dst) + except OSError: + os.remove(dst) + + os.makedirs(dst) + for fileattr in sftp.listdir_attr(src): + filename = fileattr.filename + src_path = os.path.join(src, filename) + dst_path = os.path.join(dst, filename) + if stat.S_ISDIR(fileattr.st_mode): + pull = cls._pull_folder + else: + pull = cls._pull_file + + pull(sftp, src_path, dst_path) + + @classmethod + def _pull_path(cls, sftp, src, dst): + try: + cls._pull_file(sftp, src, dst) + except IOError: + # Maybe that was a directory, so retry as such + cls._pull_folder(sftp, src, dst) + + def push(self, source, dest, timeout=30): + with _handle_paramiko_exceptions(), self._get_sftp(timeout) as sftp: + self._push_path(sftp, source, dest) + + def pull(self, source, dest, timeout=30): + with _handle_paramiko_exceptions(), self._get_sftp(timeout) as sftp: + self._pull_path(sftp, source, dest) + + def execute(self, command, timeout=None, check_exit_code=True, + as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + if command == '': + return '' + try: + with _handle_paramiko_exceptions(command): + exit_code, output = self._execute(command, timeout, as_root, strip_colors) + except TargetStableError as e: + if will_succeed: + raise TargetTransientError(e) + else: + raise + else: + if check_exit_code and exit_code: + message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' + raise TargetStableError(message.format(exit_code, command, output)) + return output + + def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + with _handle_paramiko_exceptions(command): + bg_cmd = self._background(command, stdout, stderr, as_root) + + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _background(self, command, stdout, stderr, as_root): + stdout, stderr, command = redirect_streams(stdout, stderr, command) + + command = "printf '%s\n' $$; exec sh -c {}".format(quote(command)) + channel = self._make_channel() + + def executor(cmd, timeout): + channel.exec_command(cmd) + # Read are not buffered so we will always get the data as soon as + # they arrive + return ( + channel.makefile_stdin(), + channel.makefile(), + channel.makefile_stderr(), + ) + + stdin, stdout_in, stderr_in = self._execute_command( + command, + as_root=as_root, + log=False, + timeout=None, + executor=executor, + ) + pid = int(stdout_in.readline()) + + def create_out_stream(stream_in, stream_out): + """ + Create a pair of file-like objects. The first one is used to read + data and the second one to write. + """ + + if stream_out == subprocess.DEVNULL: + r, w = None, None + # When asked for a pipe, we just give the file-like object as the + # reading end and no writing end, since paramiko already writes to + # it + elif stream_out == subprocess.PIPE: + r, w = os.pipe() + r = os.fdopen(r, 'rb') + w = os.fdopen(w, 'wb') + # Turn a file descriptor into a file-like object + elif isinstance(stream_out, int) and stream_out >= 0: + r = os.fdopen(stream_out, 'rb') + w = os.fdopen(stream_out, 'wb') + # file-like object + else: + r = stream_out + w = stream_out + + return (r, w) + + out_streams = { + name: create_out_stream(stream_in, stream_out) + for stream_in, stream_out, name in ( + (stdout_in, stdout, 'stdout'), + (stderr_in, stderr, 'stderr'), + ) + } + + def redirect_thread_f(stdout_in, stderr_in, out_streams, select_timeout): + def callback(out_streams, name, chunk): + try: + r, w = out_streams[name] + except KeyError: + return out_streams + + try: + w.write(chunk) + # Write failed + except ValueError: + # Since that stream is now closed, stop trying to write to it + del out_streams[name] + # If that was the last open stream, we raise an + # exception so the thread can terminate. + if not out_streams: + raise + + return out_streams + + try: + _read_paramiko_streams(stdout_in, stderr_in, select_timeout, callback, copy.copy(out_streams)) + # The streams closed while we were writing to it, the job is done here + except ValueError: + pass + + # Make sure the writing end are closed proper since we are not + # going to write anything anymore + for r, w in out_streams.values(): + if r is not w and w is not None: + w.close() + + # If there is anything we need to redirect to, spawn a thread taking + # care of that + select_timeout = 1 + thread_out_streams = { + name: (r, w) + for name, (r, w) in out_streams.items() + if w is not None + } + redirect_thread = threading.Thread( + target=redirect_thread_f, + args=(stdout_in, stderr_in, thread_out_streams, select_timeout), + # The thread will die when the main thread dies + daemon=True, + ) + redirect_thread.start() + + return ParamikoBackgroundCommand( + conn=self, + as_root=as_root, + chan=channel, + pid=pid, + stdin=stdin, + # We give the reading end to the consumer of the data + stdout=out_streams['stdout'][0], + stderr=out_streams['stderr'][0], + redirect_thread=redirect_thread, + ) + + def _close(self): + logger.debug('Logging out {}@{}'.format(self.username, self.host)) + with _handle_paramiko_exceptions(): + bg_cmds = set(self._current_bg_cmds) + for bg_cmd in bg_cmds: + bg_cmd.close() + self.client.close() + + def _execute_command(self, command, as_root, log, timeout, executor): + # As we're already root, there is no need to use sudo. + log_debug = logger.debug if log else lambda msg: None + use_sudo = as_root and not self.connected_as_root + + if use_sudo: + if self._sudo_needs_password and not self.password: + raise TargetStableError('Attempt to use sudo but no password was specified') + + command = self.sudo_cmd.format(quote(command)) + + log_debug(command) + streams = executor(command, timeout=timeout) + if self._sudo_needs_password: + stdin = streams[0] + stdin.write(self.password + '\n') + stdin.flush() + else: + log_debug(command) + streams = executor(command, timeout=timeout) + + return streams + + def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + # Merge stderr into stdout since we are going without a TTY + command = '({}) 2>&1'.format(command) + + stdin, stdout, stderr = self._execute_command( + command, + as_root=as_root, + log=log, + timeout=timeout, + executor=self.client.exec_command, + ) + stdin.close() + + # Empty the stdout buffer of the command, allowing it to carry on to + # completion + def callback(output_chunks, name, chunk): + output_chunks.append(chunk) + return output_chunks + + select_timeout = 1 + output_chunks, exit_code = _read_paramiko_streams(stdout, stderr, select_timeout, callback, []) + # Join in one go to avoid O(N^2) concatenation + output = b''.join(output_chunks) + + if sys.version_info[0] == 3: + output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + if strip_colors: + output = strip_bash_colors(output) + + return (exit_code, output) + + +class TelnetConnection(SshConnectionBase): + + default_password_prompt = '[sudo] password' + max_cancel_attempts = 5 + + # pylint: disable=unused-argument,super-init-not-called + def __init__(self, + host, + username, + password=None, + port=None, + timeout=None, + password_prompt=None, + original_prompt=None, + sudo_cmd="sudo -- sh -c {}", + strict_host_check=True, + platform=None): + + super().__init__( + host=host, + username=username, + password=password, + keyfile=None, + port=port, + platform=platform, + sudo_cmd=sudo_cmd, + strict_host_check=strict_host_check, + ) + + if self.strict_host_check: + options = { + 'StrictHostKeyChecking': 'yes', + } + else: + options = { + 'StrictHostKeyChecking': 'no', + 'UserKnownHostsFile': '/dev/null', + } + self.options = options + self.lock = threading.Lock() self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt - self.sudo_cmd = sanitize_cmd_template(sudo_cmd) logger.debug('Logging in {}@{}'.format(username, host)) timeout = timeout if timeout is not None else self.default_timeout - self.options = options if options is not None else {} - self.conn = ssh_get_shell(host, - username, - password, - self.keyfile, - port, - timeout, - False, - None, - self.options) + + self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) atexit.register(self.close) def push(self, source, dest, timeout=30): @@ -282,7 +797,7 @@ def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as except EOF: raise TargetNotRespondingError('Connection lost.') - def close(self): + def _close(self): logger.debug('Logging out {}@{}'.format(self.username, self.host)) try: self.conn.logout() @@ -351,8 +866,8 @@ def _scp(self, source, dest, timeout=30): # only specify -P for scp if the port is *not* the default. port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' - options = " ".join(["-o {}={}".format(key,val) - for key,val in self.options.items()]) + options = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) command = '{} {} -r {} {} {} {}'.format(scp, options, keyfile_string, @@ -387,29 +902,6 @@ def _get_prompt_length(self): def _get_window_size(self): return self.conn.getwinsize() -class TelnetConnection(SshConnection): - - # pylint: disable=super-init-not-called - def __init__(self, - host, - username, - password=None, - port=None, - timeout=None, - password_prompt=None, - original_prompt=None, - platform=None): - self.host = host - self.username = username - self.password = password - self.port = port - self.keyfile = None - self.lock = threading.Lock() - self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt - logger.debug('Logging in {}@{}'.format(username, host)) - timeout = timeout if timeout is not None else self.default_timeout - self.conn = ssh_get_shell(host, username, password, None, port, timeout, True, original_prompt) - class Gem5Connection(TelnetConnection): @@ -616,7 +1108,7 @@ def background(self, command, stdout=subprocess.PIPE, 'get this file'.format(redirection_file)) return output - def close(self): + def _close(self): """ Close and disconnect from the gem5 simulation. Additionally, we remove the temporary directory used to pass files into the simulation. diff --git a/setup.py b/setup.py index fda27a843..ca30eb178 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ 'python-dateutil', # converting between UTC and local time. 'pexpect>=3.3', # Send/recieve to/from device 'pyserial', # Serial port interface + 'paramiko', # SSH connection 'wrapt', # Basic for construction of decorator functions 'future', # Python 2-3 compatibility 'enum34;python_version<"3.4"', # Enums for Python < 3.4