Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 132 additions & 134 deletions devlib/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
#

from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime
from functools import partial
from weakref import WeakSet
from shlex import quote
from time import monotonic
import os
import signal
import socket
Expand Down Expand Up @@ -61,13 +60,27 @@ class ConnectionBase(InitCheckpoint):
"""
Base class for all connections.
"""
def __init__(self):
def __init__(
self,
poll_transfers=False,
start_transfer_poll_delay=30,
total_transfer_timeout=3600,
transfer_poll_period=30,
):
self._current_bg_cmds = set()
self._closed = False
self._close_lock = threading.Lock()
self.busybox = None
self.logger = logging.getLogger('Connection')

self.transfer_manager = TransferManager(
self,
start_transfer_poll_delay=start_transfer_poll_delay,
total_transfer_timeout=total_transfer_timeout,
transfer_poll_period=transfer_poll_period,
) if poll_transfers else NoopTransferManager()


def cancel_running_command(self):
bg_cmds = set(self._current_bg_cmds)
for bg_cmd in bg_cmds:
Expand Down Expand Up @@ -405,13 +418,13 @@ def create_out():
b''.join(out[stderr])
)

start = monotonic()
start = time.monotonic()

while ret is None:
# Even if ret is not None anymore, we need to drain the streams
ret = self.poll()

if timeout is not None and ret is None and monotonic() - start >= timeout:
if timeout is not None and ret is None and time.monotonic() - start >= timeout:
self.cancel()
_stdout, _stderr = create_out()
raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr)
Expand Down Expand Up @@ -563,166 +576,151 @@ def __enter__(self):
return self


class TransferManagerBase(ABC):
class TransferManager:
def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600):
self.conn = conn
self.transfer_poll_period = transfer_poll_period
self.total_transfer_timeout = total_transfer_timeout
self.start_transfer_poll_delay = start_transfer_poll_delay

def _pull_dest_size(self, dest):
if os.path.isdir(dest):
return sum(
os.stat(os.path.join(dirpath, f)).st_size
for dirpath, _, fnames in os.walk(dest)
for f in fnames
)
else:
return os.stat(dest).st_size
return 0
self.logger = logging.getLogger('FileTransfer')

def _push_dest_size(self, dest):
cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest))
out = self.conn.execute(cmd)
@contextmanager
def manage(self, sources, dest, direction, handle):
excep = None
stop_thread = threading.Event()

def monitor():
nonlocal excep

def cancel(reason):
self.logger.warning(
f'Cancelling file transfer {sources} -> {dest} due to: {reason}'
)
handle.cancel()

start_t = time.monotonic()
stop_thread.wait(self.start_transfer_poll_delay)
while not stop_thread.wait(self.transfer_poll_period):
if not handle.isactive():
cancel(reason='transfer inactive')
elif time.monotonic() - start_t > self.total_transfer_timeout:
cancel(reason='transfer timed out')
excep = TimeoutError(f'{direction}: {sources} -> {dest}')

m_thread = threading.Thread(target=monitor, daemon=True)
try:
return int(out.split()[0])
except ValueError:
return 0
m_thread.start()
yield self
finally:
stop_thread.set()
m_thread.join()
if excep is not None:
raise excep

def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout):
self.conn = conn
self.poll_period = poll_period
self.total_timeout = total_timeout
self.start_transfer_poll_delay = start_transfer_poll_delay

self.logger = logging.getLogger('FileTransfer')
self.managing = threading.Event()
self.transfer_started = threading.Event()
self.transfer_completed = threading.Event()
self.transfer_aborted = threading.Event()
class NoopTransferManager:
def manage(self, *args, **kwargs):
return nullcontext(self)

self.monitor_thread = None
self.sources = None
self.dest = None
self.direction = None

@abstractmethod
def _cancel(self):
pass
class TransferHandleBase(ABC):
def __init__(self, manager):
self.manager = manager

def cancel(self, reason=None):
msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest)
if reason is not None:
msg += ' due to \'{}\''.format(reason)
self.logger.warning(msg)
self.transfer_aborted.set()
self._cancel()
@property
def logger(self):
return self.manager.logger

@abstractmethod
def isactive(self):
pass

@contextmanager
def manage(self, sources, dest, direction):
try:
self.sources, self.dest, self.direction = sources, dest, direction
m_thread = threading.Thread(target=self._monitor)
@abstractmethod
def cancel(self):
pass

self.transfer_completed.clear()
self.transfer_aborted.clear()
self.transfer_started.set()

m_thread.start()
yield self
except BaseException:
self.cancel(reason='exception during transfer')
raise
finally:
self.transfer_completed.set()
self.transfer_started.set()
m_thread.join()
self.transfer_started.clear()
self.transfer_completed.clear()
self.transfer_aborted.clear()
class PopenTransferHandle(TransferHandleBase):
def __init__(self, bg_cmd, dest, direction, *args, **kwargs):
super().__init__(*args, **kwargs)

