diff --git a/gql/client.py b/gql/client.py index 99cd6e46..a4e80dcb 100644 --- a/gql/client.py +++ b/gql/client.py @@ -184,6 +184,24 @@ def _build_schema_from_introspection( self.introspection = cast(IntrospectionQuery, execution_result.data) self.schema = build_client_schema(self.introspection) + @staticmethod + def _get_event_loop() -> asyncio.AbstractEventLoop: + """Get the current asyncio event loop. + + Or create a new event loop if there isn't one (in a new Thread). + """ + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop + @overload def execute_sync( self, @@ -358,6 +376,58 @@ async def execute_async( **kwargs, ) + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """:meta private:""" + async with self as session: + return await session.execute_batch( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + @overload def execute( self, @@ -430,17 +500,7 @@ def execute( """ if isinstance(self.transport, AsyncTransport): - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() assert not loop.is_running(), ( "Cannot run client.execute(query) if an asyncio loop is running." @@ -537,7 +597,24 @@ def execute_batch( """ if isinstance(self.transport, AsyncTransport): - raise NotImplementedError("Batching is not implemented for async yet.") + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.execute_batch(query) if an asyncio loop is running." + " Use 'await client.execute_batch(query)' instead." + ) + + data = loop.run_until_complete( + self.execute_batch_async( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + ) + + return data else: # Sync transports return self.execute_batch_sync( @@ -675,17 +752,12 @@ def subscribe( We need an async transport for this functionality. """ - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.subscribe(query) if an asyncio loop is running." + " Use 'await client.subscribe_async(query)' instead." + ) async_generator: Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] @@ -699,11 +771,6 @@ def subscribe( **kwargs, ) - assert not loop.is_running(), ( - "Cannot run client.subscribe(query) if an asyncio loop is running." - " Use 'await client.subscribe_async(query)' instead." - ) - try: while True: # Note: we need to create a task here in order to be able to close @@ -1626,6 +1693,149 @@ async def execute( return result.data + async def _execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + validate_document: Optional[bool] = True, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch, using + the async transport, returning a list of ExecutionResult objects. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param validate_document: Whether we still need to validate the document. + + The extra arguments are passed to the transport execute_batch method.""" + + # Validate document + if self.client.schema: + + if validate_document: + for req in requests: + self.client.validate(req.document) + + # Parse variable values for custom scalars if requested + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + requests = [ + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) + for req in requests + ] + + results = await self.transport.execute_batch(requests, **kwargs) + + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + for result in results: + result.data = parse_result_fn( + self.client.schema, + req.document, + result.data, + operation_name=req.operation_name, + ) + + return results + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """Execute multiple GraphQL requests in a batch, using + the async transport. This method sends the requests to the server all at once. + + Raises a TransportQueryError if an error has been returned in any + ExecutionResult. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. + + The extra arguments are passed to the transport execute method.""" + + # Validate and execute on the transport + results = await self._execute_batch( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + for result in results: + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str_first_element(result.errors), + errors=result.errors, + data=result.data, + extensions=result.extensions, + ) + + assert ( + result.data is not None + ), "Transport returned an ExecutionResult without data or errors" + + if get_execution_result: + return results + + return cast(List[Dict[str, Any]], [result.data for result in results]) + async def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b2633abb..9535eef4 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -8,7 +8,7 @@ AsyncGenerator, Callable, Dict, - NoReturn, + List, Optional, Tuple, Type, @@ -23,9 +23,11 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .common.aiohttp_closed_event import create_aiohttp_closed_event +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -162,172 +164,274 @@ async def close(self) -> None: self.session = None - async def execute( + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload + + def _prepare_batch_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + reqs: List[GraphQLRequest], extra_args: Optional[Dict[str, Any]] = None, - upload_files: bool = False, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - This uses the aiohttp library to perform a HTTP POST request asynchronously - to the remote server. + ) -> Dict[str, Any]: - Don't call this coroutine directly on the transport, instead use - :code:`execute` on a client or a session. + payload = [self._build_payload(req) for req in reqs] - :param document: the parsed GraphQL request - :param variable_values: An optional Dict of variable values - :param operation_name: An optional Operation name for the request - :param extra_args: additional arguments to send to the aiohttp post method - :param upload_files: Set to True if you want to put files in the variable values - :returns: an ExecutionResult object. - """ + post_args = {"json": payload} - query_str = print_ast(document) + # Log the payload + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", self.json_serialize(post_args["json"])) - payload: Dict[str, Any] = { - "query": query_str, - } + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) + + return post_args + + def _prepare_request( + self, + req: GraphQLRequest, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> Dict[str, Any]: - if operation_name: - payload["operationName"] = operation_name + payload = self._build_payload(req) if upload_files: + post_args = self._prepare_file_uploads(req, payload) + else: + post_args = {"json": payload} - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None + # Log the payload + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", self.json_serialize(payload)) + + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=variable_values, - file_classes=self.file_classes, + # Add headers for AppSync if requested + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + self.json_serialize(payload), + {"content-type": "application/json"}, ) - # Opening the files using the FileVar parameters - open_files(list(files.values()), transport_supports_streaming=True) - self.files = files + return post_args + + def _prepare_file_uploads( + self, req: GraphQLRequest, payload: Dict[str, Any] + ) -> Dict[str, Any]: + + # If the upload_files flag is set, then we need variable_values + variable_values = req.variable_values + assert variable_values is not None - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=variable_values, + file_classes=self.file_classes, + ) - # Prepare aiohttp to send multipart-encoded data - data = aiohttp.FormData() + # Opening the files using the FileVar parameters + open_files(list(files.values()), transport_supports_streaming=True) + self.files = files - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values - # Enumerate the file streams - # Will generate something like {'0': FileVar object} - file_vars = {str(i): files[path] for i, path in enumerate(files)} + # Prepare aiohttp to send multipart-encoded data + data = aiohttp.FormData() + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} + + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} + + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) + data.add_field("operations", operations_str, content_type="application/json") + + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) + data.add_field("map", file_map_str, content_type="application/json") + + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) data.add_field( - "operations", operations_str, content_type="application/json" + k, + file_var.f, + filename=file_var.filename, + content_type=file_var.content_type, ) - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) - data.add_field("map", file_map_str, content_type="application/json") + post_args: Dict[str, Any] = {"data": data} - for k, file_var in file_vars.items(): - assert isinstance(file_var, FileVar) + return post_args - data.add_field( - k, - file_var.f, - filename=file_var.filename, - content_type=file_var.content_type, - ) + async def raise_response_error( + self, + resp: aiohttp.ClientResponse, + reason: str, + ) -> None: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + try: + # Raise ClientResponseError if response status is 400 or higher + resp.raise_for_status() + except ClientResponseError as e: + raise TransportServerError(str(e), e.status) from e - post_args: Dict[str, Any] = {"data": data} + result_text = await resp.text() + self._raise_invalid_result(result_text, reason) - else: - if variable_values: - payload["variables"] = variable_values + async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: + + # Saving latest response headers in the transport + self.response_headers = response.headers + + try: + result = await response.json(loads=self.json_deserialize, content_type=None) if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + result_text = await response.text() + log.info("<<< %s", result_text) - post_args = {"json": payload} + except Exception: + await self.raise_response_error(response, "Not a JSON answer") - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) + if result is None: + await self.raise_response_error(response, "Not a JSON answer") - # Add headers for AppSync if requested - if isinstance(self.auth, AppSyncAuthentication): - post_args["headers"] = self.auth.get_headers( - self.json_serialize(payload), - {"content-type": "application/json"}, + return result + + async def _prepare_result( + self, response: aiohttp.ClientResponse + ) -> ExecutionResult: + + result = await self._get_json_result(response) + + if "errors" not in result and "data" not in result: + await self.raise_response_error( + response, 'No "data" or "errors" keys in answer' ) - if self.session is None: - raise TransportClosed("Transport is not connected") + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - try: - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + async def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: aiohttp.ClientResponse, + ) -> List[ExecutionResult]: - # Saving latest response headers in the transport - self.response_headers = resp.headers + answers = await self._get_json_result(response) - async def raise_response_error( - resp: aiohttp.ClientResponse, reason: str - ) -> NoReturn: - # We raise a TransportServerError if status code is 400 or higher - # We raise a TransportProtocolError in the other cases + return get_batch_execution_result_list(reqs, answers) - try: - # Raise ClientResponseError if response status is 400 or higher - resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e), e.status) from e + def _raise_invalid_result(self, result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - result_text = await resp.text() - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + This uses the aiohttp library to perform a HTTP POST request asynchronously + to the remote server. - try: - result = await resp.json( - loads=self.json_deserialize, content_type=None - ) + Don't call this coroutine directly on the transport, instead use + :code:`execute` on a client or a session. - if log.isEnabledFor(logging.INFO): - result_text = await resp.text() - log.info("<<< %s", result_text) + :param document: the parsed GraphQL request + :param variable_values: An optional Dict of variable values + :param operation_name: An optional Operation name for the request + :param extra_args: additional arguments to send to the aiohttp post method + :param upload_files: Set to True if you want to put files in the variable values + :returns: an ExecutionResult object. + """ - except Exception: - await raise_response_error(resp, "Not a JSON answer") + req = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) - if result is None: - await raise_response_error(resp, "Not a JSON answer") + post_args = self._prepare_request( + req, + extra_args, + upload_files, + ) - if "errors" not in result and "data" not in result: - await raise_response_error( - resp, 'No "data" or "errors" keys in answer' - ) + if self.session is None: + raise TransportClosed("Transport is not connected") - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_result(resp) finally: if upload_files: close_files(list(self.files.values())) + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + if self.session is None: + raise TransportClosed("Transport is not connected") + + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_batch_result(reqs, resp) + def subscribe( self, document: DocumentNode, diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 4cecc9f9..243746e6 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,8 +1,10 @@ import abc -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from graphql import DocumentNode, ExecutionResult +from ..graphql_request import GraphQLRequest + class AsyncTransport(abc.ABC): @abc.abstractmethod @@ -32,6 +34,23 @@ async def execute( "Any AsyncTransport subclass must implement execute method" ) # pragma: no cover + async def execute_batch( + self, + reqs: List[GraphQLRequest], + *args: Any, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Execute the provided requests for either a remote or local GraphQL Schema. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :return: a list of ExecutionResult objects + """ + raise NotImplementedError( + "This Transport has not implemented the execute_batch method" + ) # pragma: no cover + @abc.abstractmethod def subscribe( self, diff --git a/gql/transport/common/batch.py b/gql/transport/common/batch.py new file mode 100644 index 00000000..4feadee6 --- /dev/null +++ b/gql/transport/common/batch.py @@ -0,0 +1,76 @@ +from typing import ( + Any, + Dict, + List, +) + +from graphql import ExecutionResult + +from ...graphql_request import GraphQLRequest +from ..exceptions import ( + TransportProtocolError, +) + + +def _raise_protocol_error(result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " f"{reason}: " f"{result_text}" + ) + + +def _validate_answer_is_a_list(results: Any) -> None: + if not isinstance(results, list): + _raise_protocol_error( + str(results), + "Answer is not a list", + ) + + +def _validate_data_and_errors_keys_in_answers(results: List[Dict[str, Any]]) -> None: + for result in results: + if "errors" not in result and "data" not in result: + _raise_protocol_error( + str(results), + 'No "data" or "errors" keys in answer', + ) + + +def _validate_every_answer_is_a_dict(results: List[Dict[str, Any]]) -> None: + for result in results: + if not isinstance(result, dict): + _raise_protocol_error(str(results), "Not every answer is dict") + + +def _validate_num_of_answers_same_as_requests( + reqs: List[GraphQLRequest], + results: List[Dict[str, Any]], +) -> None: + if len(reqs) != len(results): + _raise_protocol_error( + str(results), + ( + "Invalid number of answers: " + f"{len(results)} answers received for {len(reqs)} requests" + ), + ) + + +def _answer_to_execution_result(result: Dict[str, Any]) -> ExecutionResult: + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + + +def get_batch_execution_result_list( + reqs: List[GraphQLRequest], + answers: List, +) -> List[ExecutionResult]: + + _validate_answer_is_a_list(answers) + _validate_num_of_answers_same_as_requests(reqs, answers) + _validate_every_answer_is_a_dict(answers) + _validate_data_and_errors_keys_in_answers(answers) + + return [_answer_to_execution_result(answer) for answer in answers] diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index eb15ac57..406c0523 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -17,7 +17,9 @@ import httpx from graphql import DocumentNode, ExecutionResult, print_ast +from ..graphql_request import GraphQLRequest from . import AsyncTransport, Transport +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -55,32 +57,30 @@ def __init__( self.json_deserialize = json_deserialize self.kwargs = kwargs + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload + def _prepare_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + req: GraphQLRequest, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - query_str = print_ast(document) - - payload: Dict[str, Any] = { - "query": query_str, - } - if operation_name: - payload["operationName"] = operation_name + payload = self._build_payload(req) if upload_files: - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None - - post_args = self._prepare_file_uploads(variable_values, payload) + post_args = self._prepare_file_uploads(req, payload) else: - if variable_values: - payload["variables"] = variable_values - post_args = {"json": payload} # Log the payload @@ -93,9 +93,37 @@ def _prepare_request( return post_args + def _prepare_batch_request( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + + payload = [self._build_payload(req) for req in reqs] + + post_args = {"json": payload} + + # Log the payload + if log.isEnabledFor(logging.INFO): + log.debug(">>> %s", self.json_serialize(payload)) + + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) + + return post_args + def _prepare_file_uploads( - self, variable_values: Dict[str, Any], payload: Dict[str, Any] + self, + request: GraphQLRequest, + payload: Dict[str, Any], ) -> Dict[str, Any]: + + variable_values = request.variable_values + + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( @@ -143,8 +171,9 @@ def _prepare_file_uploads( return {"data": data, "files": file_streams} - def _prepare_result(self, response: httpx.Response) -> ExecutionResult: - # Save latest response headers in transport + def _get_json_result(self, response: httpx.Response) -> Any: + + # Saving latest response headers in the transport self.response_headers = response.headers if log.isEnabledFor(logging.DEBUG): @@ -152,10 +181,15 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: try: result: Dict[str, Any] = self.json_deserialize(response.content) - except Exception: self._raise_response_error(response, "Not a JSON answer") + return result + + def _prepare_result(self, response: httpx.Response) -> ExecutionResult: + + result = self._get_json_result(response) + if "errors" not in result and "data" not in result: self._raise_response_error(response, 'No "data" or "errors" keys in answer') @@ -165,6 +199,16 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: extensions=result.get("extensions"), ) + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: httpx.Response, + ) -> List[ExecutionResult]: + + answers = self._get_json_result(response) + + return get_batch_execution_result_list(reqs, answers) + def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases @@ -223,10 +267,14 @@ def execute( # type: ignore if not self.client: raise TransportClosed("Transport is not connected") + request = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) + post_args = self._prepare_request( - document, - variable_values, - operation_name, + request, extra_args, upload_files, ) @@ -239,6 +287,36 @@ def execute( # type: ignore return self._prepare_result(response) + def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + response = self.client.post(self.url, **post_args) + + return self._prepare_batch_result(reqs, response) + def close(self): """Closing the transport by closing the inner session""" if self.client: @@ -290,10 +368,14 @@ async def execute( if not self.client: raise TransportClosed("Transport is not connected") + request = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) + post_args = self._prepare_request( - document, - variable_values, - operation_name, + request, extra_args, upload_files, ) @@ -306,11 +388,35 @@ async def execute( return self._prepare_result(response) - async def close(self): - """Closing the transport by closing the inner session""" - if self.client: - await self.client.aclose() - self.client = None + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + response = await self.client.post(self.url, **post_args) + + return self._prepare_batch_result(reqs, response) def subscribe( self, @@ -323,3 +429,9 @@ def subscribe( :meta private: """ raise NotImplementedError("The HTTP transport does not support subscriptions") + + async def close(self): + """Closing the transport by closing the inner session""" + if self.client: + await self.client.aclose() + self.client = None diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 5fb7e827..d84ba9d3 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -25,6 +25,7 @@ from gql.transport import Transport from ..graphql_request import GraphQLRequest +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -307,7 +308,7 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: extensions=result.get("extensions"), ) - def execute_batch( # type: ignore + def execute_batch( self, reqs: List[GraphQLRequest], timeout: Optional[int] = None, @@ -340,52 +341,7 @@ def execute_batch( # type: ignore answers = self._extract_response(response) - self._validate_answer_is_a_list(answers) - self._validate_num_of_answers_same_as_requests(reqs, answers) - self._validate_every_answer_is_a_dict(answers) - self._validate_data_and_errors_keys_in_answers(answers) - - return [self._answer_to_execution_result(answer) for answer in answers] - - def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult: - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) - - def _validate_answer_is_a_list(self, results: Any) -> None: - if not isinstance(results, list): - self._raise_invalid_result( - str(results), - "Answer is not a list", - ) - - def _validate_data_and_errors_keys_in_answers( - self, results: List[Dict[str, Any]] - ) -> None: - for result in results: - if "errors" not in result and "data" not in result: - self._raise_invalid_result( - str(results), - 'No "data" or "errors" keys in answer', - ) - - def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None: - for result in results: - if not isinstance(result, dict): - self._raise_invalid_result(str(results), "Not every answer is dict") - - def _validate_num_of_answers_same_as_requests( - self, - reqs: List[GraphQLRequest], - results: List[Dict[str, Any]], - ) -> None: - if len(reqs) != len(results): - self._raise_invalid_result( - str(results), - "Invalid answer length", - ) + return get_batch_execution_result_list(reqs, answers) def _raise_invalid_result(self, result_text: str, reason: str) -> None: raise TransportProtocolError( @@ -427,7 +383,7 @@ def _build_batch_post_args( } data_key = "json" if self.use_json else "data" - post_args[data_key] = [self._build_data(req) for req in reqs] + post_args[data_key] = [self._build_payload(req) for req in reqs] # Log the payload if log.isEnabledFor(logging.INFO): @@ -442,7 +398,7 @@ def _build_batch_post_args( return post_args - def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]: + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: query_str = print_ast(req.document) payload: Dict[str, Any] = {"query": query_str} diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 39f1a1cb..8b4a99f4 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -784,6 +784,32 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.asyncio +@pytest.mark.aiohttp +async def test_custom_scalar_serialize_variables_async_transport(aiohttp_server): + transport = await make_money_transport(aiohttp_server) + + async with Client( + schema=schema, transport=transport, parse_results=True + ) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + results = await session.execute_batch( + [ + GraphQLRequest(document=query, variable_values=variable_values), + GraphQLRequest(document=query, variable_values=variable_values), + ], + serialize_variables=True, + ) + + print(f"result = {results!r}") + assert results[0]["toEuros"] == 5 + assert results[1]["toEuros"] == 5 + + def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index fe36585e..0642e536 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -295,27 +295,28 @@ async def handler(request): { "response": "{}", "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {}' ), }, { "response": "qlsjfqsdlkj", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" + "Server did not return a valid GraphQL result: " + "Not a JSON answer: qlsjfqsdlkj" ), }, { "response": '{"not_data_or_errors": 35}', "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, { "response": "", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: " + "Server did not return a valid GraphQL result: Not a JSON answer: " ), }, ] diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py new file mode 100644 index 00000000..f04f05e4 --- /dev/null +++ b/tests/test_aiohttp_batch.py @@ -0,0 +1,335 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(document=gql(query1_str))] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_aiohttp_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_aiohttp_batch_cannot_execute_if_not_connected( + aiohttp_server, run_sync_test +): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_extra_args(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + # passing extra arguments to aiohttp.ClientSession + from aiohttp import DummyCookieJar + + jar = DummyCookieJar() + transport = AIOHTTPTransport( + url=url, timeout=10, client_session_args={"version": "1.1", "cookie_jar": jar} + ) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Passing extra arguments to the post method of aiohttp + results = await session.execute_batch( + query, extra_args={"allow_redirects": False} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + query = [GraphQLRequest(document=gql(query1_str))] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_aiohttp_batch_online_manual(): + + from gql.transport.aiohttp import AIOHTTPTransport + + client = Client( + transport=AIOHTTPTransport(url=ONLINE_URL, timeout=10), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_client.py b/tests/test_client.py index 8669b4a3..55993a9e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -54,19 +54,6 @@ def execute( ) -@pytest.mark.aiohttp -def test_request_async_execute_batch_not_implemented_yet(): - from gql.transport.aiohttp import AIOHTTPTransport - - transport = AIOHTTPTransport(url="http://localhost/") - client = Client(transport=transport) - - with pytest.raises(NotImplementedError) as exc_info: - client.execute_batch([GraphQLRequest(document=gql("{dummy}"))]) - - assert "Batching is not implemented for async yet." == str(exc_info.value) - - @pytest.mark.requests @mock.patch("urllib3.connection.HTTPConnection._new_conn") def test_retries_on_transport(execute_mock): diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 9558e137..0991355a 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -387,6 +387,7 @@ def test_code(): "{}", "qlsjfqsdlkj", '{"not_data_or_errors": 35}', + "", ] diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index ddacbc14..87f1675a 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -457,7 +457,7 @@ async def handler(request): query = gql(query1_str) - # Passing extra arguments to the post method of aiohttp + # Passing extra arguments to the post method result = await session.execute(query, extra_args={"follow_redirects": True}) continents = result["continents"] diff --git a/tests/test_httpx_batch.py b/tests/test_httpx_batch.py new file mode 100644 index 00000000..9e5b9b93 --- /dev/null +++ b/tests/test_httpx_batch.py @@ -0,0 +1,440 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +# Marking all tests in this file with the httpx marker +pytestmark = pytest.mark.httpx + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_query(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + def test_code(): + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + results = session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(document=gql(query1_str))] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_httpx_async_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_extra_args(aiohttp_server): + import httpx + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # passing extra arguments to httpx.AsyncClient + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Passing extra arguments to the post method + results = await session.execute_batch( + query, extra_args={"follow_redirects": True} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + query = [GraphQLRequest(document=gql(query1_str))] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_async_manual(): + + from gql.transport.httpx import HTTPXAsyncTransport + + client = Client( + transport=HTTPXAsyncTransport(url=ONLINE_URL), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_sync_manual(): + + from gql.transport.httpx import HTTPXTransport + + client = Client( + transport=HTTPXTransport(url=ONLINE_URL), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 4b9e09b8..38850d56 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -545,14 +545,11 @@ def test_code(): await run_sync_test(server, test_code) -ONLINE_URL = "https://countries.trevorblades.com/" - -skip_reason = "backend does not support batching anymore..." +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto(): from threading import Thread @@ -619,7 +616,6 @@ def get_continent_name(session, continent_code): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto_execute_future(): from gql.transport.requests import RequestsHTTPTransport @@ -657,7 +653,6 @@ def test_requests_sync_batch_auto_execute_future(): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_manual(): from gql.transport.requests import RequestsHTTPTransport