Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions adb/adb_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,17 @@ def ConnectDevice(self, port_path=None, serial=None, default_timeout_ms=None, **
# If there isnt a handle override (used by tests), build one here
if 'handle' in kwargs:
self._handle = kwargs.pop('handle')
elif serial and b':' in serial:
self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms)
else:
self._handle = common.UsbHandle.FindAndOpen(
DeviceIsAvailable, port_path=port_path, serial=serial,
timeout_ms=default_timeout_ms)
# if necessary, convert serial to a unicode string
if isinstance(serial, (bytes, bytearray)):
serial = serial.decode('utf-8')

if serial and ':' in serial:
self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms)
else:
self._handle = common.UsbHandle.FindAndOpen(
DeviceIsAvailable, port_path=port_path, serial=serial,
timeout_ms=default_timeout_ms)

self._Connect(**kwargs)

Expand Down
5 changes: 5 additions & 0 deletions adb/adb_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def Connect(cls, usb, banner=b'notadb', rsa_keys=None, auth_timeout_ms=100):
InvalidResponseError: When the device does authentication in an
unexpected way.
"""
# In py3, convert unicode to bytes. In py2, convert str to bytes.
# It's later joined into a byte string, so in py2, this ends up kind of being a no-op.
if isinstance(banner, str):
banner = bytearray(banner, 'utf-8')

msg = cls(
command=b'CNXN', arg0=VERSION, arg1=MAX_ADB_DATA,
data=b'host::%s\0' % banner)
Expand Down
23 changes: 17 additions & 6 deletions adb/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,26 @@ def __init__(self, serial, timeout_ms=None):

Host may be an IP address or a host name.
"""
if b':' in serial:
(host, port) = serial.split(b':')
# if necessary, convert serial to a unicode string
if isinstance(serial, (bytes, bytearray)):
serial = serial.decode('utf-8')

if ':' in serial:
self.host, self.port = serial.split(':')
else:
host = serial
port = 5555
self._serial_number = '%s:%s' % (host, port)
self.host = serial
self.port = 5555

self._connection = None
self._serial_number = '%s:%s' % (self.host, self.port)
self._timeout_ms = float(timeout_ms) if timeout_ms else None

self._connect()

def _connect(self):
timeout = self.TimeoutSeconds(self._timeout_ms)
self._connection = socket.create_connection((host, port), timeout=timeout)
self._connection = socket.create_connection((self.host, self.port),
timeout=timeout)
if timeout:
self._connection.setblocking(0)

Expand Down
49 changes: 42 additions & 7 deletions test/adb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
from io import BytesIO
import struct
import unittest
from mock import mock


from adb import common
from adb import adb_commands
from adb import adb_protocol
from adb.usb_exceptions import TcpTimeoutException
from adb.usb_exceptions import TcpTimeoutException, DeviceNotFoundError
import common_stub


Expand Down Expand Up @@ -78,10 +81,9 @@ def _Connect(cls, usb):


class AdbTest(BaseAdbTest):

@classmethod
def _ExpectCommand(cls, service, command, *responses):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
cls._ExpectConnection(usb)
cls._ExpectOpen(usb, b'%s:%s\0' % (service, command))

Expand All @@ -91,12 +93,19 @@ def _ExpectCommand(cls, service, command, *responses):
return usb

def testConnect(self):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
self._ExpectConnection(usb)

dev = adb_commands.AdbCommands()
dev.ConnectDevice(handle=usb, banner=BANNER)

def testConnectSerialString(self):
dev = adb_commands.AdbCommands()

with mock.patch.object(common.UsbHandle, 'FindAndOpen', return_value=None):
with mock.patch.object(adb_commands.AdbCommands, '_Connect', return_value=None):
dev.ConnectDevice(serial='/dev/invalidHandle')

def testSmallResponseShell(self):
command = b'keepin it real'
response = 'word.'
Expand Down Expand Up @@ -196,7 +205,7 @@ def _MakeWriteSyncPacket(cls, command, data=b'', size=None):