def _monitor(self):
start_t = monotonic()
self.transfer_completed.wait(self.start_transfer_poll_delay)
while not self.transfer_completed.wait(self.poll_period):
if not self.isactive():
self.cancel(reason='transfer inactive')
elif monotonic() - start_t > self.total_timeout:
self.cancel(reason='transfer timed out')
if direction == 'push':
sample_size = self._push_dest_size
elif direction == 'pull':
sample_size = self._pull_dest_size
else:
raise ValueError(f'Unknown direction: {direction}')

self.sample_size = lambda: sample_size(dest)

class PopenTransferManager(TransferManagerBase):
self.bg_cmd = bg_cmd
self.last_sample = 0

def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
self.transfer = None
self.last_sample = None
@staticmethod
def _pull_dest_size(dest):
if os.path.isdir(dest):
return sum(
os.stat(os.path.join(dirpath, f)).st_size
for dirpath, _, fnames in os.walk(dest)
for f in fnames
)
else:
return os.stat(dest).st_size

def _cancel(self):
if self.transfer:
self.transfer.cancel()
self.transfer = None
self.last_sample = None
def _push_dest_size(self, dest):
conn = self.manager.conn
cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest))
out = conn.execute(cmd)
return int(out.split()[0])

def cancel(self):
self.bg_cmd.cancel()

def isactive(self):
size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size
curr_size = size_fn(self.dest)
self.logger.debug('Polled file transfer, destination size {}'.format(curr_size))
active = True if self.last_sample is None else curr_size > self.last_sample
self.last_sample = curr_size
return active
try:
curr_size = self.sample_size()
except Exception as e:
self.logger.debug(f'File size polling failed: {e}')
return True
else:
self.logger.debug(f'Polled file transfer, destination size: {curr_size}')
if curr_size:
active = curr_size > self.last_sample
self.last_sample = curr_size
return active
# If the file is empty it will never grow in size, so we assume
# everything is going well.
else:
return True

def set_transfer_and_wait(self, popen_bg_cmd):
self.transfer = popen_bg_cmd
self.last_sample = None
ret = self.transfer.wait()

if ret and not self.transfer_aborted.is_set():
raise subprocess.CalledProcessError(ret, self.transfer.popen.args)
elif self.transfer_aborted.is_set():
raise TimeoutError(self.transfer.popen.args)
class SSHTransferHandle(TransferHandleBase):

def __init__(self, handle, *args, **kwargs):
super().__init__(*args, **kwargs)

class SSHTransferManager(TransferManagerBase):
# SFTPClient or SSHClient
self.handle = handle

def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
self.transferer = None
self.progressed = False
self.transferred = None
self.to_transfer = None
self.transferred = 0
self.to_transfer = 0

def _cancel(self):
self.transferer.close()
def cancel(self):
self.handle.close()

def isactive(self):
progressed = self.progressed
self.progressed = False
msg = 'Polled transfer: {}% [{}B/{}B]'
pc = format((self.transferred / self.to_transfer) * 100, '.2f')
self.logger.debug(msg.format(pc, self.transferred, self.to_transfer))
if progressed:
self.progressed = False
pc = (self.transferred / self.to_transfer) * 100
self.logger.debug(
f'Polled transfer: {pc:.2f}% [{self.transferred}B/{self.to_transfer}B]'
)
return progressed

@contextmanager
def manage(self, sources, dest, direction, transferer):
with super().manage(sources, dest, direction):
try:
self.progressed = False
self.transferer = transferer # SFTPClient or SCPClient
yield self
except socket.error as e:
if self.transfer_aborted.is_set():
self.transfer_aborted.clear()
method = 'SCP' if self.conn.use_scp else 'SFTP'
raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest))
else:
raise e

def progress_cb(self, *args):
if self.transfer_started.is_set():
self.progressed = True
if len(args) == 3: # For SCPClient callbacks
self.transferred = args[2]
self.to_transfer = args[1]
elif len(args) == 2: # For SFTPClient callbacks
self.transferred = args[0]
self.to_transfer = args[1]
def progress_cb(self, transferred, to_transfer):
self.progressed = True
self.transferred = transferred
self.to_transfer = to_transfer
2 changes: 1 addition & 1 deletion devlib/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2931,7 +2931,7 @@ def __init__(self,
ssh_conn_params = ['host', 'username', 'password', 'keyfile',
'port', 'timeout', 'sudo_cmd',
'strict_host_check', 'use_scp',
'total_timeout', 'poll_transfers',
'total_transfer_timeout', 'poll_transfers',
'start_transfer_poll_delay']
self.ssh_connection_settings = {}
for setting in ssh_conn_params:
Expand Down
Loading