diff --git a/README.md b/README.md index ea855e28..cacf77d5 100644 --- a/README.md +++ b/README.md @@ -99,3 +99,9 @@ To lint, install requirements (included in the previous step) and run ```bash make lint ``` + +## Acknowledgmnts + +We would like to thank the following people for their contributions: + +- [@aadamson](https://github.com/aadamson) for their contributions in supporting custom `requests.Session` objects [#170](https://github.com/sigopt/sigopt-python/pull/170) diff --git a/requirements-dev.txt b/requirements-dev.txt index 14c6464e..f37a8f83 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ # For continuous integration and development -mock==1.0.1 +mock>=3.0.5 pytest==2.8.7 twine==1.9.1 diff --git a/sigopt/interface.py b/sigopt/interface.py index 23fc5bee..33dd9edb 100644 --- a/sigopt/interface.py +++ b/sigopt/interface.py @@ -274,7 +274,7 @@ class Connection(object): Client-facing interface for creating Connections. Shouldn't be changed without a major version change. """ - def __init__(self, client_token=None, user_agent=None): + def __init__(self, client_token=None, user_agent=None, session=None): client_token = client_token or os.environ.get('SIGOPT_API_TOKEN') api_url = os.environ.get('SIGOPT_API_URL') or DEFAULT_API_URL if not client_token: @@ -289,6 +289,7 @@ def __init__(self, client_token=None, user_agent=None): client_token, '', default_headers, + session=session, ) self.impl = ConnectionImpl(requestor, api_url=api_url) diff --git a/sigopt/requestor.py b/sigopt/requestor.py index 94fc0c68..70c24f22 100644 --- a/sigopt/requestor.py +++ b/sigopt/requestor.py @@ -12,10 +12,11 @@ def __init__( user, password, headers, - verify_ssl_certs=True, + verify_ssl_certs=None, proxies=None, timeout=DEFAULT_HTTP_TIMEOUT, client_ssl_certs=None, + session=None, ): self._set_auth(user, password) self.default_headers = headers or {} @@ -23,6 +24,7 @@ def __init__( self.proxies = proxies self.timeout = timeout self.client_ssl_certs = client_ssl_certs + self.session = session def _set_auth(self, username, password): if username is not None: @@ -48,7 +50,8 @@ def delete(self, url, params=None, json=None, headers=None): def request(self, method, url, params=None, json=None, headers=None): headers = self._with_default_headers(headers) try: - response = requests.request( + caller = (self.session or requests) + response = caller.request( method=method, url=url, params=params, diff --git a/test/test_interface.py b/test/test_interface.py index 3775b513..9ee82d20 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -10,12 +10,26 @@ class TestInterface(object): def test_create(self): conn = Connection(client_token='client_token') assert conn.impl.api_url == 'https://api.sigopt.com' - assert conn.impl.requestor.verify_ssl_certs is True + assert conn.impl.requestor.verify_ssl_certs is None + assert conn.impl.requestor.session is None assert conn.impl.requestor.proxies is None assert conn.impl.requestor.timeout == DEFAULT_HTTP_TIMEOUT assert isinstance(conn.clients, ApiResource) assert isinstance(conn.experiments, ApiResource) + def test_create_uses_session_if_provided(self): + session = mock.Mock() + conn = Connection(client_token='client_token', session=session) + assert conn.impl.requestor.session is session + + response = mock.Mock() + session.request.return_value = response + response.status_code = 200 + response.text = '{}' + session.request.assert_not_called() + conn.experiments().fetch() + session.request.assert_called_once() + def test_environment_variable(self): with mock.patch.dict(os.environ, {'SIGOPT_API_TOKEN': 'client_token'}): Connection()