@classmethod
def _ExpectSyncCommand(cls, write_commands, read_commands):
usb = common_stub.StubUsb()
usb = common_stub.StubUsb(device=None, setting=None)
cls._ExpectConnection(usb)
cls._ExpectOpen(usb, b'sync:\0')

Expand Down Expand Up @@ -246,7 +255,7 @@ class TcpTimeoutAdbTest(BaseAdbTest):

@classmethod
def _ExpectCommand(cls, service, command, *responses):
tcp = common_stub.StubTcp()
tcp = common_stub.StubTcp('10.0.0.123')
cls._ExpectConnection(tcp)
cls._ExpectOpen(tcp, b'%s:%s\0' % (service, command))

Expand All @@ -262,7 +271,7 @@ def _run_shell(self, cmd, timeout_ms=None):
dev.Shell(cmd, timeout_ms=timeout_ms)

def testConnect(self):
tcp = common_stub.StubTcp()
tcp = common_stub.StubTcp('10.0.0.123')
self._ExpectConnection(tcp)
dev = adb_commands.AdbCommands()
dev.ConnectDevice(handle=tcp, banner=BANNER)
Expand All @@ -276,5 +285,31 @@ def testTcpTimeout(self):
command,
timeout_ms=timeout_ms)


class TcpHandleTest(unittest.TestCase):
def testInitWithHost(self):
tcp = common_stub.StubTcp('10.11.12.13')

self.assertEqual('10.11.12.13:5555', tcp._serial_number)
self.assertEqual(None, tcp._timeout_ms)

def testInitWithHostAndPort(self):
tcp = common_stub.StubTcp('10.11.12.13:5678')

self.assertEqual('10.11.12.13:5678', tcp._serial_number)
self.assertEqual(None, tcp._timeout_ms)

def testInitWithTimeout(self):
tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234.5)

self.assertEqual('10.0.0.2:5555', tcp._serial_number)
self.assertEqual(234.5, tcp._timeout_ms)

def testInitWithTimeoutInt(self):
tcp = common_stub.StubTcp('10.0.0.2', timeout_ms=234)

self.assertEqual('10.0.0.2:5555', tcp._serial_number)
self.assertEqual(234.0, tcp._timeout_ms)

if __name__ == '__main__':
unittest.main()
115 changes: 70 additions & 45 deletions test/common_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import string
import sys
import time
from adb.usb_exceptions import TcpTimeoutException
from mock import mock

from adb.common import TcpHandle, UsbHandle
from adb.usb_exceptions import TcpTimeoutException

PRINTABLE_DATA = set(string.printable) - set(string.whitespace)

Expand All @@ -16,33 +19,23 @@ def _Dotify(data):
return ''.join(char if char in PRINTABLE_DATA else '.' for char in data)


class StubUsb(object):
"""UsbHandle stub."""

def __init__(self):
class StubHandleBase(object):
def __init__(self, timeout_ms, is_tcp=False):
self.written_data = []
self.read_data = []
self.timeout_ms = 0
self.is_tcp = is_tcp
self.timeout_ms = timeout_ms

def BulkWrite(self, data, unused_timeout_ms=None):
expected_data = self.written_data.pop(0)
if isinstance(data, bytearray):
data = bytes(data)
if not isinstance(data, bytes):
data = data.encode('utf8')
if expected_data != data:
raise ValueError('Expected %s (%s) got %s (%s)' % (
binascii.hexlify(expected_data), _Dotify(expected_data),
binascii.hexlify(data), _Dotify(data)))
def _signal_handler(self, signum, frame):
raise TcpTimeoutException('End of time')

def BulkRead(self, length,
timeout_ms=None): # pylint: disable=unused-argument
data = self.read_data.pop(0)
if length < len(data):
raise ValueError(
'Overflow packet length. Read %d bytes, got %d bytes: %s',
length, len(data))
return bytearray(data)
def _return_seconds(self, time_ms):
return (float(time_ms)/1000) if time_ms else 0

