diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py index 98ba7e0..3d86f01 100644 --- a/pyignite/connection/connection.py +++ b/pyignite/connection/connection.py @@ -156,6 +156,9 @@ def _connection_listener(self): return self.client._event_listeners +DEFAULT_INITIAL_BUF_SIZE = 1024 + + class Connection(BaseConnection): """ This is a `pyignite` class, that represents a connection to Ignite @@ -348,15 +351,15 @@ def recv(self, flags=None, reconnect=True) -> bytearray: if flags is not None: kwargs['flags'] = flags - data = bytearray(1024) + data = bytearray(DEFAULT_INITIAL_BUF_SIZE) buffer = memoryview(data) - bytes_total_received, bytes_to_receive = 0, 0 + total_rcvd, packet_len = 0, 0 while True: try: - bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs) - if bytes_received == 0: + bytes_rcvd = self._socket.recv_into(buffer, len(buffer), **kwargs) + if bytes_rcvd == 0: raise SocketError('Connection broken.') - bytes_total_received += bytes_received + total_rcvd += bytes_rcvd except connection_errors as e: self.failed = True if reconnect: @@ -364,23 +367,19 @@ def recv(self, flags=None, reconnect=True) -> bytearray: self.reconnect() raise e - if bytes_total_received < 4: - continue - elif bytes_to_receive == 0: - response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER) - bytes_to_receive = response_len - - if response_len + 4 > len(data): + if packet_len == 0 and total_rcvd > 4: + packet_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER, signed=True) + 4 + if packet_len > len(data): buffer.release() - data.extend(bytearray(response_len + 4 - len(data))) - buffer = memoryview(data)[bytes_total_received:] + data.extend(bytearray(packet_len - len(data))) + buffer = memoryview(data)[total_rcvd:] continue - if bytes_total_received >= bytes_to_receive: + if 0 < packet_len <= total_rcvd: buffer.release() break - buffer = buffer[bytes_received:] + buffer = buffer[bytes_rcvd:] return data diff --git a/tests/common/test_sync_socket.py b/tests/common/test_sync_socket.py new file mode 100644 index 0000000..cd41809 --- /dev/null +++ b/tests/common/test_sync_socket.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import secrets +import socket +import unittest.mock as mock + +import pytest + +from pyignite import Client +from tests.util import get_or_create_cache + +old_recv_into = socket.socket.recv_into + + +def patched_recv_into_factory(buf_len): + def patched_recv_into(self, buffer, nbytes, **kwargs): + return old_recv_into(self, buffer, min(nbytes, buf_len) if buf_len else nbytes, **kwargs) + return patched_recv_into + + +@pytest.mark.parametrize('buf_len', [0, 1, 4, 16, 32, 64, 128, 256, 512, 1024]) +def test_get_large_value(buf_len): + with mock.patch.object(socket.socket, 'recv_into', new=patched_recv_into_factory(buf_len)): + c = Client() + with c.connect("127.0.0.1", 10801): + with get_or_create_cache(c, 'test') as cache: + value = secrets.token_hex((1 << 16) + 1) + cache.put(1, value) + assert value == cache.get(1)