From a71a5972b73bdbb337b0603b4643657f8d774d65 Mon Sep 17 00:00:00 2001 From: Nikolay Gribanov Date: Tue, 9 Feb 2021 17:38:45 +0300 Subject: [PATCH] HH-123895 add total_timeout for long requests --- consul/aio.py | 9 ++++---- consul/base.py | 23 +++++++++++++++++--- consul/std.py | 2 +- consul/tornado.py | 5 +++-- consul/twisted.py | 5 +++-- tests/test_base.py | 14 ++++++------ tests/test_std.py | 48 +++++++++++++++++++++++++++++++++++++++-- tests/test_std_token.py | 2 +- 8 files changed, 86 insertions(+), 22 deletions(-) diff --git a/consul/aio.py b/consul/aio.py index e782f62..b22d2bc 100644 --- a/consul/aio.py +++ b/consul/aio.py @@ -5,6 +5,7 @@ import warnings import aiohttp +from aiohttp import ClientTimeout from consul import base @@ -20,10 +21,10 @@ def __init__(self, *args, loop=None, **kwargs): self._session = None self._loop = loop or asyncio.get_event_loop() - async def _request(self, callback, method, uri, data=None, headers=None): + async def _request(self, callback, method, uri, data=None, headers=None, total_timeout=None): connector = aiohttp.TCPConnector(loop=self._loop, verify_ssl=self.verify) - async with aiohttp.ClientSession(connector=connector) as session: + async with aiohttp.ClientSession(connector=connector, timeout=ClientTimeout(total=total_timeout)) as session: self._session = session resp = await session.request(method=method, url=uri, @@ -47,9 +48,9 @@ def __del__(self): ResourceWarning) asyncio.ensure_future(self.close()) - async def get(self, callback, path, params=None, headers=None): + async def get(self, callback, path, params=None, headers=None, total_timeout=None): uri = self.uri(path, params) - return await self._request(callback, 'GET', uri, headers=headers) + return await self._request(callback, 'GET', uri, headers=headers, total_timeout=total_timeout) async def put(self, callback, path, params=None, data='', headers=None): uri = self.uri(path, params) diff --git a/consul/base.py b/consul/base.py index 2b412be..7159fb6 100755 --- a/consul/base.py +++ b/consul/base.py @@ -4,6 +4,7 @@ import json import logging import os +import re import warnings import six @@ -316,7 +317,7 @@ def uri(self, path, params=None): return uri @abc.abstractmethod - def get(self, callback, path, params=None, headers=None): + def get(self, callback, path, params=None, headers=None, total_timeout=None): raise NotImplementedError @abc.abstractmethod @@ -2955,7 +2956,8 @@ def get( consistency=None, keys=False, separator=None, - dc=None): + dc=None, + total_timeout=None): """ Returns a tuple of (*index*, *value[s]*) @@ -3002,6 +3004,10 @@ def get( if index: params.append(('index', index)) if wait: + assert total_timeout, \ + 'total_timeout should be setted' + assert not self._convert_wait_to_seconds(wait) >= total_timeout, \ + f'wait: {wait} should be less than total_timeout: {total_timeout}s' params.append(('wait', wait)) if recurse: params.append(('recurse', '1')) @@ -3030,7 +3036,18 @@ def get( CB.json(index=True, decode=decode, one=one, map=lambda x: x if x else None), path='/v1/kv/%s' % key, - params=params, headers=headers) + params=params, headers=headers, total_timeout=total_timeout) + + def _convert_wait_to_seconds(self, wait): + unit_to_seconds_multiplier = { + 'ms': 0.001, + 's': 1, + 'm': 60 + } + wait_digit = int(re.search(r'\d+', wait).group()) + multiplier = unit_to_seconds_multiplier[re.search(r'ms|s|m', wait).group()] + + return wait_digit * multiplier def put( self, diff --git a/consul/std.py b/consul/std.py index 224b6fe..feb54ea 100644 --- a/consul/std.py +++ b/consul/std.py @@ -19,7 +19,7 @@ def response(response): response.text, response.content) - def get(self, callback, path, params=None, headers=None): + def get(self, callback, path, params=None, headers=None, total_timeout=None): uri = self.uri(path, params) return callback(self.response( self.session.get(uri, diff --git a/consul/tornado.py b/consul/tornado.py index 8294885..d90e610 100644 --- a/consul/tornado.py +++ b/consul/tornado.py @@ -31,12 +31,13 @@ def _request(self, callback, request): response = e.response raise gen.Return(callback(self.response(response))) - def get(self, callback, path, params=None, headers=None): + def get(self, callback, path, params=None, headers=None, total_timeout=None): uri = self.uri(path, params) request = httpclient.HTTPRequest(uri, method='GET', validate_cert=self.verify, - headers=headers) + headers=headers, + connect_timeout=total_timeout) return self._request(callback, request) def put(self, callback, path, params=None, data='', headers=None): diff --git a/consul/twisted.py b/consul/twisted.py index 900ea75..02c37e3 100644 --- a/consul/twisted.py +++ b/consul/twisted.py @@ -100,13 +100,14 @@ def request(self, callback, method, url, **kwargs): 'Request incomplete: {} {}'.format(method.upper(), url)) @inlineCallbacks - def get(self, callback, path, params=None, headers=None): + def get(self, callback, path, params=None, headers=None, total_timeout=None): uri = self.uri(path, params) response = yield self.request(callback, 'get', uri, params=params, - headers=headers) + headers=headers, + total_timeout=total_timeout) returnValue(response) @inlineCallbacks diff --git a/tests/test_base.py b/tests/test_base.py index a80227f..5d815dc 100755 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -11,7 +11,7 @@ Response = consul.base.Response Request = collections.namedtuple( - 'Request', ['method', 'path', 'params', 'data', 'headers']) + 'Request', ['method', 'path', 'params', 'data', 'headers', 'total_timeout']) class HTTPClient(object): @@ -19,14 +19,14 @@ def __init__(self, host=None, port=None, scheme=None, verify=True, cert=None, timeout=None, headers=None): pass - def get(self, callback, path, params=None, headers=None): - return Request('get', path, params, None, headers) + def get(self, callback, path, params=None, headers=None, total_timeout=None): + return Request('get', path, params, None, headers, total_timeout) - def put(self, callback, path, params=None, data='', headers=None): - return Request('put', path, params, data, headers) + def put(self, callback, path, params=None, data='', headers=None, total_timeout=None): + return Request('put', path, params, data, headers, total_timeout) - def delete(self, callback, path, params=None, headers=None): - return Request('delete', path, params, None, headers) + def delete(self, callback, path, params=None, headers=None, total_timeout=None): + return Request('delete', path, params, None, headers, total_timeout) class Consul(consul.base.Consul): diff --git a/tests/test_std.py b/tests/test_std.py index 68712fa..7475b99 100644 --- a/tests/test_std.py +++ b/tests/test_std.py @@ -31,13 +31,57 @@ def test_kv(self, consul_port): index, data = c.kv.get('foo') assert data['Value'] == six.b('bar') - def test_kv_wait(self, consul_port): + def test_kv_wait_ms(self, consul_port): c = consul.Consul(port=consul_port) assert c.kv.put('foo', 'bar') is True index, data = c.kv.get('foo') - check, data = c.kv.get('foo', index=index, wait='20ms') + check, data = c.kv.get('foo', index=index, wait='20ms', total_timeout=30) assert index == check + def test_kv_wait_s(self, consul_port): + c = consul.Consul(port=consul_port) + assert c.kv.put('foo', 'bar') is True + index, data = c.kv.get('foo') + check, data = c.kv.get('foo', index=index, wait='20s', total_timeout=30) + assert index == check + + def test_kv_wait_m(self, consul_port): + c = consul.Consul(port=consul_port) + assert c.kv.put('foo', 'bar') is True + index, data = c.kv.get('foo') + check, data = c.kv.get('foo', index=index, wait='1m', total_timeout=61) + assert index == check + + def test_kv_wait_more_timeout_ms(self, consul_port): + c = consul.Consul(port=consul_port) + assert c.kv.put('foo', 'bar') is True + index, data = c.kv.get('foo') + pytest.raises( + AssertionError, + c.kv.get, + 'foo', index=index, wait='20ms', total_timeout=0 + ) + + def test_kv_wait_more_timeout_s(self, consul_port): + c = consul.Consul(port=consul_port) + assert c.kv.put('foo', 'bar') is True + index, data = c.kv.get('foo') + pytest.raises( + AssertionError, + c.kv.get, + 'foo', index=index, wait='20s', total_timeout=19 + ) + + def test_kv_wait_more_timeout_m(self, consul_port): + c = consul.Consul(port=consul_port) + assert c.kv.put('foo', 'bar') is True + index, data = c.kv.get('foo') + pytest.raises( + AssertionError, + c.kv.get, + 'foo', index=index, wait='1m', total_timeout=59 + ) + def test_kv_encoding(self, consul_port): c = consul.Consul(port=consul_port) diff --git a/tests/test_std_token.py b/tests/test_std_token.py index 86b60c9..b5852ad 100644 --- a/tests/test_std_token.py +++ b/tests/test_std_token.py @@ -24,7 +24,7 @@ def test_kv_wait(self, acl_consul): c = consul.Consul(port=acl_consul.port, token=acl_consul.token) assert c.kv.put('foo', 'bar') is True index, data = c.kv.get('foo') - check, data = c.kv.get('foo', index=index, wait='20ms') + check, data = c.kv.get('foo', index=index, wait='20ms', total_timeout=30) assert index == check def test_kv_encoding(self, acl_consul):