diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2fee9775..7fd828a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,10 +12,10 @@ We welcome all kinds of contributions: ## Getting started -If you have a specific contribution in mind, be sure to check the -[issues](https://github.com/graphql-python/gql/issues) -and [pull requests](https://github.com/graphql-python/gql/pulls) -in progress - someone could already be working on something similar +If you have a specific contribution in mind, be sure to check the +[issues](https://github.com/graphql-python/gql/issues) +and [pull requests](https://github.com/graphql-python/gql/pulls) +in progress - someone could already be working on something similar and you can help out. ## Project setup @@ -31,10 +31,10 @@ virtualenv gql-dev Activate the virtualenv and install dependencies by running: ```console -python pip install -e[dev] +python pip install -e.[dev] ``` -If you are using Linux or MacOS, you can make use of Makefile command +If you are using Linux or MacOS, you can make use of Makefile command `make dev-setup`, which is a shortcut for the above python command. ### Development on Conda @@ -55,7 +55,7 @@ pip install -e.[dev] And you ready to start development! - + ## Running tests @@ -65,7 +65,7 @@ After developing, the full test suite can be evaluated by running: pytest tests --cov=gql -vv ``` -If you are using Linux or MacOS, you can make use of Makefile command +If you are using Linux or MacOS, you can make use of Makefile command `make tests`, which is a shortcut for the above python command. You can also test on several python environments by using tox. @@ -77,8 +77,8 @@ Install tox: pip install tox ``` -Run `tox` on your virtualenv (do not forget to activate it!) -and that's it! +Run `tox` on your virtualenv (do not forget to activate it!) +and that's it! ### Running tox on Conda @@ -93,5 +93,5 @@ This install tox underneath so no need to install it before. Then uncomment the `requires = tox-conda` line on `tox.ini` file. -Run `tox` and you will see all the environments being created -and all passing tests. :rocket: \ No newline at end of file +Run `tox` and you will see all the environments being created +and all passing tests. :rocket: diff --git a/MANIFEST.in b/MANIFEST.in index 8ccdab11..5dd78801 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -13,7 +13,7 @@ include tox.ini include scripts/gql-cli recursive-include tests *.py *.yaml *.graphql -recursive-include tests_py36 *.py +recursive-include tests_py36 *.py *.cnf *.pem prune gql-checker diff --git a/Makefile b/Makefile index 6e7b9b75..355a7f22 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,13 @@ +.PHONY: clean tests + dev-setup: python pip install -e ".[test]" tests: - pytest tests --cov=gql -vv + pytest tests tests_py36 --cov=gql --cov-report=term-missing -vv + +all_tests: + pytest tests tests_py36 --cov=gql --cov-report=term-missing --run-online -vv clean: find . -name "*.pyc" -delete diff --git a/README.md b/README.md index 1f9c7fab..ae455df6 100644 --- a/README.md +++ b/README.md @@ -67,10 +67,6 @@ from gql.transport.requests import RequestsHTTPTransport sample_transport=RequestsHTTPTransport( url='https://countries.trevorblades.com/', - use_json=True, - headers={ - "Content-type": "application/json", - }, verify=False, retries=3, ) @@ -215,7 +211,6 @@ async def main(): sample_transport = WebsocketsTransport( url='wss://countries.trevorblades.com/graphql', - ssl=True, headers={'Authorization': 'token'} ) @@ -262,8 +257,7 @@ import ssl sample_transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', - headers={'Authorization': 'token'}, - ssl=True + headers={'Authorization': 'token'} ) ``` @@ -298,8 +292,7 @@ There are two ways to send authentication tokens with websockets depending on th ```python sample_transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', - headers={'Authorization': 'token'}, - ssl=True + headers={'Authorization': 'token'} ) ``` @@ -308,8 +301,7 @@ sample_transport = WebsocketsTransport( ```python sample_transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', - init_payload={'Authorization': 'token'}, - ssl=True + init_payload={'Authorization': 'token'} ) ``` diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index e24c4045..402666ef 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -35,7 +35,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, - **kwargs, + client_session_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -44,7 +44,7 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param kwargs: Other parameters forwarded to aiohttp.ClientSession + :param client_session_args: Dict of extra parameters passed to aiohttp.ClientSession """ self.url: str = url self.headers: Optional[LooseHeaders] = headers @@ -52,7 +52,7 @@ def __init__( self.auth: Optional[BasicAuth] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout - self.kwargs = kwargs + self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None @@ -78,7 +78,7 @@ async def connect(self) -> None: ) # Adding custom parameters passed from init - client_session_args.update(self.kwargs) + client_session_args.update(self.client_session_args) self.session = aiohttp.ClientSession(**client_session_args) @@ -95,7 +95,7 @@ async def execute( document: Document, variable_values: Optional[Dict[str, str]] = None, operation_name: Optional[str] = None, - **kwargs, + extra_args: Dict[str, Any] = {}, ) -> ExecutionResult: """Execute the provided document AST against the configured remote server. This uses the aiohttp library to perform a HTTP POST request asynchronously to the remote server. @@ -114,8 +114,8 @@ async def execute( "json": payload, } - # Pass kwargs to aiohttp post method - post_args.update(kwargs) + # Pass post_args to aiohttp post method + post_args.update(extra_args) if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 3dea3bad..9d0aaaac 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -97,6 +97,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, + connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given request parameters. @@ -107,6 +108,7 @@ def __init__( :param connect_timeout: Timeout in seconds for the establishment of the websocket connection. :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. + :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url self.ssl: Union[SSLContext, bool] = ssl @@ -117,6 +119,8 @@ def __init__( self.close_timeout: int = close_timeout self.ack_timeout: int = ack_timeout + self.connect_args = connect_args + self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -460,16 +464,27 @@ async def connect(self) -> None: if self.websocket is None: + # If the ssl parameter is not provided, generate the ssl value depending on the url + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + "subprotocols": [GRAPHQLWS_SUBPROTOCOL], + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds self.websocket = await asyncio.wait_for( - websockets.connect( - self.url, - ssl=self.ssl if self.ssl else None, - extra_headers=self.headers, - subprotocols=[GRAPHQLWS_SUBPROTOCOL], - ), - self.connect_timeout, + websockets.connect(self.url, **connect_args,), self.connect_timeout, ) self.next_query_id = 1 diff --git a/tests_py36/conftest.py b/tests_py36/conftest.py index 07032e2c..ac9dfc56 100644 --- a/tests_py36/conftest.py +++ b/tests_py36/conftest.py @@ -2,6 +2,8 @@ import json import logging import os +import pathlib +import ssl import types import pytest @@ -79,12 +81,35 @@ class TestServer: Will allow us to test our client by simulating different correct and incorrect server responses """ + def __init__(self, with_ssl: bool = False): + self.with_ssl = with_ssl + async def start(self, handler): print("Starting server") + extra_serve_args = {} + + if self.with_ssl: + # This is a copy of certificate from websockets tests folder + # + # Generate TLS certificate with: + # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ + # -out test_localhost.crt -keyout test_localhost.key + # $ cat test_localhost.key test_localhost.crt > test_localhost.pem + # $ rm test_localhost.key test_localhost.crt + self.testcert = bytes( + pathlib.Path(__file__).with_name("test_localhost.pem") + ) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(self.testcert) + + extra_serve_args["ssl"] = ssl_context + # Start a server with a random open port - self.start_server = websockets.server.serve(handler, "localhost", 0) + self.start_server = websockets.server.serve( + handler, "localhost", 0, **extra_serve_args + ) # Wait that the server is started self.server = await self.start_server @@ -137,13 +162,9 @@ async def wait_connection_terminate(ws): assert json_result["type"] == "connection_terminate" -@pytest.fixture -async def server(request): - """server is a fixture used to start a dummy server to test the client behaviour. - - It can take as argument either a handler function for the websocket server for complete control - OR an array of answers to be sent by the default server handler - """ +def get_server_handler(request): + """ Get the server handler provided from test or use the default + server handler if the test provides only an array of answers""" if isinstance(request.param, types.FunctionType): server_handler = request.param @@ -179,6 +200,42 @@ async def default_server_handler(ws, path): server_handler = default_server_handler + return server_handler + + +@pytest.fixture +async def ws_ssl_server(request): + """websockets server fixture using ssl + + It can take as argument either a handler function for the websocket server for complete control + OR an array of answers to be sent by the default server handler + """ + + server_handler = get_server_handler(request) + + try: + test_server = TestServer(with_ssl=True) + + # Starting the server with the fixture param as the handler function + await test_server.start(server_handler) + + yield test_server + except Exception as e: + print("Exception received in server fixture: " + str(e)) + finally: + await test_server.stop() + + +@pytest.fixture +async def server(request): + """server is a fixture used to start a dummy server to test the client behaviour. + + It can take as argument either a handler function for the websocket server for complete control + OR an array of answers to be sent by the default server handler + """ + + server_handler = get_server_handler(request) + try: test_server = TestServer() diff --git a/tests_py36/test_aiohttp.py b/tests_py36/test_aiohttp.py index 3e2ef5e2..213ccb4c 100644 --- a/tests_py36/test_aiohttp.py +++ b/tests_py36/test_aiohttp.py @@ -1,5 +1,5 @@ import pytest -from aiohttp import web +from aiohttp import DummyCookieJar, web from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport @@ -190,3 +190,34 @@ async def handler(request): with pytest.raises(TransportClosed): await sample_transport.execute(query) + + +@pytest.mark.asyncio +async def test_aiohttp_extra_args(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + # passing extra arguments to aiohttp.ClientSession + jar = DummyCookieJar() + sample_transport = AIOHTTPTransport( + url=url, timeout=10, client_session_args={"version": "1.1", "cookie_jar": jar} + ) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Passing extra arguments to the post method of aiohttp + result = await session.execute(query, extra_args={"allow_redirects": False}) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" diff --git a/tests_py36/test_localhost.cnf b/tests_py36/test_localhost.cnf new file mode 100644 index 00000000..6dc331ac --- /dev/null +++ b/tests_py36/test_localhost.cnf @@ -0,0 +1,26 @@ +[ req ] + +default_md = sha256 +encrypt_key = no + +prompt = no + +distinguished_name = dn +x509_extensions = ext + +[ dn ] + +C = "FR" +L = "Paris" +O = "Aymeric Augustin" +CN = "localhost" + +[ ext ] + +subjectAltName = @san + +[ san ] + +DNS.1 = localhost +IP.2 = 127.0.0.1 +IP.3 = ::1 diff --git a/tests_py36/test_localhost.pem b/tests_py36/test_localhost.pem new file mode 100644 index 00000000..b8a9ea9a --- /dev/null +++ b/tests_py36/test_localhost.pem @@ -0,0 +1,48 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCUgrQVkNbAWRlo +zZUj14Ufz7YEp2MXmvmhdlfOGLwjy+xPO98aJRv5/nYF2eWM3llcmLe8FbBSK+QF +To4su7ZVnc6qITOHqcSDUw06WarQUMs94bhHUvQp1u8+b2hNiMeGw6+QiBI6OJRO +iGpLRbkN6Uj3AKwi8SYVoLyMiztuwbNyGf8fF3DDpHZtBitGtMSBCMsQsfB465pl +2UoyBrWa2lsbLt3VvBZZvHqfEuPjpjjKN5USIXnaf0NizaR6ps3EyfftWy4i7zIQ +N5uTExvaPDyPn9nH3q/dkT99mSMSU1AvTTpX8PN7DlqE6wZMbQsBPRGW7GElQ+Ox +IKdKOLk5AgMBAAECggEAd3kqzQqnaTiEs4ZoC9yPUUc1pErQ8iWP27Ar9TZ67MVa +B2ggFJV0C0sFwbFI9WnPNCn77gj4vzJmD0riH+SnS/tXThDFtscBu7BtvNp0C4Bj +8RWMvXxjxuENuQnBPFbkRWtZ6wk8uK/Zx9AAyyt9M07Qjz1wPfAIdm/IH7zHBFMA +gsqjnkLh1r0FvjNEbLiuGqYU/GVxaZYd+xy+JU52IxjHUUL9yD0BPWb+Szar6AM2 +gUpmTX6+BcCZwwZ//DzCoWYZ9JbP8akn6edBeZyuMPqYgLzZkPyQ+hRW46VPPw89 +yg4LR9nzgQiBHlac0laB4NrWa+d9QRRLitl1O3gVAQKBgQDDkptxXu7w9Lpc+HeE +N/pJfpCzUuF7ZC4vatdoDzvfB5Ky6W88Poq+I7bB9m7StXdFAbDyUBxvisjTBMVA +OtYqpAk/rhX8MjSAtjoFe2nH+eEiQriuZmtA5CdKEXS4hNbc/HhEPWhk7Zh8OV5v +y7l4r6l4UHqaN9QyE0vlFdmcmQKBgQDCZZR/trJ2/g2OquaS+Zd2h/3NXw0NBq4z +4OBEWqNa/R35jdK6WlWJH7+tKOacr+xtswLpPeZHGwMdk64/erbYWBuJWAjpH72J +DM9+1H5fFHANWpWTNn94enQxwfzZRvdkxq4IWzGhesptYnHIzoAmaqC3lbn/e3u0 +Flng32hFoQKBgQCF3D4K3hib0lYQtnxPgmUMktWF+A+fflViXTWs4uhu4mcVkFNz +n7clJ5q6reryzAQjtmGfqRedfRex340HRn46V2aBMK2Znd9zzcZu5CbmGnFvGs3/ +iNiWZNNDjike9sV+IkxLIODoW/vH4xhxWrbLFSjg0ezoy5ew4qZK2abF2QKBgQC5 +M5efeQpbjTyTUERtf/aKCZOGZmkDoPq0GCjxVjzNQdqd1z0NJ2TYR/QP36idXIlu +FZ7PYZaS5aw5MGpQtfOe94n8dm++0et7t0WzunRO1yTNxCA+aSxWNquegAcJZa/q +RdKlyWPmSRqzzZdDzWCPuQQ3AyF5wkYfUy/7qjwoIQKBgB2v96BV7+lICviIKzzb +1o3A3VzAX5MGd98uLGjlK4qsBC+s7mk2eQztiNZgbA0W6fhQ5Dz3HcXJ5ppy8Okc +jeAktrNRzz15hvi/XkWdO+VMqiHW4l+sWYukjhCyod1oO1KGHq0LYYvv076syxGw +vRKLq7IJ4WIp1VtfaBlrIogq +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 +MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 +V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT +DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg +vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 +ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R +P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w +LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G +CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU ++FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH +aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj +8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r +jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY +941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 +-----END CERTIFICATE----- diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index 6f447a26..b9510a31 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -20,7 +20,7 @@ async def test_websocket_simple_query(): # Get Websockets transport sample_transport = WebsocketsTransport( - url="wss://countries.trevorblades.com/graphql", ssl=True + url="wss://countries.trevorblades.com/graphql" ) # Instanciate client diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index f2c3c49b..3b4e50c8 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -1,5 +1,6 @@ import asyncio import json +import ssl from typing import Dict import pytest @@ -70,6 +71,44 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert sample_transport.websocket is None +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): + + server = ws_ssl_server + + url = "wss://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(ws_ssl_server.testcert) + + sample_transport = WebsocketsTransport(url=url, ssl=ssl_context) + + async with Client(transport=sample_transport) as session: + + assert isinstance( + sample_transport.websocket, websockets.client.WebSocketClientProtocol + ) + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received: " + str(result)) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert sample_transport.websocket is None + + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) @@ -419,3 +458,18 @@ def test_websocket_execute_sync(server): # Check client is disconnect here assert sample_transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_add_extra_parameters_to_connect(event_loop, server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions + sample_transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21}) + + query = gql(query1_str) + + async with Client(transport=sample_transport) as session: + await session.execute(query)