diff --git a/pyignite/binary.py b/pyignite/binary.py index 4e34267..5a5f895 100644 --- a/pyignite/binary.py +++ b/pyignite/binary.py @@ -201,7 +201,7 @@ def write_footer(obj, stream, header, header_class, schema_items, offsets, initi stream.write(schema) if save_to_buf: - obj._buffer = bytes(stream.mem_view(initial_pos, stream.tell() - initial_pos)) + obj._buffer = stream.slice(initial_pos, stream.tell() - initial_pos) obj._hashcode = header.hash_code def _setattr(self, attr_name: str, attr_value: Any): diff --git a/pyignite/connection/aio_connection.py b/pyignite/connection/aio_connection.py index ce32592..020f8d4 100644 --- a/pyignite/connection/aio_connection.py +++ b/pyignite/connection/aio_connection.py @@ -158,7 +158,7 @@ async def _connect_version(self) -> Union[dict, OrderedDict]: with AioBinaryStream(self.client) as stream: await hs_request.from_python_async(stream) - await self._send(stream.getbuffer(), reconnect=False) + await self._send(stream.getvalue(), reconnect=False) with AioBinaryStream(self.client, await self._recv(reconnect=False)) as stream: hs_response = await HandshakeResponse.parse_async(stream, self.protocol_context) @@ -185,7 +185,7 @@ async def _reconnect(self): except connection_errors: pass - async def request(self, data: Union[bytes, bytearray, memoryview]) -> bytearray: + async def request(self, data: Union[bytes, bytearray]) -> bytearray: """ Perform request. @@ -195,7 +195,7 @@ async def request(self, data: Union[bytes, bytearray, memoryview]) -> bytearray: await self._send(data) return await self._recv() - async def _send(self, data: Union[bytes, bytearray, memoryview], reconnect=True): + async def _send(self, data: Union[bytes, bytearray], reconnect=True): if self.closed: raise SocketError('Attempt to use closed connection.') @@ -212,21 +212,43 @@ async def _recv(self, reconnect=True) -> bytearray: if self.closed: raise SocketError('Attempt to use closed connection.') - with BytesIO() as stream: + data = bytearray(1024) + buffer = memoryview(data) + bytes_total_received, bytes_to_receive = 0, 0 + while True: try: - buf = await self._reader.readexactly(4) - response_len = int.from_bytes(buf, PROTOCOL_BYTE_ORDER) + chunk = await self._reader.read(len(buffer)) + bytes_received = len(chunk) + if bytes_received == 0: + raise SocketError('Connection broken.') - stream.write(buf) - - stream.write(await self._reader.readexactly(response_len)) + buffer[0:bytes_received] = chunk + bytes_total_received += bytes_received except connection_errors: self.failed = True if reconnect: await self._reconnect() raise - return bytearray(stream.getbuffer()) + if bytes_total_received < 4: + continue + elif bytes_to_receive == 0: + response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER) + bytes_to_receive = response_len + + if response_len + 4 > len(data): + buffer.release() + data.extend(bytearray(response_len + 4 - len(data))) + buffer = memoryview(data)[bytes_total_received:] + continue + + if bytes_total_received >= bytes_to_receive: + buffer.release() + break + + buffer = buffer[bytes_received:] + + return data async def close(self): async with self._mux: diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py index 7d5778c..e8437dc 100644 --- a/pyignite/connection/connection.py +++ b/pyignite/connection/connection.py @@ -212,7 +212,7 @@ def _connect_version(self) -> Union[dict, OrderedDict]: with BinaryStream(self.client) as stream: hs_request.from_python(stream) - self.send(stream.getbuffer(), reconnect=False) + self.send(stream.getvalue(), reconnect=False) with BinaryStream(self.client, self.recv(reconnect=False)) as stream: hs_response = HandshakeResponse.parse(stream, self.protocol_context) @@ -235,7 +235,7 @@ def reconnect(self): except connection_errors: pass - def request(self, data: Union[bytes, bytearray, memoryview], flags=None) -> bytearray: + def request(self, data: Union[bytes, bytearray], flags=None) -> bytearray: """ Perform request. @@ -245,7 +245,7 @@ def request(self, data: Union[bytes, bytearray, memoryview], flags=None) -> byte self.send(data, flags=flags) return self.recv() - def send(self, data: Union[bytes, bytearray, memoryview], flags=None, reconnect=True): + def send(self, data: Union[bytes, bytearray], flags=None, reconnect=True): """ Send data down the socket. @@ -275,22 +275,6 @@ def recv(self, flags=None, reconnect=True) -> bytearray: :param flags: (optional) OS-specific flags. :param reconnect: (optional) reconnect on failure, default True. """ - def _recv(buffer, num_bytes): - bytes_to_receive = num_bytes - while bytes_to_receive > 0: - try: - bytes_rcvd = self._socket.recv_into(buffer, bytes_to_receive, **kwargs) - if bytes_rcvd == 0: - raise SocketError('Connection broken.') - except connection_errors: - self.failed = True - if reconnect: - self.reconnect() - raise - - buffer = buffer[bytes_rcvd:] - bytes_to_receive -= bytes_rcvd - if self.closed: raise SocketError('Attempt to use closed connection.') @@ -298,12 +282,39 @@ def _recv(buffer, num_bytes): if flags is not None: kwargs['flags'] = flags - data = bytearray(4) - _recv(memoryview(data), 4) - response_len = int.from_bytes(data, PROTOCOL_BYTE_ORDER) + data = bytearray(1024) + buffer = memoryview(data) + bytes_total_received, bytes_to_receive = 0, 0 + while True: + try: + bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs) + if bytes_received == 0: + raise SocketError('Connection broken.') + bytes_total_received += bytes_received + except connection_errors: + self.failed = True + if reconnect: + self.reconnect() + raise + + if bytes_total_received < 4: + continue + elif bytes_to_receive == 0: + response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER) + bytes_to_receive = response_len + + if response_len + 4 > len(data): + buffer.release() + data.extend(bytearray(response_len + 4 - len(data))) + buffer = memoryview(data)[bytes_total_received:] + continue + + if bytes_total_received >= bytes_to_receive: + buffer.release() + break + + buffer = buffer[bytes_received:] - data.extend(bytearray(response_len)) - _recv(memoryview(data)[4:], response_len) return data def close(self): diff --git a/pyignite/datatypes/internal.py b/pyignite/datatypes/internal.py index 0de50e2..55ed844 100644 --- a/pyignite/datatypes/internal.py +++ b/pyignite/datatypes/internal.py @@ -36,7 +36,10 @@ from ..stream import READ_BACKWARD -def tc_map(key: bytes, _memo_map: dict = {}): +_tc_map = {} + + +def tc_map(key: bytes): """ Returns a default parser/generator class for the given type code. @@ -49,7 +52,8 @@ def tc_map(key: bytes, _memo_map: dict = {}): of the “type code-type class” mapping, :return: parser/generator class for the type code. """ - if not _memo_map: + global _tc_map + if not _tc_map: from pyignite.datatypes import ( Null, ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject, DateObject, @@ -64,7 +68,7 @@ def tc_map(key: bytes, _memo_map: dict = {}): MapObject, BinaryObject, WrappedDataObject, ) - _memo_map = { + _tc_map = { TC_NULL: Null, TC_BYTE: ByteObject, @@ -110,7 +114,7 @@ def tc_map(key: bytes, _memo_map: dict = {}): TC_COMPLEX_OBJECT: BinaryObject, TC_ARRAY_WRAPPED_OBJECTS: WrappedDataObject, } - return _memo_map[key] + return _tc_map[key] class Conditional: @@ -183,7 +187,7 @@ async def parse_async(self, stream): def __parse_length(self, stream): counter_type_len = ctypes.sizeof(self.counter_type) length = int.from_bytes( - stream.mem_view(offset=counter_type_len), + stream.slice(offset=counter_type_len), byteorder=PROTOCOL_BYTE_ORDER ) stream.seek(counter_type_len, SEEK_CUR) @@ -348,6 +352,9 @@ class AnyDataObject: """ _python_map = None _python_array_map = None + _map_obj_type = None + _collection_obj_type = None + _binary_obj_type = None @staticmethod def get_subtype(iterable, allow_none=False): @@ -391,7 +398,7 @@ async def parse_async(cls, stream): @classmethod def __data_class_parse(cls, stream): - type_code = bytes(stream.mem_view(offset=ctypes.sizeof(ctypes.c_byte))) + type_code = stream.slice(offset=ctypes.sizeof(ctypes.c_byte)) try: return tc_map(type_code) except KeyError: @@ -416,15 +423,17 @@ def __data_class_from_ctype(cls, ctype_object): return tc_map(type_code) @classmethod - def _init_python_map(cls): + def _init_python_mapping(cls): """ Optimizes Python types→Ignite types map creation for speed. Local imports seem inevitable here. """ from pyignite.datatypes import ( - LongObject, DoubleObject, String, BoolObject, Null, UUIDObject, - DateObject, TimeObject, DecimalObject, ByteArrayObject, + LongObject, DoubleObject, String, BoolObject, Null, UUIDObject, DateObject, TimeObject, + DecimalObject, ByteArrayObject, LongArrayObject, DoubleArrayObject, StringArrayObject, + BoolArrayObject, UUIDArrayObject, DateArrayObject, TimeArrayObject, DecimalArrayObject, + MapObject, CollectionObject, BinaryObject ) cls._python_map = { @@ -442,17 +451,6 @@ def _init_python_map(cls): decimal.Decimal: DecimalObject, } - @classmethod - def _init_python_array_map(cls): - """ - Optimizes Python types→Ignite array types map creation for speed. - """ - from pyignite.datatypes import ( - LongArrayObject, DoubleArrayObject, StringArrayObject, - BoolArrayObject, UUIDArrayObject, DateArrayObject, TimeArrayObject, - DecimalArrayObject, - ) - cls._python_array_map = { int: LongArrayObject, float: DoubleArrayObject, @@ -466,18 +464,20 @@ def _init_python_array_map(cls): decimal.Decimal: DecimalArrayObject, } + cls._map_obj_type = MapObject + cls._collection_obj_type = CollectionObject + cls._binary_obj_type = BinaryObject + @classmethod def map_python_type(cls, value): - from pyignite.datatypes import ( - MapObject, CollectionObject, BinaryObject, - ) - - if cls._python_map is None: - cls._init_python_map() - if cls._python_array_map is None: - cls._init_python_array_map() + if cls._python_map is None or cls._python_array_map is None: + cls._init_python_mapping() value_type = type(value) + + if value_type in cls._python_map: + return cls._python_map[value_type] + if is_iterable(value) and value_type not in (str, bytearray, bytes): value_subtype = cls.get_subtype(value) if value_subtype in cls._python_array_map: @@ -490,7 +490,7 @@ def map_python_type(cls, value): isinstance(value[0], int), isinstance(value[1], dict), ]): - return MapObject + return cls._map_obj_type if all([ value_subtype is None, @@ -498,7 +498,7 @@ def map_python_type(cls, value): isinstance(value[0], int), is_iterable(value[1]), ]): - return CollectionObject + return cls._collection_obj_type # no default for ObjectArrayObject, sorry @@ -507,10 +507,8 @@ def map_python_type(cls, value): ) if is_binary(value): - return BinaryObject + return cls._binary_obj_type - if value_type in cls._python_map: - return cls._python_map[value_type] raise TypeError( 'Type `{}` is invalid.'.format(value_type) ) diff --git a/pyignite/datatypes/null_object.py b/pyignite/datatypes/null_object.py index f16034f..8ac47b2 100644 --- a/pyignite/datatypes/null_object.py +++ b/pyignite/datatypes/null_object.py @@ -140,7 +140,7 @@ async def to_python_async(cls, ctypes_object, *args, **kwargs): def __check_null_input(cls, stream): type_len = ctypes.sizeof(ctypes.c_byte) - if stream.mem_view(offset=type_len) == TC_NULL: + if stream.slice(offset=type_len) == TC_NULL: stream.seek(type_len, SEEK_CUR) return True, Null.build_c_type() diff --git a/pyignite/datatypes/standard.py b/pyignite/datatypes/standard.py index 2b61235..4ca6795 100644 --- a/pyignite/datatypes/standard.py +++ b/pyignite/datatypes/standard.py @@ -91,7 +91,7 @@ def build_c_type(cls, length: int): @classmethod def parse_not_null(cls, stream): length = int.from_bytes( - stream.mem_view(stream.tell() + ctypes.sizeof(ctypes.c_byte), ctypes.sizeof(ctypes.c_int)), + stream.slice(stream.tell() + ctypes.sizeof(ctypes.c_byte), ctypes.sizeof(ctypes.c_int)), byteorder=PROTOCOL_BYTE_ORDER ) diff --git a/pyignite/queries/query.py b/pyignite/queries/query.py index d9e6aaf..8dac64f 100644 --- a/pyignite/queries/query.py +++ b/pyignite/queries/query.py @@ -122,7 +122,7 @@ def perform( """ with BinaryStream(conn.client) as stream: self.from_python(stream, query_params) - response_data = conn.request(stream.getbuffer()) + response_data = conn.request(stream.getvalue()) response_struct = self.response_type(protocol_context=conn.protocol_context, following=response_config, **kwargs) @@ -154,7 +154,7 @@ async def perform_async( """ with AioBinaryStream(conn.client) as stream: await self.from_python_async(stream, query_params) - data = await conn.request(stream.getbuffer()) + data = await conn.request(stream.getvalue()) response_struct = self.response_type(protocol_context=conn.protocol_context, following=response_config, **kwargs) diff --git a/pyignite/queries/response.py b/pyignite/queries/response.py index 6495802..f0338e1 100644 --- a/pyignite/queries/response.py +++ b/pyignite/queries/response.py @@ -27,42 +27,42 @@ from pyignite.stream import READ_BACKWARD +class StatusFlagResponseHeader(ctypes.LittleEndianStructure): + _pack_ = 1 + _fields_ = [ + ('length', ctypes.c_int), + ('query_id', ctypes.c_longlong), + ('flags', ctypes.c_short) + ] + + +class ResponseHeader(ctypes.LittleEndianStructure): + _pack_ = 1 + _fields_ = [ + ('length', ctypes.c_int), + ('query_id', ctypes.c_longlong), + ('status_code', ctypes.c_int) + ] + + @attr.s class Response: following = attr.ib(type=list, factory=list) protocol_context = attr.ib(type=type(ProtocolContext), default=None) - _response_header = None _response_class_name = 'Response' def __attrs_post_init__(self): # replace None with empty list self.following = self.following or [] - def __build_header(self): - if self._response_header is None: - fields = [ - ('length', ctypes.c_int), - ('query_id', ctypes.c_longlong), - ] - - if self.protocol_context.is_status_flags_supported(): - fields.append(('flags', ctypes.c_short)) - else: - fields.append(('status_code', ctypes.c_int),) - - self._response_header = type( - 'ResponseHeader', - (ctypes.LittleEndianStructure,), - { - '_pack_': 1, - '_fields_': fields, - }, - ) - return self._response_header - def __parse_header(self, stream): init_pos = stream.tell() - header_class = self.__build_header() + + if self.protocol_context.is_status_flags_supported(): + header_class = StatusFlagResponseHeader + else: + header_class = ResponseHeader + header_len = ctypes.sizeof(header_class) header = stream.read_ctype(header_class) stream.seek(header_len, SEEK_CUR) diff --git a/pyignite/stream/binary_stream.py b/pyignite/stream/binary_stream.py index 57b4b83..3923a3b 100644 --- a/pyignite/stream/binary_stream.py +++ b/pyignite/stream/binary_stream.py @@ -23,7 +23,12 @@ READ_BACKWARD = 1 -class BinaryStreamBaseMixin: +class BinaryStreamBase: + def __init__(self, client, buf=None): + self.client = client + self.stream = BytesIO(buf) if buf else BytesIO() + self._buffer = None + @property def compact_footer(self) -> bool: return self.client.compact_footer @@ -50,10 +55,11 @@ def read_ctype(self, ctype_class, position=None, direction=READ_FORWARD): else: start, end = init_position - ctype_len, init_position - buf = self.stream.getbuffer()[start:end] - return ctype_class.from_buffer_copy(buf) + with self.getbuffer()[start:end] as buf: + return ctype_class.from_buffer_copy(buf) def write(self, buf): + self._release_buffer() return self.stream.write(buf) def tell(self): @@ -62,30 +68,39 @@ def tell(self): def seek(self, *args, **kwargs): return self.stream.seek(*args, **kwargs) + def getbuffer(self): + if self._buffer: + return self._buffer + + self._buffer = self.stream.getbuffer() + return self._buffer + def getvalue(self): return self.stream.getvalue() - def getbuffer(self): - return self.stream.getbuffer() - - def mem_view(self, start=-1, offset=0): + def slice(self, start=-1, offset=0): start = start if start >= 0 else self.tell() - return self.stream.getbuffer()[start:start + offset] + with self.getbuffer()[start:start + offset] as buf: + return bytes(buf) def hashcode(self, start, bytes_len): - return ignite_utils.hashcode(self.stream.getbuffer()[start:start + bytes_len]) + with self.getbuffer()[start:start + bytes_len] as buf: + return ignite_utils.hashcode(buf) + + def _release_buffer(self): + if self._buffer: + self._buffer.release() + self._buffer = None def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - try: - self.stream.close() - except BufferError: - pass + self._release_buffer() + self.stream.close() -class BinaryStream(BinaryStreamBaseMixin): +class BinaryStream(BinaryStreamBase): """ Synchronous binary stream. """ @@ -94,8 +109,7 @@ def __init__(self, client: 'pyignite.Client', buf: Optional[Union[bytes, bytearr :param client: Client instance, required. :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO. """ - self.client = client - self.stream = BytesIO(buf) if buf else BytesIO() + super().__init__(client, buf) def get_dataclass(self, header): result = self.client.query_binary_type(header.type_id, header.schema_id) @@ -107,7 +121,7 @@ def register_binary_type(self, *args, **kwargs): self.client.register_binary_type(*args, **kwargs) -class AioBinaryStream(BinaryStreamBaseMixin): +class AioBinaryStream(BinaryStreamBase): """ Asyncio binary stream. """ @@ -118,8 +132,7 @@ def __init__(self, client: 'pyignite.AioClient', buf: Optional[Union[bytes, byte :param client: AioClient instance, required. :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO. """ - self.client = client - self.stream = BytesIO(buf) if buf else BytesIO() + super().__init__(client, buf) async def get_dataclass(self, header): result = await self.client.query_binary_type(header.type_id, header.schema_id)