def _alarm_sounder(self, timeout_ms):
signal.signal(signal.SIGALRM, self._signal_handler)
signal.setitimer(signal.ITIMER_REAL,
self._return_seconds(timeout_ms))

def ExpectWrite(self, data):
if not isinstance(data, bytes):
Expand All @@ -54,22 +47,6 @@ def ExpectRead(self, data):
data = data.encode('utf8')
self.read_data.append(data)

def Timeout(self, timeout_ms):
return timeout_ms if timeout_ms is not None else self.timeout_ms

class StubTcp(StubUsb):

def _signal_handler(self, signum, frame):
raise TcpTimeoutException('End of time')

def _return_seconds(self, time_ms):
return (float(time_ms)/1000) if time_ms else 0

def _alarm_sounder(self, timeout_ms):
signal.signal(signal.SIGALRM, self._signal_handler)
signal.setitimer(signal.ITIMER_REAL,
self._return_seconds(timeout_ms))

def BulkWrite(self, data, timeout_ms=None):
expected_data = self.written_data.pop(0)
if isinstance(data, bytearray):
Expand All @@ -80,8 +57,8 @@ def BulkWrite(self, data, timeout_ms=None):
raise ValueError('Expected %s (%s) got %s (%s)' % (
binascii.hexlify(expected_data), _Dotify(expected_data),
binascii.hexlify(data), _Dotify(data)))
if b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
if self.is_tcp and b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
time.sleep(2*self._return_seconds(timeout_ms))

def BulkRead(self, length,
Expand All @@ -91,8 +68,56 @@ def BulkRead(self, length,
raise ValueError(
'Overflow packet length. Read %d bytes, got %d bytes: %s',
length, len(data))
if b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
if self.is_tcp and b'i_need_a_timeout' in data:
self._alarm_sounder(timeout_ms)
time.sleep(2*self._return_seconds(timeout_ms))
return bytearray(data)
return bytearray(data)

def Timeout(self, timeout_ms):
return timeout_ms if timeout_ms is not None else self.timeout_ms


class StubUsb(UsbHandle):
"""UsbHandle stub."""
def __init__(self, device, setting, usb_info=None, timeout_ms=None):
super(StubUsb, self).__init__(device, setting, usb_info, timeout_ms)
self.stub_base = StubHandleBase(0)

def ExpectWrite(self, data):
return self.stub_base.ExpectWrite(data)

def ExpectRead(self, data):
return self.stub_base.ExpectRead(data)

def BulkWrite(self, data, unused_timeout_ms=None):
return self.stub_base.BulkWrite(data, unused_timeout_ms)

def BulkRead(self, length, timeout_ms=None):
return self.stub_base.BulkRead(length, timeout_ms)

def Timeout(self, timeout_ms):
return self.stub_base.Timeout(timeout_ms)


class StubTcp(TcpHandle):
def __init__(self, serial, timeout_ms=None):
"""TcpHandle stub."""
self._connect = mock.MagicMock(return_value=None)

super(StubTcp, self).__init__(serial, timeout_ms)
self.stub_base = StubHandleBase(0, is_tcp=True)

def ExpectWrite(self, data):
return self.stub_base.ExpectWrite(data)

def ExpectRead(self, data):
return self.stub_base.ExpectRead(data)

def BulkWrite(self, data, unused_timeout_ms=None):
return self.stub_base.BulkWrite(data, unused_timeout_ms)

def BulkRead(self, length, timeout_ms=None):
return self.stub_base.BulkRead(length, timeout_ms)

def Timeout(self, timeout_ms):
return self.stub_base.Timeout(timeout_ms)
2 changes: 1 addition & 1 deletion test/fastboot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class FastbootTest(unittest.TestCase):

def setUp(self):
self.usb = common_stub.StubUsb()
self.usb = common_stub.StubUsb(device=None, setting=None)

@staticmethod
def _SumLengths(items):
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ envlist =
deps =
pytest
pytest-cov
mock
usedevelop = True
commands = py.test --cov adb test